diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 3e10b9eee4e24..5d48bc7c96555 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -55,7 +55,8 @@ private class ShuffleStatus(numPartitions: Int) { * locations is so small that we choose to ignore that case and store only a single location * for each output. */ - private[this] val mapStatuses = new Array[MapStatus](numPartitions) + // Exposed for testing + val mapStatuses = new Array[MapStatus](numPartitions) /** * The cached result of serializing the map statuses array. This cache is lazily populated when @@ -105,14 +106,30 @@ private class ShuffleStatus(numPartitions: Int) { } } + /** + * Removes all shuffle outputs associated with this host. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists). + */ + def removeOutputsOnHost(host: String): Unit = { + removeOutputsByFilter(x => x.host == host) + } + /** * Removes all map outputs associated with the specified executor. Note that this will also * remove outputs which are served by an external shuffle server (if one exists), as they are * still registered with that execId. */ def removeOutputsOnExecutor(execId: String): Unit = synchronized { + removeOutputsByFilter(x => x.executorId == execId) + } + + /** + * Removes all shuffle outputs which satisfies the filter. Note that this will also + * remove outputs which are served by an external shuffle server (if one exists). + */ + def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized { for (mapId <- 0 until mapStatuses.length) { - if (mapStatuses(mapId) != null && mapStatuses(mapId).location.executorId == execId) { + if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) { _numAvailableOutputs -= 1 mapStatuses(mapId) = null invalidateSerializedMapOutputStatusCache() @@ -317,7 +334,8 @@ private[spark] class MapOutputTrackerMaster( // HashMap for storing shuffleStatuses in the driver. // Statuses are dropped only by explicit de-registering. - private val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala + // Exposed for testing + val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) @@ -415,6 +433,15 @@ private[spark] class MapOutputTrackerMaster( } } + /** + * Removes all shuffle outputs associated with this host. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists). + */ + def removeOutputsOnHost(host: String): Unit = { + shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) } + incrementEpoch() + } + /** * Removes all shuffle outputs associated with this executor. Note that this will also remove * outputs which are served by an external shuffle server (if one exists), as they are still diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7827e6760f355..84ef57f2d271b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -151,6 +151,14 @@ package object config { .createOptional // End blacklist confs + private[spark] val UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE = + ConfigBuilder("spark.files.fetchFailure.unRegisterOutputOnHost") + .doc("Whether to un-register all the outputs on the host in condition that we receive " + + " a FetchFailure. This is set default to false, which means, we only un-register the " + + " outputs related to the exact executor(instead of the host) on a FetchFailure.") + .booleanConf + .createWithDefault(false) + private[spark] val LISTENER_BUS_EVENT_QUEUE_CAPACITY = ConfigBuilder("spark.scheduler.listenerbus.eventqueue.capacity") .withAlternative("spark.scheduler.listenerbus.eventqueue.size") diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 932e6c138e1c4..fafe9cafdc18f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -35,6 +35,7 @@ import org.apache.commons.lang3.SerializationUtils import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} @@ -187,6 +188,14 @@ class DAGScheduler( /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) + /** + * Whether to unregister all the outputs on the host in condition that we receive a FetchFailure, + * this is set default to false, which means, we only unregister the outputs related to the exact + * executor(instead of the host) on a FetchFailure. + */ + private[scheduler] val unRegisterOutputOnHostOnFetchFailure = + sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) + /** * Number of consecutive stage attempts allowed before a stage is aborted. */ @@ -1336,7 +1345,21 @@ class DAGScheduler( // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, filesLost = true, Some(task.epoch)) + val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled && + unRegisterOutputOnHostOnFetchFailure) { + // We had a fetch failure with the external shuffle service, so we + // assume all shuffle data on the node is bad. + Some(bmAddress.host) + } else { + // Unregister shuffle data just for one executor (we don't have any + // reason to believe shuffle data has been lost for the entire host). + None + } + removeExecutorAndUnregisterOutputs( + execId = bmAddress.executorId, + fileLost = true, + hostToUnregisterOutputs = hostToUnregisterOutputs, + maybeEpoch = Some(task.epoch)) } } @@ -1370,22 +1393,42 @@ class DAGScheduler( */ private[scheduler] def handleExecutorLost( execId: String, - filesLost: Boolean, - maybeEpoch: Option[Long] = None) { + workerLost: Boolean): Unit = { + // if the cluster manager explicitly tells us that the entire worker was lost, then + // we know to unregister shuffle output. (Note that "worker" specifically refers to the process + // from a Standalone cluster, where the shuffle service lives in the Worker.) + val fileLost = workerLost || !env.blockManager.externalShuffleServiceEnabled + removeExecutorAndUnregisterOutputs( + execId = execId, + fileLost = fileLost, + hostToUnregisterOutputs = None, + maybeEpoch = None) + } + + private def removeExecutorAndUnregisterOutputs( + execId: String, + fileLost: Boolean, + hostToUnregisterOutputs: Option[String], + maybeEpoch: Option[Long] = None): Unit = { val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { failedEpoch(execId) = currentEpoch logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) blockManagerMaster.removeExecutor(execId) - - if (filesLost || !env.blockManager.externalShuffleServiceEnabled) { - logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) - mapOutputTracker.removeOutputsOnExecutor(execId) + if (fileLost) { + hostToUnregisterOutputs match { + case Some(host) => + logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch)) + mapOutputTracker.removeOutputsOnHost(host) + case None => + logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) + mapOutputTracker.removeOutputsOnExecutor(execId) + } clearCacheLocs() + + } else { + logDebug("Additional executor lost message for %s (epoch %d)".format(execId, currentEpoch)) } - } else { - logDebug("Additional executor lost message for " + execId + - "(epoch " + currentEpoch + ")") } } @@ -1678,11 +1721,11 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler dagScheduler.handleExecutorAdded(execId, host) case ExecutorLost(execId, reason) => - val filesLost = reason match { + val workerLost = reason match { case SlaveLost(_, true) => true case _ => false } - dagScheduler.handleExecutorLost(execId, filesLost) + dagScheduler.handleExecutorLost(execId, workerLost) case BeginEvent(task, taskInfo) => dagScheduler.handleBeginEvent(task, taskInfo) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 67145e7445061..ddd3281106745 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -396,6 +396,73 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } + test("All shuffle files on the slave should be cleaned up when slave lost") { + // reset the test context with the right shuffle service config + afterEach() + val conf = new SparkConf() + conf.set("spark.shuffle.service.enabled", "true") + conf.set("spark.files.fetchFailure.unRegisterOutputOnHost", "true") + init(conf) + runEvent(ExecutorAdded("exec-hostA1", "hostA")) + runEvent(ExecutorAdded("exec-hostA2", "hostA")) + runEvent(ExecutorAdded("exec-hostB", "hostB")) + val firstRDD = new MyRDD(sc, 3, Nil) + val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(3)) + val firstShuffleId = firstShuffleDep.shuffleId + val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep)) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(3)) + val secondShuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + // map stage1 completes successfully, with one task on each executor + complete(taskSets(0), Seq( + (Success, + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + (Success, + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + (Success, makeMapStatus("hostB", 1)) + )) + // map stage2 completes successfully, with one task on each executor + complete(taskSets(1), Seq( + (Success, + MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + (Success, + MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), + (Success, makeMapStatus("hostB", 1)) + )) + // make sure our test setup is correct + val initialMapStatus1 = mapOutputTracker.shuffleStatuses(firstShuffleId).mapStatuses + // val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get + assert(initialMapStatus1.count(_ != null) === 3) + assert(initialMapStatus1.map{_.location.executorId}.toSet === + Set("exec-hostA1", "exec-hostA2", "exec-hostB")) + + val initialMapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses + // val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get + assert(initialMapStatus2.count(_ != null) === 3) + assert(initialMapStatus2.map{_.location.executorId}.toSet === + Set("exec-hostA1", "exec-hostA2", "exec-hostB")) + + // reduce stage fails with a fetch failure from one host + complete(taskSets(2), Seq( + (FetchFailed(BlockManagerId("exec-hostA2", "hostA", 12345), firstShuffleId, 0, 0, "ignored"), + null) + )) + + // Here is the main assertion -- make sure that we de-register + // the map outputs for both map stage from both executors on hostA + + val mapStatus1 = mapOutputTracker.shuffleStatuses(firstShuffleId).mapStatuses + assert(mapStatus1.count(_ != null) === 1) + assert(mapStatus1(2).location.executorId === "exec-hostB") + assert(mapStatus1(2).location.host === "hostB") + + val mapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses + assert(mapStatus2.count(_ != null) === 1) + assert(mapStatus2(2).location.executorId === "exec-hostB") + assert(mapStatus2(2).location.host === "hostB") + } + test("zero split job") { var numResults = 0 var failureReason: Option[Exception] = None