Skip to content

Commit

Permalink
Increase HeartReceiverSuite coverage and clean up
Browse files Browse the repository at this point in the history
Now it covers expiring dead hosts and has much less duplicate code.
  • Loading branch information
Andrew Or committed Jul 2, 2015
1 parent 3a342de commit 4a903d6
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 61 deletions.
89 changes: 69 additions & 20 deletions core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ import scala.collection.mutable
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.scheduler.{SlaveLost, TaskScheduler}
import org.apache.spark.util.{ThreadUtils, Utils}
import org.apache.spark.scheduler._
import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils}

/**
* A heartbeat from executors to the driver. This is a shared message used by several internal
Expand All @@ -45,13 +45,23 @@ private[spark] case object TaskSchedulerIsSet

private[spark] case object ExpireDeadHosts

private case class ExecutorRegistered(executorId: String)

private case class ExecutorRemoved(executorId: String)

private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)

/**
* Lives in the driver to receive heartbeats from executors..
*/
private[spark] class HeartbeatReceiver(sc: SparkContext)
extends ThreadSafeRpcEndpoint with Logging {
private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock)
extends ThreadSafeRpcEndpoint with SparkListener with Logging {

def this(sc: SparkContext) {
this(sc, new SystemClock)
}

sc.addSparkListener(this)

override val rpcEnv: RpcEnv = sc.env.rpcEnv

Expand Down Expand Up @@ -86,30 +96,48 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
override def onStart(): Unit = {
timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
Option(self).foreach(_.send(ExpireDeadHosts))
Option(self).foreach(_.ask[Boolean](ExpireDeadHosts))
}
}, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
}

override def receive: PartialFunction[Any, Unit] = {
case ExpireDeadHosts =>
expireDeadHosts()
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {

// Messages sent and received locally
case ExecutorRegistered(executorId) =>
executorLastSeen(executorId) = clock.getTimeMillis()
context.reply(true)
case ExecutorRemoved(executorId) =>
executorLastSeen.remove(executorId)
context.reply(true)
case TaskSchedulerIsSet =>
scheduler = sc.taskScheduler
}
context.reply(true)
case ExpireDeadHosts =>
expireDeadHosts()
context.reply(true)

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
// Messages received from executors
case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) =>
if (scheduler != null) {
executorLastSeen(executorId) = System.currentTimeMillis()
eventLoopThread.submit(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
val unknownExecutor = !scheduler.executorHeartbeatReceived(
executorId, taskMetrics, blockManagerId)
val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
context.reply(response)
}
})
if (executorLastSeen.contains(executorId)) {
executorLastSeen(executorId) = clock.getTimeMillis()
eventLoopThread.submit(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
val unknownExecutor = !scheduler.executorHeartbeatReceived(
executorId, taskMetrics, blockManagerId)
val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
context.reply(response)
}
})
} else {
// This may happen if we get an executor's in-flight heartbeat immediately
// after we just removed it. It's not really an error condition so we should
// not log warning here. Otherwise there may be a lot of noise especially if
// we explicitly remove executors (SPARK-4134).
logDebug(s"Received heartbeat from unknown executor $executorId")
context.reply(HeartbeatResponse(reregisterBlockManager = true))
}
} else {
// Because Executor will sleep several seconds before sending the first "Heartbeat", this
// case rarely happens. However, if it really happens, log it and ask the executor to
Expand All @@ -119,9 +147,30 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
}
}

/**
* If the heartbeat receiver is not stopped, notify it of executor registrations.
*/
override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = {
Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId)))
}

/**
* If the heartbeat receiver is not stopped, notify it of executor removals so it doesn't
* log superfluous errors.
*
* Note that we must do this after the executor is actually removed to guard against the
* following race condition: if we remove an executor's metadata from our data structure
* prematurely, we may get an in-flight heartbeat from the executor before the executor is
* actually removed, in which case we will still mark the executor as a dead host later
* and expire it with loud error messages.
*/
override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = {
Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId)))
}

