Skip to content

Commit

Permalink
Tighten up field/method visibility in Executor and made some code mor…
Browse files Browse the repository at this point in the history
…e clear to read.

I was reading Executor just now and found that some latest changes introduced some weird code path with too much monadic chaining and unnecessary fields. I cleaned it up a bit, and also tightened up the visibility of various fields/methods. Also added some inline documentation to help understand this code better.

Author: Reynold Xin <rxin@databricks.com>

Closes apache#4850 from rxin/executor and squashes the following commits:

866fc60 [Reynold Xin] Code review feedback.
020efbb [Reynold Xin] Tighten up field/method visibility in Executor and made some code more clear to read.
  • Loading branch information
rxin committed Mar 20, 2015
1 parent f17d43b commit 0745a30
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 106 deletions.
6 changes: 1 addition & 5 deletions core/src/main/scala/org/apache/spark/TaskEndReason.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,7 @@ case object TaskKilled extends TaskFailedReason {
* Task requested the driver to commit, but was denied.
*/
@DeveloperApi
case class TaskCommitDenied(
jobID: Int,
partitionID: Int,
attemptID: Int)
extends TaskFailedReason {
case class TaskCommitDenied(jobID: Int, partitionID: Int, attemptID: Int) extends TaskFailedReason {
override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" +
s" for job: $jobID, partition: $partitionID, attempt: $attemptID"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@ import org.apache.spark.{TaskCommitDenied, TaskEndReason}
/**
* Exception thrown when a task attempts to commit output to HDFS but is denied by the driver.
*/
class CommitDeniedException(
private[spark] class CommitDeniedException(
msg: String,
jobID: Int,
splitID: Int,
attemptID: Int)
extends Exception(msg) {

def toTaskEndReason: TaskEndReason = new TaskCommitDenied(jobID, splitID, attemptID)

def toTaskEndReason: TaskEndReason = TaskCommitDenied(jobID, splitID, attemptID)
}

196 changes: 106 additions & 90 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.File
import java.lang.management.ManagementFactory
import java.net.URL
import java.nio.ByteBuffer
import java.util.concurrent._
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
Expand All @@ -31,24 +31,26 @@ import akka.actor.Props

import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader,
SparkUncaughtExceptionHandler, AkkaUtils, Utils}
import org.apache.spark.util._

/**
* Spark executor used with Mesos, YARN, and the standalone scheduler.
* In coarse-grained mode, an existing actor system is provided.
* Spark executor, backed by a threadpool to run tasks.
*
* This can be used with Mesos, YARN, and the standalone scheduler.
* An internal RPC interface (at the moment Akka) is used for communication with the driver,
* except in the case of Mesos fine-grained mode.
*/
private[spark] class Executor(
executorId: String,
executorHostname: String,
env: SparkEnv,
userClassPath: Seq[URL] = Nil,
isLocal: Boolean = false)
extends Logging
{
extends Logging {

logInfo(s"Starting executor ID $executorId on host $executorHostname")

// Application dependencies (added through SparkContext) that we've fetched so far on this node.
Expand Down Expand Up @@ -78,9 +80,8 @@ private[spark] class Executor(
}

// Start worker thread pool
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")

val executorSource = new ExecutorSource(this, executorId)
private val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
private val executorSource = new ExecutorSource(threadPool, executorId)

if (!isLocal) {
env.metricsSystem.registerSource(executorSource)
Expand Down Expand Up @@ -122,21 +123,21 @@ private[spark] class Executor(
taskId: Long,
attemptNumber: Int,
taskName: String,
serializedTask: ByteBuffer) {
serializedTask: ByteBuffer): Unit = {
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}

def killTask(taskId: Long, interruptThread: Boolean) {
def killTask(taskId: Long, interruptThread: Boolean): Unit = {
val tr = runningTasks.get(taskId)
if (tr != null) {
tr.kill(interruptThread)
}
}

def stop() {
def stop(): Unit = {
env.metricsSystem.report()
env.actorSystem.stop(executorActor)
isStopped = true
Expand All @@ -146,7 +147,10 @@ private[spark] class Executor(
}
}

private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
/** Returns the total amount of time this JVM process has spent in garbage collection. */
private def computeTotalGcTime(): Long = {
ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
}

class TaskRunner(
execBackend: ExecutorBackend,
Expand All @@ -156,27 +160,34 @@ private[spark] class Executor(
serializedTask: ByteBuffer)
extends Runnable {

/** Whether this task has been killed. */
@volatile private var killed = false
@volatile var task: Task[Any] = _
@volatile var attemptedTask: Option[Task[Any]] = None

/** How much the JVM process has spent in GC when the task starts to run. */
@volatile var startGCTime: Long = _

def kill(interruptThread: Boolean) {
/**
* The task to run. This will be set in run() by deserializing the task binary coming
* from the driver. Once it is set, it will never be changed.
*/
@volatile var task: Task[Any] = _

def kill(interruptThread: Boolean): Unit = {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
killed = true
if (task != null) {
task.kill(interruptThread)
}
}

override def run() {
override def run(): Unit = {
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
startGCTime = gcTime
startGCTime = computeTotalGcTime()

try {
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
Expand All @@ -193,7 +204,6 @@ private[spark] class Executor(
throw new TaskKilledException
}

attemptedTask = Some(task)
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)

Expand All @@ -215,18 +225,17 @@ private[spark] class Executor(
for (m <- task.metrics) {
m.setExecutorDeserializeTime(taskStart - deserializeStartTime)
m.setExecutorRunTime(taskFinish - taskStart)
m.setJvmGCTime(gcTime - startGCTime)
m.setJvmGCTime(computeTotalGcTime() - startGCTime)
m.setResultSerializationTime(afterSerialization - beforeSerialization)
}

val accumUpdates = Accumulators.values

val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit

// directSend = sending directly back to the driver
val serializedResult = {
val serializedResult: ByteBuffer = {
if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
Expand All @@ -248,42 +257,40 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

} catch {
case ffe: FetchFailedException => {
case ffe: FetchFailedException =>
val reason = ffe.toTaskEndReason
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}

case _: TaskKilledException | _: InterruptedException if task.killed => {
case _: TaskKilledException | _: InterruptedException if task.killed =>
logInfo(s"Executor killed $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
}

case cDE: CommitDeniedException => {
case cDE: CommitDeniedException =>
val reason = cDE.toTaskEndReason
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}

case t: Throwable => {
case t: Throwable =>
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
// the default uncaught exception handler, which will terminate the Executor.
logError(s"Exception in $taskName (TID $taskId)", t)

val serviceTime = System.currentTimeMillis() - taskStart
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
m.setExecutorRunTime(serviceTime)
m.setJvmGCTime(gcTime - startGCTime)
val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
task.metrics.map { m =>
m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
m.setJvmGCTime(computeTotalGcTime() - startGCTime)
m
}
}
val reason = new ExceptionFailure(t, metrics)
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
val taskEndReason = new ExceptionFailure(t, metrics)
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason))

// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
SparkUncaughtExceptionHandler.uncaughtException(t)
}
}

} finally {
// Release memory used by this thread for shuffles
env.shuffleMemoryManager.releaseMemoryForThisThread()
Expand Down Expand Up @@ -358,7 +365,7 @@ private[spark] class Executor(
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
}
Expand All @@ -370,12 +377,12 @@ private[spark] class Executor(
if (currentTimeStamp < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentJars(name) = timestamp
// Add it to our class loader
val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
if (!urlClassLoader.getURLs.contains(url)) {
val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL
if (!urlClassLoader.getURLs().contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
}
Expand All @@ -384,61 +391,70 @@ private[spark] class Executor(
}
}

def startDriverHeartbeater() {
val interval = conf.getInt("spark.executor.heartbeatInterval", 10000)
val timeout = AkkaUtils.lookupTimeout(conf)
val retryAttempts = AkkaUtils.numRetries(conf)
val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem)
private val timeout = AkkaUtils.lookupTimeout(conf)
private val retryAttempts = AkkaUtils.numRetries(conf)
private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
private val heartbeatReceiverRef =
AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem)

/** Reports heartbeat and metrics for active tasks to the driver. */
private def reportHeartBeat(): Unit = {
// list of (task id, metrics) to send back to the driver
val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
val curGCTime = computeTotalGcTime()

for (taskRunner <- runningTasks.values()) {
if (taskRunner.task != null) {
taskRunner.task.metrics.foreach { metrics =>
metrics.updateShuffleReadMetrics()
metrics.updateInputMetrics()
metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)

if (isLocal) {
// JobProgressListener will hold an reference of it during
// onExecutorMetricsUpdate(), then JobProgressListener can not see
// the changes of metrics any more, so make a deep copy of it
val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics))
tasksMetrics += ((taskRunner.taskId, copiedMetrics))
} else {
// It will be copied by serialization
tasksMetrics += ((taskRunner.taskId, metrics))
}
}
}
}

val t = new Thread() {
val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
try {
val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
retryAttempts, retryIntervalMs, timeout)
if (response.reregisterBlockManager) {
logWarning("Told to re-register on heartbeat")
env.blockManager.reregister()
}
} catch {
case NonFatal(e) => logWarning("Issue communicating with driver in heartbeater", e)
}
}

/**
* Starts a thread to report heartbeat and partial metrics for active tasks to driver.
* This thread stops running when the executor is stopped.
*/
private def startDriverHeartbeater(): Unit = {
val interval = conf.getInt("spark.executor.heartbeatInterval", 10000)
val thread = new Thread() {
override def run() {
// Sleep a random interval so the heartbeats don't end up in sync
Thread.sleep(interval + (math.random * interval).asInstanceOf[Int])

while (!isStopped) {
val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
val curGCTime = gcTime

for (taskRunner <- runningTasks.values()) {
if (taskRunner.attemptedTask.nonEmpty) {
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
metrics.updateShuffleReadMetrics()
metrics.updateInputMetrics()
metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)

if (isLocal) {
// JobProgressListener will hold an reference of it during
// onExecutorMetricsUpdate(), then JobProgressListener can not see
// the changes of metrics any more, so make a deep copy of it
val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics))
tasksMetrics += ((taskRunner.taskId, copiedMetrics))
} else {
// It will be copied by serialization
tasksMetrics += ((taskRunner.taskId, metrics))
}
}
}
}

val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
try {
val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef,
retryAttempts, retryIntervalMs, timeout)
if (response.reregisterBlockManager) {
logWarning("Told to re-register on heartbeat")
env.blockManager.reregister()
}
} catch {
case NonFatal(t) => logWarning("Issue communicating with driver in heartbeater", t)
}

reportHeartBeat()
Thread.sleep(interval)
}
}
}
t.setDaemon(true)
t.setName("Driver Heartbeater")
t.start()
thread.setDaemon(true)
thread.setName("driver-heartbeater")
thread.start()
}
}
Loading

0 comments on commit 0745a30

Please sign in to comment.