From 4a903d6b9c5f13ad8342e2f05b3b21f85e46cef4 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 1 Jul 2015 18:43:49 -0700 Subject: [PATCH] Increase HeartReceiverSuite coverage and clean up Now it covers expiring dead hosts and has much less duplicate code. --- .../org/apache/spark/HeartbeatReceiver.scala | 89 +++++++--- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../apache/spark/HeartbeatReceiverSuite.scala | 161 +++++++++++++----- 3 files changed, 191 insertions(+), 61 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 6909015ff66e6..221b1dab43278 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -24,8 +24,8 @@ import scala.collection.mutable import org.apache.spark.executor.TaskMetrics import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId -import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.scheduler._ +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -45,13 +45,23 @@ private[spark] case object TaskSchedulerIsSet private[spark] case object ExpireDeadHosts +private case class ExecutorRegistered(executorId: String) + +private case class ExecutorRemoved(executorId: String) + private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(sc: SparkContext) - extends ThreadSafeRpcEndpoint with Logging { +private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) + extends ThreadSafeRpcEndpoint with SparkListener with Logging { + + def this(sc: SparkContext) { + this(sc, new SystemClock) + } + + sc.addSparkListener(this) override val rpcEnv: RpcEnv = sc.env.rpcEnv @@ -86,30 +96,48 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) override def onStart(): Unit = { timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - Option(self).foreach(_.send(ExpireDeadHosts)) + Option(self).foreach(_.ask[Boolean](ExpireDeadHosts)) } }, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS) } - override def receive: PartialFunction[Any, Unit] = { - case ExpireDeadHosts => - expireDeadHosts() + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + + // Messages sent and received locally + case ExecutorRegistered(executorId) => + executorLastSeen(executorId) = clock.getTimeMillis() + context.reply(true) + case ExecutorRemoved(executorId) => + executorLastSeen.remove(executorId) + context.reply(true) case TaskSchedulerIsSet => scheduler = sc.taskScheduler - } + context.reply(true) + case ExpireDeadHosts => + expireDeadHosts() + context.reply(true) - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + // Messages received from executors case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { - executorLastSeen(executorId) = System.currentTimeMillis() - eventLoopThread.submit(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) - context.reply(response) - } - }) + if (executorLastSeen.contains(executorId)) { + executorLastSeen(executorId) = clock.getTimeMillis() + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) + } else { + // This may happen if we get an executor's in-flight heartbeat immediately + // after we just removed it. It's not really an error condition so we should + // not log warning here. Otherwise there may be a lot of noise especially if + // we explicitly remove executors (SPARK-4134). + logDebug(s"Received heartbeat from unknown executor $executorId") + context.reply(HeartbeatResponse(reregisterBlockManager = true)) + } } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to @@ -119,9 +147,30 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) } } + /** + * If the heartbeat receiver is not stopped, notify it of executor registrations. + */ + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId))) + } + + /** + * If the heartbeat receiver is not stopped, notify it of executor removals so it doesn't + * log superfluous errors. + * + * Note that we must do this after the executor is actually removed to guard against the + * following race condition: if we remove an executor's metadata from our data structure + * prematurely, we may get an in-flight heartbeat from the executor before the executor is + * actually removed, in which case we will still mark the executor as a dead host later + * and expire it with loud error messages. + */ + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { + Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId))) + } + private def expireDeadHosts(): Unit = { logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.") - val now = System.currentTimeMillis() + val now = clock.getTimeMillis() for ((executorId, lastSeenMs) <- executorLastSeen) { if (now - lastSeenMs > executorTimeoutMs) { logWarning(s"Removing executor $executorId with no recent heartbeats: " + diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0e5a86f44e410..75591de4b5f0e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -498,7 +498,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _schedulerBackend = sched _taskScheduler = ts _dagScheduler = new DAGScheduler(this) - _heartbeatReceiver.send(TaskSchedulerIsSet) + _heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet) // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's // constructor diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 911b3bddd1836..b31b09196608f 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -17,64 +17,145 @@ package org.apache.spark -import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.storage.BlockManagerId +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.mockito.Mockito.{mock, spy, verify, when} import org.mockito.Matchers import org.mockito.Matchers._ -import org.apache.spark.scheduler.TaskScheduler -import org.apache.spark.util.RpcUtils -import org.scalatest.concurrent.Eventually._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ManualClock -class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext { +class HeartbeatReceiverSuite + extends SparkFunSuite + with BeforeAndAfterEach + with PrivateMethodTester + with LocalSparkContext { - test("HeartbeatReceiver") { + private val executorId1 = "executor-1" + private val executorId2 = "executor-2" + + // Shared state that must be reset before and after each test + private var scheduler: TaskScheduler = null + private var heartbeatReceiver: HeartbeatReceiver = null + private var heartbeatReceiverRef: RpcEndpointRef = null + private var heartbeatReceiverClock: ManualClock = null + + override def beforeEach(): Unit = { sc = spy(new SparkContext("local[2]", "test")) - val scheduler = mock(classOf[TaskScheduler]) - when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + scheduler = mock(classOf[TaskScheduler]) when(sc.taskScheduler).thenReturn(scheduler) + heartbeatReceiverClock = new ManualClock + heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock) + heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + } - val heartbeatReceiver = new HeartbeatReceiver(sc) - sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) - eventually(timeout(5 seconds), interval(5 millis)) { - assert(heartbeatReceiver.scheduler != null) - } - val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + override def afterEach(): Unit = { + resetSparkContext() + scheduler = null + heartbeatReceiver = null + heartbeatReceiverRef = null + heartbeatReceiverClock = null + } - val metrics = new TaskMetrics - val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithRetry[HeartbeatResponse]( - Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + test("task scheduler is set correctly") { + assert(heartbeatReceiver.scheduler === null) + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + assert(heartbeatReceiver.scheduler !== null) + } - verify(scheduler).executorHeartbeatReceived( - Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) - assert(false === response.reregisterBlockManager) + test("normal heartbeat") { + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = false) + val trackedExecutors = executorLastSeen(heartbeatReceiver) + assert(trackedExecutors.size === 2) + assert(trackedExecutors.contains(executorId1)) + assert(trackedExecutors.contains(executorId2)) } - test("HeartbeatReceiver re-register") { - sc = spy(new SparkContext("local[2]", "test")) - val scheduler = mock(classOf[TaskScheduler]) - when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) - when(sc.taskScheduler).thenReturn(scheduler) + test("reregister if scheduler is not ready yet") { + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + // Task scheduler not set in HeartbeatReceiver + triggerHeartbeat(executorId1, executorShouldReregister = true) + } - val heartbeatReceiver = new HeartbeatReceiver(sc) - sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet) - eventually(timeout(5 seconds), interval(5 millis)) { - assert(heartbeatReceiver.scheduler != null) - } - val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv) + test("reregister if heartbeat from unregistered executor") { + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + // Received heartbeat from unknown receiver, so we ask it to re-register + triggerHeartbeat(executorId1, executorShouldReregister = true) + assert(executorLastSeen(heartbeatReceiver).isEmpty) + } + + test("reregister if heartbeat from removed executor") { + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + // Remove the second executor but not the first + heartbeatReceiver.onExecutorRemoved(SparkListenerExecutorRemoved(0, executorId2, "bad boy")) + // Now trigger the heartbeats + // A heartbeat from the second executor should require reregistering + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = true) + val trackedExecutors = executorLastSeen(heartbeatReceiver) + assert(trackedExecutors.size === 1) + assert(trackedExecutors.contains(executorId1)) + assert(!trackedExecutors.contains(executorId2)) + } + test("expire dead hosts") { + val executorTimeout = executorTimeoutMs(heartbeatReceiver) + heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + triggerHeartbeat(executorId1, executorShouldReregister = false) + triggerHeartbeat(executorId2, executorShouldReregister = false) + // Advance the clock and only trigger a heartbeat for the first executor + heartbeatReceiverClock.advance(executorTimeout / 2) + triggerHeartbeat(executorId1, executorShouldReregister = false) + heartbeatReceiverClock.advance(executorTimeout) + heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) + // Only the second executor should be expired as a dead host + verify(scheduler).executorLost(Matchers.eq(executorId2), any()) + val trackedExecutors = executorLastSeen(heartbeatReceiver) + assert(trackedExecutors.size === 1) + assert(trackedExecutors.contains(executorId1)) + assert(!trackedExecutors.contains(executorId2)) + } + + /** Manually send a heartbeat and return the response. */ + private def triggerHeartbeat( + executorId: String, + executorShouldReregister: Boolean): Unit = { val metrics = new TaskMetrics - val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithRetry[HeartbeatResponse]( - Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + val blockManagerId = BlockManagerId(executorId, "localhost", 12345) + val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( + Heartbeat(executorId, Array(1L -> metrics), blockManagerId)) + if (executorShouldReregister) { + assert(response.reregisterBlockManager) + } else { + assert(!response.reregisterBlockManager) + // Additionally verify that the scheduler callback is called with the correct parameters + verify(scheduler).executorHeartbeatReceived( + Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + } + } - verify(scheduler).executorHeartbeatReceived( - Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) - assert(true === response.reregisterBlockManager) + // Helper methods to access private fields in HeartbeatReceiver + private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen) + private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs) + private def executorLastSeen(receiver: HeartbeatReceiver): collection.Map[String, Long] = { + receiver invokePrivate _executorLastSeen() + } + private def executorTimeoutMs(receiver: HeartbeatReceiver): Long = { + receiver invokePrivate _executorTimeoutMs() } + }