Skip to content

Commit

Permalink
[SPARK-6980] Changed addMessageIfTimeout to PartialFunction, cleanup …
Browse files Browse the repository at this point in the history
…from PR comments
  • Loading branch information
BryanCutler committed Jun 8, 2015
1 parent 2f94095 commit 1607a5f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 50 deletions.
51 changes: 19 additions & 32 deletions core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.concurrent.{Awaitable, Await, Future}
import scala.language.postfixOps

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.util.{ThreadUtils, RpcUtils, Utils}
import org.apache.spark.util.{RpcUtils, Utils}


/**
Expand Down Expand Up @@ -190,8 +190,8 @@ private[spark] object RpcAddress {
/**
* An exception thrown if RpcTimeout modifies a [[TimeoutException]].
*/
private[rpc] class RpcTimeoutException(message: String)
extends TimeoutException(message)
private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException)
extends TimeoutException(message) { initCause(cause) }


/**
Expand All @@ -209,27 +209,23 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {
def message: String = description

/** Amends the standard message of TimeoutException to include the description */
def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = {
new RpcTimeoutException(te.getMessage() + " " + description)
private def createRpcTimeoutException(te: TimeoutException): RpcTimeoutException = {
new RpcTimeoutException(te.getMessage() + " " + description, te)
}

/**
* Add a callback to the given Future so that if it completes as failed with a TimeoutException
* then the timeout description is added to the message
* PartialFunction to match a TimeoutException and add the timeout description to the message
*
* @note This can be used in the recover callback of a Future to add to a TimeoutException
* Example:
* val timeout = new RpcTimeout(5 millis, "short timeout")
* Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout)
*/
def addMessageIfTimeout[T](future: Future[T]): Future[T] = {
future.recover {
// Add a warning message if Future is passed to addMessageIfTimeoutTest more than once
case rte: RpcTimeoutException => throw new RpcTimeoutException(rte.getMessage() +
" (Future has multiple calls to RpcTimeout.addMessageIfTimeoutTest)")
// Any other TimeoutException get converted to a RpcTimeoutException with modified message
case te: TimeoutException => throw createRpcTimeoutException(te)
}(ThreadUtils.sameThread)
}

/** Applies the duration to create future before calling addMessageIfTimeout*/
def addMessageIfTimeout[T](f: FiniteDuration => Future[T]): Future[T] = {
addMessageIfTimeout(f(duration))
def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = {
// The exception has already been converted to a RpcTimeoutException so just raise it
case rte: RpcTimeoutException => throw rte
// Any other TimeoutException get converted to a RpcTimeoutException with modified message
case te: TimeoutException => throw createRpcTimeoutException(te)
}

/**
Expand All @@ -241,13 +237,7 @@ private[spark] class RpcTimeout(timeout: FiniteDuration, description: String) {
def awaitResult[T](awaitable: Awaitable[T]): T = {
try {
Await.result(awaitable, duration)
}
catch {
// The exception has already been converted to a RpcTimeoutException so just raise it
case rte: RpcTimeoutException => throw rte
// Any other TimeoutException get converted to a RpcTimeoutException with modified message
case te: TimeoutException => throw createRpcTimeoutException(te)
}
} catch addMessageIfTimeout
}
}

Expand Down Expand Up @@ -299,13 +289,10 @@ object RpcTimeout {

// Find the first set property or use the default value with the first property
val itr = timeoutPropList.iterator
var foundProp = None: Option[(String, String)]
var foundProp: Option[(String, String)] = None
while (itr.hasNext && foundProp.isEmpty){
val propKey = itr.next()
conf.getOption(propKey) match {
case Some(prop) => foundProp = Some(propKey,prop)
case None =>
}
conf.getOption(propKey).foreach { prop => foundProp = Some(propKey, prop) }
}
val finalProp = foundProp.getOrElse(timeoutPropList.head, defaultValue)
val timeout = { Utils.timeStringAsSeconds(finalProp._2) seconds }
Expand Down
36 changes: 18 additions & 18 deletions core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,11 @@ private[spark] class AkkaRpcEnv private[akka] (

override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = {
import actorSystem.dispatcher
defaultLookupTimeout.addMessageIfTimeout(
actorSystem.actorSelection(uri).resolveOne(_).
map(new AkkaRpcEndpointRef(defaultAddress, _, conf))
)
actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration).
map(new AkkaRpcEndpointRef(defaultAddress, _, conf)).
// this is just in case there is a timeout from creating the future in resolveOne, we want the
// exception to indicate the conf that determines the timeout
recover(defaultLookupTimeout.addMessageIfTimeout)
}

override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = {
Expand Down Expand Up @@ -297,20 +298,19 @@ private[akka] class AkkaRpcEndpointRef(
}

override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
timeout.addMessageIfTimeout(
actorRef.ask(AkkaMessage(message, true))(_).flatMap {
// The function will run in the calling thread, so it should be short and never block.
case msg @ AkkaMessage(message, reply) =>
if (reply) {
logError(s"Receive $msg but the sender cannot reply")
Future.failed(new SparkException(s"Receive $msg but the sender cannot reply"))
} else {
Future.successful(message)
}
case AkkaFailure(e) =>
Future.failed(e)
}(ThreadUtils.sameThread).mapTo[T]
)
actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap {
// The function will run in the calling thread, so it should be short and never block.
case msg @ AkkaMessage(message, reply) =>
if (reply) {
logError(s"Receive $msg but the sender cannot reply")
Future.failed(new SparkException(s"Receive $msg but the sender cannot reply"))
} else {
Future.successful(message)
}
case AkkaFailure(e) =>
Future.failed(e)
}(ThreadUtils.sameThread).mapTo[T].
recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
}

override def toString: String = s"${getClass.getSimpleName}($actorRef)"
Expand Down

0 comments on commit 1607a5f

Please sign in to comment.