private def expireDeadHosts(): Unit = {
logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.")
val now = System.currentTimeMillis()
val now = clock.getTimeMillis()
for ((executorId, lastSeenMs) <- executorLastSeen) {
if (now - lastSeenMs > executorTimeoutMs) {
logWarning(s"Removing executor $executorId with no recent heartbeats: " +
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
_schedulerBackend = sched
_taskScheduler = ts
_dagScheduler = new DAGScheduler(this)
_heartbeatReceiver.send(TaskSchedulerIsSet)
_heartbeatReceiver.ask[Boolean](TaskSchedulerIsSet)

// start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's
// constructor
Expand Down
161 changes: 121 additions & 40 deletions core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,64 +17,145 @@

package org.apache.spark

import scala.concurrent.duration._
import scala.language.postfixOps

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
import org.mockito.Mockito.{mock, spy, verify, when}
import org.mockito.Matchers
import org.mockito.Matchers._

import org.apache.spark.scheduler.TaskScheduler
import org.apache.spark.util.RpcUtils
import org.scalatest.concurrent.Eventually._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler._
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.ManualClock

class HeartbeatReceiverSuite extends SparkFunSuite with LocalSparkContext {
class HeartbeatReceiverSuite
extends SparkFunSuite
with BeforeAndAfterEach
with PrivateMethodTester
with LocalSparkContext {

test("HeartbeatReceiver") {
private val executorId1 = "executor-1"
private val executorId2 = "executor-2"

// Shared state that must be reset before and after each test
private var scheduler: TaskScheduler = null
private var heartbeatReceiver: HeartbeatReceiver = null
private var heartbeatReceiverRef: RpcEndpointRef = null
private var heartbeatReceiverClock: ManualClock = null

override def beforeEach(): Unit = {
sc = spy(new SparkContext("local[2]", "test"))
val scheduler = mock(classOf[TaskScheduler])
when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
scheduler = mock(classOf[TaskScheduler])
when(sc.taskScheduler).thenReturn(scheduler)
heartbeatReceiverClock = new ManualClock
heartbeatReceiver = new HeartbeatReceiver(sc, heartbeatReceiverClock)
heartbeatReceiverRef = sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver)
when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true)
}

val heartbeatReceiver = new HeartbeatReceiver(sc)
sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
eventually(timeout(5 seconds), interval(5 millis)) {
assert(heartbeatReceiver.scheduler != null)
}
val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
override def afterEach(): Unit = {
resetSparkContext()
scheduler = null
heartbeatReceiver = null
heartbeatReceiverRef = null
heartbeatReceiverClock = null
}

val metrics = new TaskMetrics
val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
val response = receiverRef.askWithRetry[HeartbeatResponse](
Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
test("task scheduler is set correctly") {
assert(heartbeatReceiver.scheduler === null)
heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
assert(heartbeatReceiver.scheduler !== null)
}

verify(scheduler).executorHeartbeatReceived(
Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
assert(false === response.reregisterBlockManager)
test("normal heartbeat") {
heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
triggerHeartbeat(executorId1, executorShouldReregister = false)
triggerHeartbeat(executorId2, executorShouldReregister = false)
val trackedExecutors = executorLastSeen(heartbeatReceiver)
assert(trackedExecutors.size === 2)
assert(trackedExecutors.contains(executorId1))
assert(trackedExecutors.contains(executorId2))
}

test("HeartbeatReceiver re-register") {
sc = spy(new SparkContext("local[2]", "test"))
val scheduler = mock(classOf[TaskScheduler])
when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false)
when(sc.taskScheduler).thenReturn(scheduler)
test("reregister if scheduler is not ready yet") {
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
// Task scheduler not set in HeartbeatReceiver
triggerHeartbeat(executorId1, executorShouldReregister = true)
}

val heartbeatReceiver = new HeartbeatReceiver(sc)
sc.env.rpcEnv.setupEndpoint("heartbeat", heartbeatReceiver).send(TaskSchedulerIsSet)
eventually(timeout(5 seconds), interval(5 millis)) {
assert(heartbeatReceiver.scheduler != null)
}
val receiverRef = RpcUtils.makeDriverRef("heartbeat", sc.conf, sc.env.rpcEnv)
test("reregister if heartbeat from unregistered executor") {
heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
// Received heartbeat from unknown receiver, so we ask it to re-register
triggerHeartbeat(executorId1, executorShouldReregister = true)
assert(executorLastSeen(heartbeatReceiver).isEmpty)
}

test("reregister if heartbeat from removed executor") {
heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
// Remove the second executor but not the first
heartbeatReceiver.onExecutorRemoved(SparkListenerExecutorRemoved(0, executorId2, "bad boy"))
// Now trigger the heartbeats
// A heartbeat from the second executor should require reregistering
triggerHeartbeat(executorId1, executorShouldReregister = false)
triggerHeartbeat(executorId2, executorShouldReregister = true)
val trackedExecutors = executorLastSeen(heartbeatReceiver)
assert(trackedExecutors.size === 1)
assert(trackedExecutors.contains(executorId1))
assert(!trackedExecutors.contains(executorId2))
}

test("expire dead hosts") {
val executorTimeout = executorTimeoutMs(heartbeatReceiver)
heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet)
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null))
heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null))
triggerHeartbeat(executorId1, executorShouldReregister = false)
triggerHeartbeat(executorId2, executorShouldReregister = false)
// Advance the clock and only trigger a heartbeat for the first executor
heartbeatReceiverClock.advance(executorTimeout / 2)
triggerHeartbeat(executorId1, executorShouldReregister = false)
heartbeatReceiverClock.advance(executorTimeout)
heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts)
// Only the second executor should be expired as a dead host
verify(scheduler).executorLost(Matchers.eq(executorId2), any())
val trackedExecutors = executorLastSeen(heartbeatReceiver)
assert(trackedExecutors.size === 1)
assert(trackedExecutors.contains(executorId1))
assert(!trackedExecutors.contains(executorId2))
}

/** Manually send a heartbeat and return the response. */
private def triggerHeartbeat(
executorId: String,
executorShouldReregister: Boolean): Unit = {
val metrics = new TaskMetrics
val blockManagerId = BlockManagerId("executor-1", "localhost", 12345)
val response = receiverRef.askWithRetry[HeartbeatResponse](
Heartbeat("executor-1", Array(1L -> metrics), blockManagerId))
val blockManagerId = BlockManagerId(executorId, "localhost", 12345)
val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](
Heartbeat(executorId, Array(1L -> metrics), blockManagerId))
if (executorShouldReregister) {
assert(response.reregisterBlockManager)
} else {
assert(!response.reregisterBlockManager)
// Additionally verify that the scheduler callback is called with the correct parameters
verify(scheduler).executorHeartbeatReceived(
Matchers.eq(executorId), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
}
}

verify(scheduler).executorHeartbeatReceived(
Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId))
assert(true === response.reregisterBlockManager)
// Helper methods to access private fields in HeartbeatReceiver
private val _executorLastSeen = PrivateMethod[collection.Map[String, Long]]('executorLastSeen)
private val _executorTimeoutMs = PrivateMethod[Long]('executorTimeoutMs)
private def executorLastSeen(receiver: HeartbeatReceiver): collection.Map[String, Long] = {
receiver invokePrivate _executorLastSeen()
}
private def executorTimeoutMs(receiver: HeartbeatReceiver): Long = {
receiver invokePrivate _executorTimeoutMs()
}

}

0 comments on commit 4a903d6

Please sign in to comment.