diff --git a/.rat-excludes b/.rat-excludes index 613096395dc72..baddec547537c 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -15,6 +15,7 @@ TAGS RELEASE control docs +docker.properties.template fairscheduler.xml.template spark-defaults.conf.template log4j.properties diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 02f51fe59a911..00fd30fa38d36 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -19,7 +19,7 @@ rem set SPARK_HOME=%~dp0.. -echo "%*" | findstr " --help -h" >nul +echo "%*" | findstr " \<--help\> \<-h\>" >nul if %ERRORLEVEL% equ 0 ( call :usage exit /b 0 diff --git a/conf/docker.properties.template b/conf/docker.properties.template new file mode 100644 index 0000000000000..26e3bfd9c5b9b --- /dev/null +++ b/conf/docker.properties.template @@ -0,0 +1,3 @@ +spark.mesos.executor.docker.image: +spark.mesos.executor.docker.volumes: /usr/local/lib:/host/usr/local/lib:ro +spark.mesos.executor.home: /opt/spark diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index b986fa87dc2f4..228d9149df2a2 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -27,13 +27,20 @@ import org.apache.spark.util.{ThreadUtils, Clock, SystemClock, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. * - * The add policy depends on whether there are backlogged tasks waiting to be scheduled. If - * the scheduler queue is not drained in N seconds, then new executors are added. If the queue - * persists for another M seconds, then more executors are added and so on. The number added - * in each round increases exponentially from the previous round until an upper bound on the - * number of executors has been reached. The upper bound is based both on a configured property - * and on the number of tasks pending: the policy will never increase the number of executor - * requests past the number needed to handle all pending tasks. + * The ExecutorAllocationManager maintains a moving target number of executors which is periodically + * synced to the cluster manager. The target starts at a configured initial value and changes with + * the number of pending and running tasks. + * + * Decreasing the target number of executors happens when the current target is more than needed to + * handle the current load. The target number of executors is always truncated to the number of + * executors that could run all current running and pending tasks at once. + * + * Increasing the target number of executors happens in response to backlogged tasks waiting to be + * scheduled. If the scheduler queue is not drained in N seconds, then new executors are added. If + * the queue persists for another M seconds, then more executors are added and so on. The number + * added in each round increases exponentially from the previous round until an upper bound has been + * reached. The upper bound is based both on a configured property and on the current number of + * running and pending tasks, as described above. * * The rationale for the exponential increase is twofold: (1) Executors should be added slowly * in the beginning in case the number of extra executors needed turns out to be small. Otherwise, @@ -105,8 +112,10 @@ private[spark] class ExecutorAllocationManager( // Number of executors to add in the next round private var numExecutorsToAdd = 1 - // Number of executors that have been requested but have not registered yet - private var numExecutorsPending = 0 + // The desired number of executors at this moment in time. If all our executors were to die, this + // is the number of executors we would immediately want from the cluster manager. + private var numExecutorsTarget = + conf.getInt("spark.dynamicAllocation.initialExecutors", minNumExecutors) // Executors that have been requested to be removed but have not been killed yet private val executorsPendingToRemove = new mutable.HashSet[String] @@ -199,13 +208,6 @@ private[spark] class ExecutorAllocationManager( executor.awaitTermination(10, TimeUnit.SECONDS) } - /** - * The number of executors we would have if the cluster manager were to fulfill all our existing - * requests. - */ - private def targetNumExecutors(): Int = - numExecutorsPending + executorIds.size - executorsPendingToRemove.size - /** * The maximum number of executors we would need under the current load to satisfy all running * and pending tasks, rounded up. @@ -227,7 +229,7 @@ private[spark] class ExecutorAllocationManager( private def schedule(): Unit = synchronized { val now = clock.getTimeMillis - addOrCancelExecutorRequests(now) + updateAndSyncNumExecutorsTarget(now) removeTimes.retain { case (executorId, expireTime) => val expired = now >= expireTime @@ -239,26 +241,28 @@ private[spark] class ExecutorAllocationManager( } /** + * Updates our target number of executors and syncs the result with the cluster manager. + * * Check to see whether our existing allocation and the requests we've made previously exceed our - * current needs. If so, let the cluster manager know so that it can cancel pending requests that - * are unneeded. + * current needs. If so, truncate our target and let the cluster manager know so that it can + * cancel pending requests that are unneeded. * * If not, and the add time has expired, see if we can request new executors and refresh the add * time. * * @return the delta in the target number of executors. */ - private def addOrCancelExecutorRequests(now: Long): Int = synchronized { - val currentTarget = targetNumExecutors + private def updateAndSyncNumExecutorsTarget(now: Long): Int = synchronized { val maxNeeded = maxNumExecutorsNeeded - if (maxNeeded < currentTarget) { + if (maxNeeded < numExecutorsTarget) { // The target number exceeds the number we actually need, so stop adding new - // executors and inform the cluster manager to cancel the extra pending requests. - val newTotalExecutors = math.max(maxNeeded, minNumExecutors) - client.requestTotalExecutors(newTotalExecutors) + // executors and inform the cluster manager to cancel the extra pending requests + val oldNumExecutorsTarget = numExecutorsTarget + numExecutorsTarget = math.max(maxNeeded, minNumExecutors) + client.requestTotalExecutors(numExecutorsTarget) numExecutorsToAdd = 1 - updateNumExecutorsPending(newTotalExecutors) + numExecutorsTarget - oldNumExecutorsTarget } else if (addTime != NOT_SET && now >= addTime) { val delta = addExecutors(maxNeeded) logDebug(s"Starting timer to add more executors (to " + @@ -281,21 +285,30 @@ private[spark] class ExecutorAllocationManager( */ private def addExecutors(maxNumExecutorsNeeded: Int): Int = { // Do not request more executors if it would put our target over the upper bound - val currentTarget = targetNumExecutors - if (currentTarget >= maxNumExecutors) { - logDebug(s"Not adding executors because there are already ${executorIds.size} " + - s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)") + if (numExecutorsTarget >= maxNumExecutors) { + val numExecutorsPending = numExecutorsTarget - executorIds.size + logDebug(s"Not adding executors because there are already ${executorIds.size} registered " + + s"and ${numExecutorsPending} pending executor(s) (limit $maxNumExecutors)") numExecutorsToAdd = 1 return 0 } - val actualMaxNumExecutors = math.min(maxNumExecutors, maxNumExecutorsNeeded) - val newTotalExecutors = math.min(currentTarget + numExecutorsToAdd, actualMaxNumExecutors) - val addRequestAcknowledged = testing || client.requestTotalExecutors(newTotalExecutors) + val oldNumExecutorsTarget = numExecutorsTarget + // There's no point in wasting time ramping up to the number of executors we already have, so + // make sure our target is at least as much as our current allocation: + numExecutorsTarget = math.max(numExecutorsTarget, executorIds.size) + // Boost our target with the number to add for this round: + numExecutorsTarget += numExecutorsToAdd + // Ensure that our target doesn't exceed what we need at the present moment: + numExecutorsTarget = math.min(numExecutorsTarget, maxNumExecutorsNeeded) + // Ensure that our target fits within configured bounds: + numExecutorsTarget = math.max(math.min(numExecutorsTarget, maxNumExecutors), minNumExecutors) + + val addRequestAcknowledged = testing || client.requestTotalExecutors(numExecutorsTarget) if (addRequestAcknowledged) { - val delta = updateNumExecutorsPending(newTotalExecutors) + val delta = numExecutorsTarget - oldNumExecutorsTarget logInfo(s"Requesting $delta new executor(s) because tasks are backlogged" + - s" (new desired total will be $newTotalExecutors)") + s" (new desired total will be $numExecutorsTarget)") numExecutorsToAdd = if (delta == numExecutorsToAdd) { numExecutorsToAdd * 2 } else { @@ -304,23 +317,11 @@ private[spark] class ExecutorAllocationManager( delta } else { logWarning( - s"Unable to reach the cluster manager to request $newTotalExecutors total executors!") + s"Unable to reach the cluster manager to request $numExecutorsTarget total executors!") 0 } } - /** - * Given the new target number of executors, update the number of pending executor requests, - * and return the delta from the old number of pending requests. - */ - private def updateNumExecutorsPending(newTotalExecutors: Int): Int = { - val newNumExecutorsPending = - newTotalExecutors - executorIds.size + executorsPendingToRemove.size - val delta = newNumExecutorsPending - numExecutorsPending - numExecutorsPending = newNumExecutorsPending - delta - } - /** * Request the cluster manager to remove the given executor. * Return whether the request is received. @@ -372,10 +373,6 @@ private[spark] class ExecutorAllocationManager( // as idle again so as not to forget that it is a candidate for removal. (see SPARK-4951) executorIds.filter(listener.isExecutorIdle).foreach(onExecutorIdle) logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})") - if (numExecutorsPending > 0) { - numExecutorsPending -= 1 - logDebug(s"Decremented number of pending executors ($numExecutorsPending left)") - } } else { logWarning(s"Duplicate executor $executorId has registered") } diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 3653f724ba192..8aed1e20e0686 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -150,8 +150,13 @@ import org.apache.spark.util.Utils * authorization. If not filter is in place the user is generally null and no authorization * can take place. * - * Connection encryption (SSL) configuration is organized hierarchically. The user can configure - * the default SSL settings which will be used for all the supported communication protocols unless + * When authentication is being used, encryption can also be enabled by setting the option + * spark.authenticate.enableSaslEncryption to true. This is only supported by communication + * channels that use the network-common library, and can be used as an alternative to SSL in those + * cases. + * + * SSL can be used for encryption for certain communication channels. The user can configure the + * default SSL settings which will be used for all the supported communication protocols unless * they are overwritten by protocol specific settings. This way the user can easily provide the * common settings for all the protocols without disabling the ability to configure each one * individually. @@ -412,6 +417,14 @@ private[spark] class SecurityManager(sparkConf: SparkConf) */ def isAuthenticationEnabled(): Boolean = authOn + /** + * Checks whether SASL encryption should be enabled. + * @return Whether to enable SASL encryption when connecting to services that support it. + */ + def isSaslEncryptionEnabled(): Boolean = { + sparkConf.getBoolean("spark.authenticate.enableSaslEncryption", false) + } + /** * Gets the user used for authenticating HTTP connections. * For now use a single hardcoded user. diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3f7cba6dbcdb5..4ef90546a2452 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -347,6 +347,19 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli value } + /** Control our logLevel. This overrides any user-defined log settings. + * @param logLevel The desired log level as a string. + * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN + */ + def setLogLevel(logLevel: String) { + val validLevels = Seq("ALL", "DEBUG", "ERROR", "FATAL", "INFO", "OFF", "TRACE", "WARN") + if (!validLevels.contains(logLevel)) { + throw new IllegalArgumentException( + s"Supplied level $logLevel did not match one of: ${validLevels.mkString(",")}") + } + Utils.setLogLevel(org.apache.log4j.Level.toLevel(logLevel)) + } + try { _conf = config.clone() _conf.validateSettings() diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index fe6320b504e15..398ca41e16151 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -105,18 +105,23 @@ private[spark] object TestUtils { URI.create(s"string:///${name.replace(".", "/")}${SOURCE.extension}") } - private[spark] class JavaSourceFromString(val name: String, val code: String) + private class JavaSourceFromString(val name: String, val code: String) extends SimpleJavaFileObject(createURI(name), SOURCE) { override def getCharContent(ignoreEncodingErrors: Boolean): String = code } - /** Creates a compiled class with the source file. Class file will be placed in destDir. */ + /** Creates a compiled class with the given name. Class file will be placed in destDir. */ def createCompiledClass( className: String, destDir: File, - sourceFile: JavaSourceFromString, - classpathUrls: Seq[URL]): File = { + toStringValue: String = "", + baseClass: String = null, + classpathUrls: Seq[URL] = Seq()): File = { val compiler = ToolProvider.getSystemJavaCompiler + val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") + val sourceFile = new JavaSourceFromString(className, + "public class " + className + extendsText + " implements java.io.Serializable {" + + " @Override public String toString() { return \"" + toStringValue + "\"; }}") // Calling this outputs a class file in pwd. It's easier to just rename the file than // build a custom FileManager that controls the output location. @@ -139,18 +144,4 @@ private[spark] object TestUtils { assert(out.exists(), "Destination file not moved: " + out.getAbsolutePath()) out } - - /** Creates a compiled class with the given name. Class file will be placed in destDir. */ - def createCompiledClass( - className: String, - destDir: File, - toStringValue: String = "", - baseClass: String = null, - classpathUrls: Seq[URL] = Seq()): File = { - val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") - val sourceFile = new JavaSourceFromString(className, - "public class " + className + extendsText + " implements java.io.Serializable {" + - " @Override public String toString() { return \"" + toStringValue + "\"; }}") - createCompiledClass(className, destDir, sourceFile, classpathUrls) - } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 3be6783bba49d..02e49a853c5f7 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -755,6 +755,14 @@ class JavaSparkContext(val sc: SparkContext) */ def getLocalProperty(key: String): String = sc.getLocalProperty(key) + /** Control our logLevel. This overrides any user-defined log settings. + * @param logLevel The desired log level as a string. + * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN + */ + def setLogLevel(logLevel: String) { + sc.setLogLevel(logLevel) + } + /** * Assigns a group ID to all the jobs started by this thread until the group ID is set to a * different value or cleared. diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index c2c3e9a9e4827..848b62f9de71b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy +import scala.collection.mutable.HashSet import scala.concurrent._ import akka.actor._ @@ -31,21 +32,24 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} /** * Proxy that relays messages to the driver. + * + * We currently don't support retry if submission fails. In HA mode, client will submit request to + * all masters and see which one could handle it. */ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with ActorLogReceive with Logging { - var masterActor: ActorSelection = _ + private val masterActors = driverArgs.masters.map { m => + context.actorSelection(Master.toAkkaUrl(m, AkkaUtils.protocol(context.system))) + } + private val lostMasters = new HashSet[Address] + private var activeMasterActor: ActorSelection = null + val timeout = RpcUtils.askTimeout(conf) override def preStart(): Unit = { - masterActor = context.actorSelection( - Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(context.system))) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - println(s"Sending ${driverArgs.cmd} command to ${driverArgs.master}") - driverArgs.cmd match { case "launch" => // TODO: We could add an env variable here and intercept it in `sc.addJar` that would @@ -79,11 +83,17 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) driverArgs.supervise, command) - masterActor ! RequestSubmitDriver(driverDescription) + // This assumes only one Master is active at a time + for (masterActor <- masterActors) { + masterActor ! RequestSubmitDriver(driverDescription) + } case "kill" => val driverId = driverArgs.driverId - masterActor ! RequestKillDriver(driverId) + // This assumes only one Master is active at a time + for (masterActor <- masterActors) { + masterActor ! RequestKillDriver(driverId) + } } } @@ -92,10 +102,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println("... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") - val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout) + val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout) .mapTo[DriverStatusResponse] val statusResponse = Await.result(statusFuture, timeout) - statusResponse.found match { case false => println(s"ERROR: Cluster master did not recognize $driverId") @@ -122,20 +131,46 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) case SubmitDriverResponse(success, driverId, message) => println(message) - if (success) pollAndReportStatus(driverId.get) else System.exit(-1) + if (success) { + activeMasterActor = context.actorSelection(sender.path) + pollAndReportStatus(driverId.get) + } else if (!Utils.responseFromBackup(message)) { + System.exit(-1) + } + case KillDriverResponse(driverId, success, message) => println(message) - if (success) pollAndReportStatus(driverId) else System.exit(-1) + if (success) { + activeMasterActor = context.actorSelection(sender.path) + pollAndReportStatus(driverId) + } else if (!Utils.responseFromBackup(message)) { + System.exit(-1) + } case DisassociatedEvent(_, remoteAddress, _) => - println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") - System.exit(-1) + if (!lostMasters.contains(remoteAddress)) { + println(s"Error connecting to master $remoteAddress.") + lostMasters += remoteAddress + // Note that this heuristic does not account for the fact that a Master can recover within + // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This + // is not currently a concern, however, because this client does not retry submissions. + if (lostMasters.size >= masterActors.size) { + println("No master is available, exiting.") + System.exit(-1) + } + } case AssociationErrorEvent(cause, _, remoteAddress, _, _) => - println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") - println(s"Cause was: $cause") - System.exit(-1) + if (!lostMasters.contains(remoteAddress)) { + println(s"Error connecting to master ($remoteAddress).") + println(s"Cause was: $cause") + lostMasters += remoteAddress + if (lostMasters.size >= masterActors.size) { + println("No master is available, exiting.") + System.exit(-1) + } + } } } @@ -163,7 +198,9 @@ object Client { "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely - Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(actorSystem)) + for (m <- driverArgs.masters) { + Master.toAkkaUrl(m, AkkaUtils.protocol(actorSystem)) + } actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) actorSystem.awaitTermination() diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 5cbac787dceeb..316e2d59f01b8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -22,8 +22,7 @@ import java.net.{URI, URISyntaxException} import scala.collection.mutable.ListBuffer import org.apache.log4j.Level - -import org.apache.spark.util.{IntParam, MemoryParam} +import org.apache.spark.util.{IntParam, MemoryParam, Utils} /** * Command-line parser for the driver client. @@ -35,7 +34,7 @@ private[deploy] class ClientArguments(args: Array[String]) { var logLevel = Level.WARN // launch parameters - var master: String = "" + var masters: Array[String] = null var jarUrl: String = "" var mainClass: String = "" var supervise: Boolean = DEFAULT_SUPERVISE @@ -80,13 +79,13 @@ private[deploy] class ClientArguments(args: Array[String]) { } jarUrl = _jarUrl - master = _master + masters = Utils.parseStandaloneMasterUrls(_master) mainClass = _mainClass _driverOptions ++= tail case "kill" :: _master :: _driverId :: tail => cmd = "kill" - master = _master + masters = Utils.parseStandaloneMasterUrls(_master) driverId = _driverId case _ => diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index cd16f992a3c0a..09973a0a2c998 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -19,10 +19,12 @@ package org.apache.spark.deploy import java.util.concurrent.CountDownLatch +import scala.collection.JavaConversions._ + import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.sasl.SaslRpcHandler +import org.apache.spark.network.sasl.SaslServerBootstrap import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.util.Utils @@ -44,10 +46,7 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) private val blockHandler = new ExternalShuffleBlockHandler(transportConf) - private val transportContext: TransportContext = { - val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler - new TransportContext(transportConf, handler) - } + private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler) private var server: TransportServer = _ @@ -62,7 +61,13 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana def start() { require(server == null, "Shuffle server already started") logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - server = transportContext.createServer(port) + val bootstraps = + if (useSasl) { + Seq(new SaslServerBootstrap(transportConf, securityManager)) + } else { + Nil + } + server = transportContext.createServer(port, bootstraps) } def stop() { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index cfaebf9ea5050..b563034457a91 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,12 +17,16 @@ package org.apache.spark.deploy +import java.io.{ByteArrayInputStream, DataInputStream} import java.lang.reflect.Method import java.security.PrivilegedExceptionAction +import java.util.{Arrays, Comparator} +import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.hadoop.fs.FileSystem.Statistics +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.security.{Credentials, UserGroupInformation} @@ -32,6 +36,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils import scala.collection.JavaConversions._ +import scala.concurrent.duration._ +import scala.language.postfixOps /** * :: DeveloperApi :: @@ -39,7 +45,8 @@ import scala.collection.JavaConversions._ */ @DeveloperApi class SparkHadoopUtil extends Logging { - val conf: Configuration = newConfiguration(new SparkConf()) + private val sparkConf = new SparkConf() + val conf: Configuration = newConfiguration(sparkConf) UserGroupInformation.setConfiguration(conf) /** @@ -201,6 +208,61 @@ class SparkHadoopUtil extends Logging { if (baseStatus.isDir) recurse(basePath) else Array(baseStatus) } + /** + * Lists all the files in a directory with the specified prefix, and does not end with the + * given suffix. The returned {{FileStatus}} instances are sorted by the modification times of + * the respective files. + */ + def listFilesSorted( + remoteFs: FileSystem, + dir: Path, + prefix: String, + exclusionSuffix: String): Array[FileStatus] = { + val fileStatuses = remoteFs.listStatus(dir, + new PathFilter { + override def accept(path: Path): Boolean = { + val name = path.getName + name.startsWith(prefix) && !name.endsWith(exclusionSuffix) + } + }) + Arrays.sort(fileStatuses, new Comparator[FileStatus] { + override def compare(o1: FileStatus, o2: FileStatus): Int = { + Longs.compare(o1.getModificationTime, o2.getModificationTime) + } + }) + fileStatuses + } + + /** + * How much time is remaining (in millis) from now to (fraction * renewal time for the token that + * is valid the latest)? + * This will return -ve (or 0) value if the fraction of validity has already expired. + */ + def getTimeFromNowToRenewal( + sparkConf: SparkConf, + fraction: Double, + credentials: Credentials): Long = { + val now = System.currentTimeMillis() + + val renewalInterval = + sparkConf.getLong("spark.yarn.token.renewal.interval", (24 hours).toMillis) + + credentials.getAllTokens.filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) + .map { t => + val identifier = new DelegationTokenIdentifier() + identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) + (identifier.getIssueDate + fraction * renewalInterval).toLong - now + }.foldLeft(0L)(math.max) + } + + + private[spark] def getSuffixForCredentialsPath(credentialsPath: Path): Int = { + val fileName = credentialsPath.getName + fileName.substring( + fileName.lastIndexOf(SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM) + 1).toInt + } + + private val HADOOP_CONF_PATTERN = "(\\$\\{hadoopconf-[^\\}\\$\\s]+\\})".r.unanchored /** @@ -231,6 +293,17 @@ class SparkHadoopUtil extends Logging { } } } + + /** + * Start a thread to periodically update the current user's credentials with new delegation + * tokens so that writes to HDFS do not fail. + */ + private[spark] def startExecutorDelegationTokenRenewer(conf: SparkConf) {} + + /** + * Stop the thread that does the delegation token updates. + */ + private[spark] def stopExecutorDelegationTokenRenewer() {} } object SparkHadoopUtil { @@ -251,6 +324,10 @@ object SparkHadoopUtil { } } + val SPARK_YARN_CREDS_TEMP_EXTENSION = ".tmp" + + val SPARK_YARN_CREDS_COUNTER_DELIM = "-" + def get: SparkHadoopUtil = { hadoop } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 0d149e703aff2..42b5d41b7b526 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -20,7 +20,6 @@ package org.apache.spark.deploy import java.io.{File, PrintStream} import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL -import java.nio.file.{Path => JavaPath} import java.security.PrivilegedExceptionAction import scala.collection.mutable.{ArrayBuffer, HashMap, Map} @@ -119,8 +118,8 @@ object SparkSubmit { * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. */ private def kill(args: SparkSubmitArguments): Unit = { - new RestSubmissionClient() - .killSubmission(args.master, args.submissionToKill) + new RestSubmissionClient(args.master) + .killSubmission(args.submissionToKill) } /** @@ -128,8 +127,8 @@ object SparkSubmit { * Standalone and Mesos cluster mode only. */ private def requestStatus(args: SparkSubmitArguments): Unit = { - new RestSubmissionClient() - .requestSubmissionStatus(args.master, args.submissionToRequestStatusFor) + new RestSubmissionClient(args.master) + .requestSubmissionStatus(args.submissionToRequestStatusFor) } /** @@ -401,6 +400,10 @@ object SparkSubmit { OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), + // Yarn client or cluster + OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, clOption = "--principal"), + OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, clOption = "--keytab"), + // Other options OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, sysProp = "spark.executor.cores"), @@ -709,9 +712,7 @@ private[deploy] object SparkSubmitUtils { * @param artifactId the artifactId of the coordinate * @param version the version of the coordinate */ - private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) { - override def toString: String = s"$groupId:$artifactId:$version" - } + private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) /** * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided @@ -734,10 +735,6 @@ private[deploy] object SparkSubmitUtils { } } - /** Path of the local Maven cache. */ - private[spark] def m2Path: JavaPath = new File(System.getProperty("user.home"), - ".m2" + File.separator + "repository" + File.separator).toPath - /** * Extracts maven coordinates from a comma-delimited string * @param remoteRepos Comma-delimited string of remote repositories @@ -751,7 +748,8 @@ private[deploy] object SparkSubmitUtils { val localM2 = new IBiblioResolver localM2.setM2compatible(true) - localM2.setRoot(m2Path.toUri.toString) + val m2Path = ".m2" + File.separator + "repository" + File.separator + localM2.setRoot(new File(System.getProperty("user.home"), m2Path).toURI.toString) localM2.setUsepoms(true) localM2.setName("local-m2-cache") cr.add(localM2) @@ -876,72 +874,69 @@ private[deploy] object SparkSubmitUtils { "" } else { val sysOut = System.out - try { - // To prevent ivy from logging to system out - System.setOut(printStream) - val artifacts = extractMavenCoordinates(coordinates) - // Default configuration name for ivy - val ivyConfName = "default" - // set ivy settings for location of cache - val ivySettings: IvySettings = new IvySettings - // Directories for caching downloads through ivy and storing the jars when maven coordinates - // are supplied to spark-submit - val alternateIvyCache = ivyPath.getOrElse("") - val packagesDirectory: File = - if (alternateIvyCache.trim.isEmpty) { - new File(ivySettings.getDefaultIvyUserDir, "jars") - } else { - ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) - ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) - new File(alternateIvyCache, "jars") - } - printStream.println( - s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") - printStream.println(s"The jars for the packages stored in: $packagesDirectory") - // create a pattern matcher - ivySettings.addMatcher(new GlobPatternMatcher) - // create the dependency resolvers - val repoResolver = createRepoResolvers(remoteRepos, ivySettings) - ivySettings.addResolver(repoResolver) - ivySettings.setDefaultResolver(repoResolver.getName) - - val ivy = Ivy.newInstance(ivySettings) - // Set resolve options to download transitive dependencies as well - val resolveOptions = new ResolveOptions - resolveOptions.setTransitive(true) - val retrieveOptions = new RetrieveOptions - // Turn downloading and logging off for testing - if (isTest) { - resolveOptions.setDownload(false) - resolveOptions.setLog(LogOptions.LOG_QUIET) - retrieveOptions.setLog(LogOptions.LOG_QUIET) + // To prevent ivy from logging to system out + System.setOut(printStream) + val artifacts = extractMavenCoordinates(coordinates) + // Default configuration name for ivy + val ivyConfName = "default" + // set ivy settings for location of cache + val ivySettings: IvySettings = new IvySettings + // Directories for caching downloads through ivy and storing the jars when maven coordinates + // are supplied to spark-submit + val alternateIvyCache = ivyPath.getOrElse("") + val packagesDirectory: File = + if (alternateIvyCache.trim.isEmpty) { + new File(ivySettings.getDefaultIvyUserDir, "jars") } else { - resolveOptions.setDownload(true) + ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) + ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) + new File(alternateIvyCache, "jars") } + printStream.println( + s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") + printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + // create the dependency resolvers + val repoResolver = createRepoResolvers(remoteRepos, ivySettings) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + + val ivy = Ivy.newInstance(ivySettings) + // Set resolve options to download transitive dependencies as well + val resolveOptions = new ResolveOptions + resolveOptions.setTransitive(true) + val retrieveOptions = new RetrieveOptions + // Turn downloading and logging off for testing + if (isTest) { + resolveOptions.setDownload(false) + resolveOptions.setLog(LogOptions.LOG_QUIET) + retrieveOptions.setLog(LogOptions.LOG_QUIET) + } else { + resolveOptions.setDownload(true) + } - // A Module descriptor must be specified. Entries are dummy strings - val md = getModuleDescriptor - md.setDefaultConf(ivyConfName) + // A Module descriptor must be specified. Entries are dummy strings + val md = getModuleDescriptor + md.setDefaultConf(ivyConfName) - // Add exclusion rules for Spark and Scala Library - addExclusionRules(ivySettings, ivyConfName, md) - // add all supplied maven artifacts as dependencies - addDependenciesToIvy(md, artifacts, ivyConfName) + // Add exclusion rules for Spark and Scala Library + addExclusionRules(ivySettings, ivyConfName, md) + // add all supplied maven artifacts as dependencies + addDependenciesToIvy(md, artifacts, ivyConfName) - // resolve dependencies - val rr: ResolveReport = ivy.resolve(md, resolveOptions) - if (rr.hasError) { - throw new RuntimeException(rr.getAllProblemMessages.toString) - } - // retrieve all resolved dependencies - ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, - packagesDirectory.getAbsolutePath + File.separator + - "[organization]_[artifact]-[revision].[ext]", - retrieveOptions.setConfs(Array(ivyConfName))) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) - } finally { - System.setOut(sysOut) + // resolve dependencies + val rr: ResolveReport = ivy.resolve(md, resolveOptions) + if (rr.hasError) { + throw new RuntimeException(rr.getAllProblemMessages.toString) } + // retrieve all resolved dependencies + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision].[ext]", + retrieveOptions.setConfs(Array(ivyConfName))) + System.setOut(sysOut) + resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c621b8fc86f94..c0e4c771908b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -63,6 +63,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() var proxyUser: String = null + var principal: String = null + var keytab: String = null // Standalone cluster mode only var supervise: Boolean = false @@ -393,6 +395,12 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S case PROXY_USER => proxyUser = value + case PRINCIPAL => + principal = value + + case KEYTAB => + keytab = value + case HELP => printUsageAndExit(0) @@ -506,6 +514,13 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --num-executors NUM Number of executors to launch (Default: 2). | --archives ARCHIVES Comma separated list of archives to be extracted into the | working directory of each executor. + | --principal PRINCIPAL Principal to be used to login to KDC, while running on + | secure HDFS. + | --keytab KEYTAB The full path to the file that contains the keytab for the + | principal specified above. This keytab will be copied to + | the node running the Application Master via the Secure + | Distributed Cache, for renewing the login tickets and the + | delegation tokens periodically. """.stripMargin ) SparkSubmit.exitFn() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index dc6077f3d132b..0fac3cdcf55e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -254,7 +254,8 @@ private[master] class Master( case RequestSubmitDriver(description) => { if (state != RecoveryState.ALIVE) { - val msg = s"Can only accept driver submissions in ALIVE state. Current state: $state." + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only accept driver submissions in ALIVE state." sender ! SubmitDriverResponse(false, None, msg) } else { logInfo("Driver submitted " + description.command.mainClass) @@ -274,7 +275,8 @@ private[master] class Master( case RequestKillDriver(driverId) => { if (state != RecoveryState.ALIVE) { - val msg = s"Can only kill drivers in ALIVE state. Current state: $state." + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + s"Can only kill drivers in ALIVE state." sender ! KillDriverResponse(driverId, success = false, msg) } else { logInfo("Asked to kill driver " + driverId) @@ -305,12 +307,18 @@ private[master] class Master( } case RequestDriverStatus(driverId) => { - (drivers ++ completedDrivers).find(_.id == driverId) match { - case Some(driver) => - sender ! DriverStatusResponse(found = true, Some(driver.state), - driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) - case None => - sender ! DriverStatusResponse(found = false, None, None, None, None) + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only request driver status in ALIVE state." + sender ! DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg))) + } else { + (drivers ++ completedDrivers).find(_.id == driverId) match { + case Some(driver) => + sender ! DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) + case None => + sender ! DriverStatusResponse(found = false, None, None, None, None) + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index da5060778edeb..a03d460509e03 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -33,7 +33,7 @@ import scala.reflect.ClassTag * The implementation of this trait defines how name-object pairs are stored or retrieved. */ @DeveloperApi -trait PersistenceEngine { +abstract class PersistenceEngine { /** * Defines how the object is serialized and persisted. Implementation will diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index aad9c87bdb987..dea0a65eeeaa6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -44,9 +44,9 @@ class MasterWebUI(val master: Master, requestedPort: Int) attachPage(masterPage) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler( - "/app/kill", "/", masterPage.handleAppKillRequest, httpMethod = "POST")) + "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( - "/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethod = "POST")) + "/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethods = Set("POST"))) } /** Attach a reconstructed UI to this Master UI. Only valid after bind(). */ diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala new file mode 100644 index 0000000000000..be8560d10fc62 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos.{MesosClusterSubmissionState, MesosClusterRetryState} +import org.apache.spark.ui.{UIUtils, WebUIPage} + + +private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") { + + override def render(request: HttpServletRequest): Seq[Node] = { + val driverId = request.getParameter("id") + require(driverId != null && driverId.nonEmpty, "Missing id parameter") + + val state = parent.scheduler.getDriverState(driverId) + if (state.isEmpty) { + val content = +
+

Cannot find driver {driverId}

+
+ return UIUtils.basicSparkPage(content, s"Details for Job $driverId") + } + val driverState = state.get + val driverHeaders = Seq("Driver property", "Value") + val schedulerHeaders = Seq("Scheduler property", "Value") + val commandEnvHeaders = Seq("Command environment variable", "Value") + val launchedHeaders = Seq("Launched property", "Value") + val commandHeaders = Seq("Comamnd property", "Value") + val retryHeaders = Seq("Last failed status", "Next retry time", "Retry count") + val driverDescription = Iterable.apply(driverState.description) + val submissionState = Iterable.apply(driverState.submissionState) + val command = Iterable.apply(driverState.description.command) + val schedulerProperties = Iterable.apply(driverState.description.schedulerProperties) + val commandEnv = Iterable.apply(driverState.description.command.environment) + val driverTable = + UIUtils.listingTable(driverHeaders, driverRow, driverDescription) + val commandTable = + UIUtils.listingTable(commandHeaders, commandRow, command) + val commandEnvTable = + UIUtils.listingTable(commandEnvHeaders, propertiesRow, commandEnv) + val schedulerTable = + UIUtils.listingTable(schedulerHeaders, propertiesRow, schedulerProperties) + val launchedTable = + UIUtils.listingTable(launchedHeaders, launchedRow, submissionState) + val retryTable = + UIUtils.listingTable( + retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) + val content = +

Driver state information for driver id {driverId}

+ Back to Drivers +
+
+

Driver state: {driverState.state}

+

Driver properties

+ {driverTable} +

Driver command

+ {commandTable} +

Driver command environment

+ {commandEnvTable} +

Scheduler properties

+ {schedulerTable} +

Launched state

+ {launchedTable} +

Retry state

+ {retryTable} +
+
; + + UIUtils.basicSparkPage(content, s"Details for Job $driverId") + } + + private def launchedRow(submissionState: Option[MesosClusterSubmissionState]): Seq[Node] = { + submissionState.map { state => + + Mesos Slave ID + {state.slaveId.getValue} + + + Mesos Task ID + {state.taskId.getValue} + + + Launch Time + {state.startDate} + + + Finish Time + {state.finishDate.map(_.toString).getOrElse("")} + + + Last Task Status + {state.mesosTaskStatus.map(_.toString).getOrElse("")} + + }.getOrElse(Seq[Node]()) + } + + private def propertiesRow(properties: collection.Map[String, String]): Seq[Node] = { + properties.map { case (k, v) => + + {k}{v} + + }.toSeq + } + + private def commandRow(command: Command): Seq[Node] = { + + Main class{command.mainClass} + + + Arguments{command.arguments.mkString(" ")} + + + Class path entries{command.classPathEntries.mkString(" ")} + + + Java options{command.javaOpts.mkString((" "))} + + + Library path entries{command.libraryPathEntries.mkString((" "))} + + } + + private def driverRow(driver: MesosDriverDescription): Seq[Node] = { + + Name{driver.name} + + + Id{driver.submissionId} + + + Cores{driver.cores} + + + Memory{driver.mem} + + + Submitted{driver.submissionDate} + + + Supervise{driver.supervise} + + } + + private def retryRow(retryState: Option[MesosClusterRetryState]): Seq[Node] = { + retryState.map { state => + + + {state.lastFailureStatus} + + + {state.nextRetry} + + + {state.retries} + + + }.getOrElse(Seq[Node]()) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 7b2005e0f1237..7419fa9699648 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -56,8 +56,9 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( } private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { + val id = submission.submissionId - {submission.submissionId} + {id} {submission.submissionDate} {submission.command.mainClass} cpus: {submission.cores}, mem: {submission.mem} @@ -65,8 +66,9 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( } private def driverRow(state: MesosClusterSubmissionState): Seq[Node] = { + val id = state.driverDescription.submissionId - {state.driverDescription.submissionId} + {id} {state.driverDescription.submissionDate} {state.driverDescription.command.mainClass} cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem} @@ -77,8 +79,9 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( } private def retryRow(submission: MesosDriverDescription): Seq[Node] = { + val id = submission.submissionId - {submission.submissionId} + {id} {submission.submissionDate} {submission.command.mainClass} {submission.retryState.get.lastFailureStatus} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala index 4865d46dbc4ab..3f693545a0349 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -39,6 +39,7 @@ private[spark] class MesosClusterUI( override def initialize() { attachPage(new MesosClusterPage(this)) + attachPage(new DriverPage(this)) attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 307cebfb4bd09..6078f50518ba4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -18,9 +18,10 @@ package org.apache.spark.deploy.rest import java.io.{DataOutputStream, FileNotFoundException} -import java.net.{HttpURLConnection, SocketException, URL} +import java.net.{ConnectException, HttpURLConnection, SocketException, URL} import javax.servlet.http.HttpServletResponse +import scala.collection.mutable import scala.io.Source import com.fasterxml.jackson.core.JsonProcessingException @@ -51,57 +52,109 @@ import org.apache.spark.util.Utils * implementation of this client can use that information to retry using the version specified * by the server. */ -private[spark] class RestSubmissionClient extends Logging { +private[spark] class RestSubmissionClient(master: String) extends Logging { import RestSubmissionClient._ private val supportedMasterPrefixes = Seq("spark://", "mesos://") + private val masters: Array[String] = Utils.parseStandaloneMasterUrls(master) + + // Set of masters that lost contact with us, used to keep track of + // whether there are masters still alive for us to communicate with + private val lostMasters = new mutable.HashSet[String] + /** * Submit an application specified by the parameters in the provided request. * * If the submission was successful, poll the status of the submission and report * it to the user. Otherwise, report the error message provided by the server. */ - def createSubmission( - master: String, - request: CreateSubmissionRequest): SubmitRestProtocolResponse = { + def createSubmission(request: CreateSubmissionRequest): SubmitRestProtocolResponse = { logInfo(s"Submitting a request to launch an application in $master.") - validateMaster(master) - val url = getSubmitUrl(master) - val response = postJson(url, request.toJson) - response match { - case s: CreateSubmissionResponse => - reportSubmissionStatus(master, s) - handleRestResponse(s) - case unexpected => - handleUnexpectedRestResponse(unexpected) + var handled: Boolean = false + var response: SubmitRestProtocolResponse = null + for (m <- masters if !handled) { + validateMaster(m) + val url = getSubmitUrl(m) + try { + response = postJson(url, request.toJson) + response match { + case s: CreateSubmissionResponse => + if (s.success) { + reportSubmissionStatus(s) + handleRestResponse(s) + handled = true + } + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + } catch { + case e: SubmitRestConnectionException => + if (handleConnectionException(m)) { + throw new SubmitRestConnectionException("Unable to connect to server", e) + } + } } response } /** Request that the server kill the specified submission. */ - def killSubmission(master: String, submissionId: String): SubmitRestProtocolResponse = { + def killSubmission(submissionId: String): SubmitRestProtocolResponse = { logInfo(s"Submitting a request to kill submission $submissionId in $master.") - validateMaster(master) - val response = post(getKillUrl(master, submissionId)) - response match { - case k: KillSubmissionResponse => handleRestResponse(k) - case unexpected => handleUnexpectedRestResponse(unexpected) + var handled: Boolean = false + var response: SubmitRestProtocolResponse = null + for (m <- masters if !handled) { + validateMaster(m) + val url = getKillUrl(m, submissionId) + try { + response = post(url) + response match { + case k: KillSubmissionResponse => + if (!Utils.responseFromBackup(k.message)) { + handleRestResponse(k) + handled = true + } + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + } catch { + case e: SubmitRestConnectionException => + if (handleConnectionException(m)) { + throw new SubmitRestConnectionException("Unable to connect to server", e) + } + } } response } /** Request the status of a submission from the server. */ def requestSubmissionStatus( - master: String, submissionId: String, quiet: Boolean = false): SubmitRestProtocolResponse = { logInfo(s"Submitting a request for the status of submission $submissionId in $master.") - validateMaster(master) - val response = get(getStatusUrl(master, submissionId)) - response match { - case s: SubmissionStatusResponse => if (!quiet) { handleRestResponse(s) } - case unexpected => handleUnexpectedRestResponse(unexpected) + + var handled: Boolean = false + var response: SubmitRestProtocolResponse = null + for (m <- masters if !handled) { + validateMaster(m) + val url = getStatusUrl(m, submissionId) + try { + response = get(url) + response match { + case s: SubmissionStatusResponse if s.success => + if (!quiet) { + handleRestResponse(s) + } + handled = true + case unexpected => + handleUnexpectedRestResponse(unexpected) + } + } catch { + case e: SubmitRestConnectionException => + if (handleConnectionException(m)) { + throw new SubmitRestConnectionException("Unable to connect to server", e) + } + } } response } @@ -148,11 +201,16 @@ private[spark] class RestSubmissionClient extends Logging { conn.setRequestProperty("Content-Type", "application/json") conn.setRequestProperty("charset", "utf-8") conn.setDoOutput(true) - val out = new DataOutputStream(conn.getOutputStream) - Utils.tryWithSafeFinally { - out.write(json.getBytes(Charsets.UTF_8)) - } { - out.close() + try { + val out = new DataOutputStream(conn.getOutputStream) + Utils.tryWithSafeFinally { + out.write(json.getBytes(Charsets.UTF_8)) + } { + out.close() + } + } catch { + case e: ConnectException => + throw new SubmitRestConnectionException("Connect Exception when connect to server", e) } readResponse(conn) } @@ -191,11 +249,9 @@ private[spark] class RestSubmissionClient extends Logging { } } catch { case unreachable @ (_: FileNotFoundException | _: SocketException) => - throw new SubmitRestConnectionException( - s"Unable to connect to server ${connection.getURL}", unreachable) + throw new SubmitRestConnectionException("Unable to connect to server", unreachable) case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) => - throw new SubmitRestProtocolException( - "Malformed response received from server", malformed) + throw new SubmitRestProtocolException("Malformed response received from server", malformed) } } @@ -241,13 +297,12 @@ private[spark] class RestSubmissionClient extends Logging { /** Report the status of a newly created submission. */ private def reportSubmissionStatus( - master: String, submitResponse: CreateSubmissionResponse): Unit = { if (submitResponse.success) { val submissionId = submitResponse.submissionId if (submissionId != null) { logInfo(s"Submission successfully created as $submissionId. Polling submission state...") - pollSubmissionStatus(master, submissionId) + pollSubmissionStatus(submissionId) } else { // should never happen logError("Application successfully submitted, but submission ID was not provided!") @@ -262,9 +317,9 @@ private[spark] class RestSubmissionClient extends Logging { * Poll the status of the specified submission and log it. * This retries up to a fixed number of times before giving up. */ - private def pollSubmissionStatus(master: String, submissionId: String): Unit = { + private def pollSubmissionStatus(submissionId: String): Unit = { (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => - val response = requestSubmissionStatus(master, submissionId, quiet = true) + val response = requestSubmissionStatus(submissionId, quiet = true) val statusResponse = response match { case s: SubmissionStatusResponse => s case _ => return // unexpected type, let upstream caller handle it @@ -302,6 +357,21 @@ private[spark] class RestSubmissionClient extends Logging { private def handleUnexpectedRestResponse(unexpected: SubmitRestProtocolResponse): Unit = { logError(s"Error: Server responded with message of unexpected type ${unexpected.messageType}.") } + + /** + * When a connection exception is caught, return true if all masters are lost. + * Note that the heuristic used here does not take into account that masters + * can recover during the lifetime of this client. This assumption should be + * harmless because this client currently does not support retrying submission + * on failure yet (SPARK-6443). + */ + private def handleConnectionException(masterUrl: String): Boolean = { + if (!lostMasters.contains(masterUrl)) { + logWarning(s"Unable to connect to server ${masterUrl}.") + lostMasters += masterUrl + } + lostMasters.size >= masters.size + } } private[spark] object RestSubmissionClient { @@ -324,10 +394,10 @@ private[spark] object RestSubmissionClient { } val sparkProperties = conf.getAll.toMap val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } - val client = new RestSubmissionClient + val client = new RestSubmissionClient(master) val submitRequest = client.constructSubmitRequest( appResource, mainClass, appArgs, sparkProperties, environmentVariables) - client.createSubmission(master, submitRequest) + client.createSubmission(submitRequest) } def main(args: Array[String]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index fd17a980c9319..8198296eeb341 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -53,7 +53,7 @@ private[spark] class MesosRestServer( new MesosStatusRequestServlet(scheduler, masterConf) } -private[deploy] class MesosSubmitRequestServlet( +private[mesos] class MesosSubmitRequestServlet( scheduler: MesosClusterScheduler, conf: SparkConf) extends SubmitRequestServlet { @@ -139,7 +139,7 @@ private[deploy] class MesosSubmitRequestServlet( } } -private[deploy] class MesosKillRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) +private[mesos] class MesosKillRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) extends KillRequestServlet { protected override def handleKill(submissionId: String): KillSubmissionResponse = { val k = scheduler.killDriver(submissionId) @@ -148,7 +148,7 @@ private[deploy] class MesosKillRequestServlet(scheduler: MesosClusterScheduler, } } -private[deploy] class MesosStatusRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) +private[mesos] class MesosStatusRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) extends StatusRequestServlet { protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { val d = scheduler.getDriverStatus(submissionId) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 88f9d880ac209..9678631da9f6f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -105,7 +105,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { if (masters != null) { // Two positional arguments were given printUsageAndExit(1) } - masters = value.stripPrefix("spark://").split(",").map("spark://" + _) + masters = Utils.parseStandaloneMasterUrls(value) parse(tail) case Nil => diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 79aed90b53e2f..ed159dec4f998 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -20,6 +20,8 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer +import org.apache.hadoop.conf.Configuration + import scala.collection.mutable import scala.util.{Failure, Success} @@ -168,6 +170,12 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { driverConf.set(key, value) } } + if (driverConf.contains("spark.yarn.credentials.file")) { + logInfo("Will periodically update credentials from: " + + driverConf.get("spark.yarn.credentials.file")) + SparkHadoopUtil.get.startExecutorDelegationTokenRenewer(driverConf) + } + val env = SparkEnv.createExecutorEnv( driverConf, executorId, hostname, port, cores, isLocal = false) @@ -183,6 +191,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.rpcEnv.awaitTermination() + SparkHadoopUtil.get.stopExecutorDelegationTokenRenewer() } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 3f0950dae1f24..6181c0ee9fa2b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -24,7 +24,7 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} -import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} +import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock @@ -49,18 +49,18 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { - val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { - val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) - if (!authEnabled) { - (nettyRpcHandler, None) - } else { - (new SaslRpcHandler(nettyRpcHandler, securityManager), - Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager))) - } + val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + var serverBootstrap: Option[TransportServerBootstrap] = None + var clientBootstrap: Option[TransportClientBootstrap] = None + if (authEnabled) { + serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager)) + clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager, + securityManager.isSaslEncryptionEnabled())) } transportContext = new TransportContext(transportConf, rpcHandler) - clientFactory = transportContext.createClientFactory(bootstrap.toList) - server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0)) + clientFactory = transportContext.createClientFactory(clientBootstrap.toList) + server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0), + serverBootstrap.toList) appId = conf.getAppId logInfo("Server created on " + server.getPort) } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 16e905982cf64..497871ed6d5e5 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -656,7 +656,7 @@ private[nio] class ConnectionManager( connection.synchronized { if (connection.sparkSaslServer == null) { logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager) + connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager, false) } } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) @@ -800,7 +800,7 @@ private[nio] class ConnectionManager( if (!conn.isSaslComplete()) { conn.synchronized { if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager) + conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager, false) var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 05b8ab0d0a1f9..5d812918a13d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1399,6 +1399,7 @@ class DAGScheduler( def stop() { logInfo("Stopping DAGScheduler") + messageScheduler.shutdownNow() eventProcessLoop.stop() taskScheduler.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7352fa1fe9ebd..f107148f3b8c6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -68,6 +68,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { + override protected def log = CoarseGrainedSchedulerBackend.this.log private val addressToExecutorId = new HashMap[RpcAddress, String] @@ -112,6 +113,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } + } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -122,7 +124,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } else { logInfo("Registered executor: " + executorRef + " with ID " + executorId) context.reply(RegisteredExecutor) - addressToExecutorId(executorRef.address) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) @@ -243,6 +244,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp properties += ((key, value)) } } + // TODO (prashant) send conf instead of properties driverEndpoint = rpcEnv.setupEndpoint( CoarseGrainedSchedulerBackend.ENDPOINT_NAME, new DriverEndpoint(rpcEnv, properties)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 3412301e64fd7..dc59545b43314 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -196,9 +196,14 @@ private[spark] class CoarseMesosSchedulerBackend( .addResources(createResource("cpus", cpusToUse)) .addResources(createResource("mem", MemoryUtils.calculateTotalMemory(sc))) - .build() + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, task.getContainerBuilder()) + } + d.launchTasks( - Collections.singleton(offer.getId), Collections.singletonList(task), filters) + Collections.singleton(offer.getId), Collections.singletonList(task.build()), filters) } else { // Filter it out d.launchTasks( diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 0396e62be5309..06f0e2881c344 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -50,12 +50,13 @@ private[spark] class MesosClusterSubmissionState( val taskId: TaskID, val slaveId: SlaveID, var mesosTaskStatus: Option[TaskStatus], - var startDate: Date) + var startDate: Date, + var finishDate: Option[Date]) extends Serializable { def copy(): MesosClusterSubmissionState = { new MesosClusterSubmissionState( - driverDescription, taskId, slaveId, mesosTaskStatus, startDate) + driverDescription, taskId, slaveId, mesosTaskStatus, startDate, finishDate) } } @@ -95,6 +96,14 @@ private[spark] class MesosClusterSchedulerState( val finishedDrivers: Iterable[MesosClusterSubmissionState], val pendingRetryDrivers: Iterable[MesosDriverDescription]) +/** + * The full state of a Mesos driver, that is being used to display driver information on the UI. + */ +private[spark] class MesosDriverState( + val state: String, + val description: MesosDriverDescription, + val submissionState: Option[MesosClusterSubmissionState] = None) + /** * A Mesos scheduler that is responsible for launching submitted Spark drivers in cluster mode * as Mesos tasks in a Mesos cluster. @@ -233,6 +242,22 @@ private[spark] class MesosClusterScheduler( s } + /** + * Gets the driver state to be displayed on the Web UI. + */ + def getDriverState(submissionId: String): Option[MesosDriverState] = { + stateLock.synchronized { + queuedDrivers.find(_.submissionId.equals(submissionId)) + .map(d => new MesosDriverState("QUEUED", d)) + .orElse(launchedDrivers.get(submissionId) + .map(d => new MesosDriverState("RUNNING", d.driverDescription, Some(d)))) + .orElse(finishedDrivers.find(_.driverDescription.submissionId.equals(submissionId)) + .map(d => new MesosDriverState("FINISHED", d.driverDescription, Some(d)))) + .orElse(pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + .map(d => new MesosDriverState("RETRYING", d))) + } + } + private def isQueueFull(): Boolean = launchedDrivers.size >= queuedCapacity /** @@ -439,7 +464,7 @@ private[spark] class MesosClusterScheduler( logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + submission.submissionId) val newState = new MesosClusterSubmissionState(submission, taskId, offer.offer.getSlaveId, - None, new Date()) + None, new Date(), None) launchedDrivers(submission.submissionId) = newState launchedDriversState.persist(submission.submissionId, newState) afterLaunchCallback(submission.submissionId) @@ -534,6 +559,7 @@ private[spark] class MesosClusterScheduler( // Check if the driver is supervise enabled and can be relaunched. if (state.driverDescription.supervise && shouldRelaunch(status.getState)) { removeFromLaunchedDrivers(taskId) + state.finishDate = Some(new Date()) val retryState: Option[MesosClusterRetryState] = state.driverDescription.retryState val (retries, waitTimeSec) = retryState .map { rs => (rs.retries + 1, Math.min(maxRetryWaitTime, rs.waitTime * 2)) } @@ -546,6 +572,7 @@ private[spark] class MesosClusterScheduler( pendingRetryDriversState.persist(taskId, newDriverDescription) } else if (TaskState.isFinished(TaskState.fromMesos(status.getState))) { removeFromLaunchedDrivers(taskId) + state.finishDate = Some(new Date()) if (finishedDrivers.size >= retainedDrivers) { val toRemove = math.max(retainedDrivers / 10, 1) finishedDrivers.trimStart(toRemove) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 8346a2407489f..db0a080b3b0c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -23,7 +23,7 @@ import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} +import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.executor.MesosExecutorBackend @@ -56,7 +56,7 @@ private[spark] class MesosSchedulerBackend( // The listener bus to publish executor added/removed events. val listenerBus = sc.listenerBus - + private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) @volatile var appId: String = _ @@ -124,13 +124,19 @@ private[spark] class MesosSchedulerBackend( Value.Scalar.newBuilder() .setValue(MemoryUtils.calculateTotalMemory(sc)).build()) .build() - MesosExecutorInfo.newBuilder() + val executorInfo = MesosExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) .addResources(cpus) .addResources(memory) - .build() + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, executorInfo.getContainerBuilder()) + } + + executorInfo.build() } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala new file mode 100644 index 0000000000000..928c5cfed417a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import org.apache.mesos.Protos.{ContainerInfo, Volume} +import org.apache.mesos.Protos.ContainerInfo.DockerInfo + +import org.apache.spark.{Logging, SparkConf} + +/** + * A collection of utility functions which can be used by both the + * MesosSchedulerBackend and the CoarseMesosSchedulerBackend. + */ +private[mesos] object MesosSchedulerBackendUtil extends Logging { + /** + * Parse a comma-delimited list of volume specs, each of which + * takes the form [host-dir:]container-dir[:rw|:ro]. + */ + def parseVolumesSpec(volumes: String): List[Volume] = { + volumes.split(",").map(_.split(":")).flatMap { spec => + val vol: Volume.Builder = Volume + .newBuilder() + .setMode(Volume.Mode.RW) + spec match { + case Array(container_path) => + Some(vol.setContainerPath(container_path)) + case Array(container_path, "rw") => + Some(vol.setContainerPath(container_path)) + case Array(container_path, "ro") => + Some(vol.setContainerPath(container_path) + .setMode(Volume.Mode.RO)) + case Array(host_path, container_path) => + Some(vol.setContainerPath(container_path) + .setHostPath(host_path)) + case Array(host_path, container_path, "rw") => + Some(vol.setContainerPath(container_path) + .setHostPath(host_path)) + case Array(host_path, container_path, "ro") => + Some(vol.setContainerPath(container_path) + .setHostPath(host_path) + .setMode(Volume.Mode.RO)) + case spec => { + logWarning(s"Unable to parse volume specs: $volumes. " + + "Expected form: \"[host-dir:]container-dir[:rw|:ro](, ...)\"") + None + } + } + } + .map { _.build() } + .toList + } + + /** + * Parse a comma-delimited list of port mapping specs, each of which + * takes the form host_port:container_port[:udp|:tcp] + * + * Note: + * the docker form is [ip:]host_port:container_port, but the DockerInfo + * message has no field for 'ip', and instead has a 'protocol' field. + * Docker itself only appears to support TCP, so this alternative form + * anticipates the expansion of the docker form to allow for a protocol + * and leaves open the chance for mesos to begin to accept an 'ip' field + */ + def parsePortMappingsSpec(portmaps: String): List[DockerInfo.PortMapping] = { + portmaps.split(",").map(_.split(":")).flatMap { spec: Array[String] => + val portmap: DockerInfo.PortMapping.Builder = DockerInfo.PortMapping + .newBuilder() + .setProtocol("tcp") + spec match { + case Array(host_port, container_port) => + Some(portmap.setHostPort(host_port.toInt) + .setContainerPort(container_port.toInt)) + case Array(host_port, container_port, protocol) => + Some(portmap.setHostPort(host_port.toInt) + .setContainerPort(container_port.toInt) + .setProtocol(protocol)) + case spec => { + logWarning(s"Unable to parse port mapping specs: $portmaps. " + + "Expected form: \"host_port:container_port[:udp|:tcp](, ...)\"") + None + } + } + } + .map { _.build() } + .toList + } + + /** + * Construct a DockerInfo structure and insert it into a ContainerInfo + */ + def addDockerInfo( + container: ContainerInfo.Builder, + image: String, + volumes: Option[List[Volume]] = None, + network: Option[ContainerInfo.DockerInfo.Network] = None, + portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None):Unit = { + + val docker = ContainerInfo.DockerInfo.newBuilder().setImage(image) + + network.foreach(docker.setNetwork) + portmaps.foreach(_.foreach(docker.addPortMappings)) + container.setType(ContainerInfo.Type.DOCKER) + container.setDocker(docker.build()) + volumes.foreach(_.foreach(container.addVolumes)) + } + + /** + * Setup a docker containerizer + */ + def setupContainerBuilderDockerInfo( + imageName: String, + conf: SparkConf, + builder: ContainerInfo.Builder): Unit = { + val volumes = conf + .getOption("spark.mesos.executor.docker.volumes") + .map(parseVolumesSpec) + val portmaps = conf + .getOption("spark.mesos.executor.docker.portmaps") + .map(parsePortMappingsSpec) + addDockerInfo( + builder, + imageName, + volumes = volumes, + portmaps = portmaps) + logDebug("setupContainerDockerInfo: using docker image: " + imageName) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 402ee1c7648c5..a46fecd2274ef 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -111,7 +111,8 @@ private[spark] class BlockManager( // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) - new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) + new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), + securityManager.isSaslEncryptionEnabled()) } else { blockTransferService } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index a091ca650c60c..dfd6fdb5e9993 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -115,19 +115,21 @@ private[spark] object JettyUtils extends Logging { destPath: String, beforeRedirect: HttpServletRequest => Unit = x => (), basePath: String = "", - httpMethod: String = "GET"): ServletContextHandler = { + httpMethods: Set[String] = Set("GET")): ServletContextHandler = { val prefixedDestPath = attachPrefix(basePath, destPath) val servlet = new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { - httpMethod match { - case "GET" => doRequest(request, response) - case _ => response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + if (httpMethods.contains("GET")) { + doRequest(request, response) + } else { + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) } } override def doPost(request: HttpServletRequest, response: HttpServletResponse): Unit = { - httpMethod match { - case "POST" => doRequest(request, response) - case _ => response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + if (httpMethods.contains("POST")) { + doRequest(request, response) + } else { + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) } } private def doRequest(request: HttpServletRequest, response: HttpServletResponse): Unit = { diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 580ab8b1325f8..06fce86bd38d2 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -55,8 +55,10 @@ private[spark] class SparkUI private ( attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) + // This should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( - "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, httpMethod = "POST")) + "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, + httpMethods = Set("GET", "POST"))) } initialize() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 6d8c7e1fda8d8..a33243d4252bf 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -76,15 +76,20 @@ private[ui] class StageTableBase( val basePathUri = UIUtils.prependBaseUri(basePath) val killLink = if (killEnabled) { - val killLinkUri = s"$basePathUri/stages/stage/kill/" val confirm = s"if (window.confirm('Are you sure you want to kill stage ${s.stageId} ?')) " + "{ this.parentNode.submit(); return true; } else { return false; }" + // SPARK-6846 this should be POST-only but YARN AM won't proxy POST + /* + val killLinkUri = s"$basePathUri/stages/stage/kill/"
(kill)
+ */ + val killLinkUri = s"$basePathUri/stages/stage/kill/?id=${s.stageId}&terminate=true" + (kill) } val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4b5a5df5ef7b7..be4db02ab86d0 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2022,6 +2022,13 @@ private[spark] object Utils extends Logging { } } + /** + * configure a new log4j level + */ + def setLogLevel(l: org.apache.log4j.Level) { + org.apache.log4j.Logger.getRootLogger().setLevel(l) + } + /** * config a log4j properties used for testsuite */ @@ -2152,6 +2159,22 @@ private[spark] object Utils extends Logging { .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName()) } + /** + * Split the comma delimited string of master URLs into a list. + * For instance, "spark://abc,def" becomes [spark://abc, spark://def]. + */ + def parseStandaloneMasterUrls(masterUrls: String): Array[String] = { + masterUrls.stripPrefix("spark://").split(",").map("spark://" + _) + } + + /** An identifier that backup masters use in their responses. */ + val BACKUP_STANDALONE_MASTER_PREFIX = "Current state is not alive" + + /** Return true if the response message is sent from a backup Master on standby. */ + def responseFromBackup(msg: String): Boolean = { + msg.startsWith(BACKUP_STANDALONE_MASTER_PREFIX) + } + /** * Adds a shutdown hook with default priority. * diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 22acc270b983e..49e6de4e0bafd 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -78,7 +78,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit test("starting state") { sc = createSparkContext() val manager = sc.executorAllocationManager.get - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 1) assert(executorsPendingToRemove(manager).isEmpty) assert(executorIds(manager).isEmpty) assert(addTime(manager) === ExecutorAllocationManager.NOT_SET) @@ -91,108 +91,108 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Keep adding until the limit is reached - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 1) assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 1) + assert(numExecutorsTarget(manager) === 2) assert(numExecutorsToAdd(manager) === 2) assert(addExecutors(manager) === 2) - assert(numExecutorsPending(manager) === 3) + assert(numExecutorsTarget(manager) === 4) assert(numExecutorsToAdd(manager) === 4) assert(addExecutors(manager) === 4) - assert(numExecutorsPending(manager) === 7) + assert(numExecutorsTarget(manager) === 8) assert(numExecutorsToAdd(manager) === 8) - assert(addExecutors(manager) === 3) // reached the limit of 10 - assert(numExecutorsPending(manager) === 10) + assert(addExecutors(manager) === 2) // reached the limit of 10 + assert(numExecutorsTarget(manager) === 10) assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 0) - assert(numExecutorsPending(manager) === 10) + assert(numExecutorsTarget(manager) === 10) assert(numExecutorsToAdd(manager) === 1) // Register previously requested executors onExecutorAdded(manager, "first") - assert(numExecutorsPending(manager) === 9) + assert(numExecutorsTarget(manager) === 10) onExecutorAdded(manager, "second") onExecutorAdded(manager, "third") onExecutorAdded(manager, "fourth") - assert(numExecutorsPending(manager) === 6) + assert(numExecutorsTarget(manager) === 10) onExecutorAdded(manager, "first") // duplicates should not count onExecutorAdded(manager, "second") - assert(numExecutorsPending(manager) === 6) + assert(numExecutorsTarget(manager) === 10) // Try adding again // This should still fail because the number pending + running is still at the limit assert(addExecutors(manager) === 0) - assert(numExecutorsPending(manager) === 6) + assert(numExecutorsTarget(manager) === 10) assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 0) - assert(numExecutorsPending(manager) === 6) + assert(numExecutorsTarget(manager) === 10) assert(numExecutorsToAdd(manager) === 1) } test("add executors capped by num pending tasks") { - sc = createSparkContext(1, 10) + sc = createSparkContext(0, 10) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 5))) // Verify that we're capped at number of tasks in the stage - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 1) + assert(numExecutorsTarget(manager) === 1) assert(numExecutorsToAdd(manager) === 2) assert(addExecutors(manager) === 2) - assert(numExecutorsPending(manager) === 3) + assert(numExecutorsTarget(manager) === 3) assert(numExecutorsToAdd(manager) === 4) assert(addExecutors(manager) === 2) - assert(numExecutorsPending(manager) === 5) + assert(numExecutorsTarget(manager) === 5) assert(numExecutorsToAdd(manager) === 1) - // Verify that running a task reduces the cap + // Verify that running a task doesn't affect the target sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) sc.listenerBus.postToAll(SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) - assert(numExecutorsPending(manager) === 4) + assert(numExecutorsTarget(manager) === 5) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 5) + assert(numExecutorsTarget(manager) === 6) assert(numExecutorsToAdd(manager) === 2) assert(addExecutors(manager) === 2) - assert(numExecutorsPending(manager) === 7) + assert(numExecutorsTarget(manager) === 8) assert(numExecutorsToAdd(manager) === 4) assert(addExecutors(manager) === 0) - assert(numExecutorsPending(manager) === 7) + assert(numExecutorsTarget(manager) === 8) assert(numExecutorsToAdd(manager) === 1) - // Verify that re-running a task doesn't reduce the cap further + // Verify that re-running a task doesn't blow things up sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 3))) sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 0, "executor-1"))) sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(1, 0, "executor-1"))) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 8) + assert(numExecutorsTarget(manager) === 9) assert(numExecutorsToAdd(manager) === 2) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 9) + assert(numExecutorsTarget(manager) === 10) assert(numExecutorsToAdd(manager) === 1) // Verify that running a task once we're at our limit doesn't blow things up sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 1, "executor-1"))) assert(addExecutors(manager) === 0) - assert(numExecutorsPending(manager) === 9) + assert(numExecutorsTarget(manager) === 10) } test("cancel pending executors when no longer needed") { - sc = createSparkContext(1, 10) + sc = createSparkContext(0, 10) val manager = sc.executorAllocationManager.get sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 5))) - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 1) - assert(numExecutorsPending(manager) === 1) + assert(numExecutorsTarget(manager) === 1) assert(numExecutorsToAdd(manager) === 2) assert(addExecutors(manager) === 2) - assert(numExecutorsPending(manager) === 3) + assert(numExecutorsTarget(manager) === 3) val task1Info = createTaskInfo(0, 0, "executor-1") sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task1Info)) @@ -266,7 +266,6 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit // Add a few executors assert(addExecutors(manager) === 1) assert(addExecutors(manager) === 2) - assert(addExecutors(manager) === 4) onExecutorAdded(manager, "1") onExecutorAdded(manager, "2") onExecutorAdded(manager, "3") @@ -274,55 +273,57 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit onExecutorAdded(manager, "5") onExecutorAdded(manager, "6") onExecutorAdded(manager, "7") - assert(executorIds(manager).size === 7) + onExecutorAdded(manager, "8") + assert(executorIds(manager).size === 8) // Remove until limit assert(removeExecutor(manager, "1")) assert(removeExecutor(manager, "2")) - assert(!removeExecutor(manager, "3")) // lower limit reached - assert(!removeExecutor(manager, "4")) + assert(removeExecutor(manager, "3")) + assert(!removeExecutor(manager, "4")) // lower limit reached + assert(!removeExecutor(manager, "5")) onExecutorRemoved(manager, "1") onExecutorRemoved(manager, "2") + onExecutorRemoved(manager, "3") assert(executorIds(manager).size === 5) // Add until limit - assert(addExecutors(manager) === 5) // upper limit reached + assert(addExecutors(manager) === 2) // upper limit reached assert(addExecutors(manager) === 0) - assert(!removeExecutor(manager, "3")) // still at lower limit - assert(!removeExecutor(manager, "4")) - onExecutorAdded(manager, "8") + assert(!removeExecutor(manager, "4")) // still at lower limit + assert(!removeExecutor(manager, "5")) onExecutorAdded(manager, "9") onExecutorAdded(manager, "10") onExecutorAdded(manager, "11") onExecutorAdded(manager, "12") + onExecutorAdded(manager, "13") assert(executorIds(manager).size === 10) // Remove succeeds again, now that we are no longer at the lower limit - assert(removeExecutor(manager, "3")) assert(removeExecutor(manager, "4")) assert(removeExecutor(manager, "5")) assert(removeExecutor(manager, "6")) + assert(removeExecutor(manager, "7")) assert(executorIds(manager).size === 10) - assert(addExecutors(manager) === 1) - onExecutorRemoved(manager, "3") + assert(addExecutors(manager) === 0) onExecutorRemoved(manager, "4") + onExecutorRemoved(manager, "5") assert(executorIds(manager).size === 8) - // Add succeeds again, now that we are no longer at the upper limit - // Number of executors added restarts at 1 - assert(addExecutors(manager) === 2) - assert(addExecutors(manager) === 1) // upper limit reached + // Number of executors pending restarts at 1 + assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 0) assert(executorIds(manager).size === 8) - onExecutorRemoved(manager, "5") onExecutorRemoved(manager, "6") - onExecutorAdded(manager, "13") + onExecutorRemoved(manager, "7") onExecutorAdded(manager, "14") + onExecutorAdded(manager, "15") assert(executorIds(manager).size === 8) assert(addExecutors(manager) === 0) // still at upper limit - onExecutorAdded(manager, "15") onExecutorAdded(manager, "16") + onExecutorAdded(manager, "17") assert(executorIds(manager).size === 10) + assert(numExecutorsTarget(manager) === 10) } test("starting/canceling add timer") { @@ -405,33 +406,33 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit } test("mock polling loop with no events") { - sc = createSparkContext(1, 20) + sc = createSparkContext(0, 20) val manager = sc.executorAllocationManager.get val clock = new ManualClock(2020L) manager.setClock(clock) // No events - we should not be adding or removing - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) schedule(manager) - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) clock.advance(100L) schedule(manager) - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) clock.advance(1000L) schedule(manager) - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) clock.advance(10000L) schedule(manager) - assert(numExecutorsPending(manager) === 0) + assert(numExecutorsTarget(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) } test("mock polling loop add behavior") { - sc = createSparkContext(1, 20) + sc = createSparkContext(0, 20) val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) @@ -441,43 +442,43 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit onSchedulerBacklogged(manager) clock.advance(schedulerBacklogTimeout * 1000 / 2) schedule(manager) - assert(numExecutorsPending(manager) === 0) // timer not exceeded yet + assert(numExecutorsTarget(manager) === 0) // timer not exceeded yet clock.advance(schedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 1) // first timer exceeded + assert(numExecutorsTarget(manager) === 1) // first timer exceeded clock.advance(sustainedSchedulerBacklogTimeout * 1000 / 2) schedule(manager) - assert(numExecutorsPending(manager) === 1) // second timer not exceeded yet + assert(numExecutorsTarget(manager) === 1) // second timer not exceeded yet clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 1 + 2) // second timer exceeded + assert(numExecutorsTarget(manager) === 1 + 2) // second timer exceeded clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 1 + 2 + 4) // third timer exceeded + assert(numExecutorsTarget(manager) === 1 + 2 + 4) // third timer exceeded // Scheduler queue drained onSchedulerQueueEmpty(manager) clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 7) // timer is canceled + assert(numExecutorsTarget(manager) === 7) // timer is canceled clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 7) + assert(numExecutorsTarget(manager) === 7) // Scheduler queue backlogged again onSchedulerBacklogged(manager) clock.advance(schedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 7 + 1) // timer restarted + assert(numExecutorsTarget(manager) === 7 + 1) // timer restarted clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 7 + 1 + 2) + assert(numExecutorsTarget(manager) === 7 + 1 + 2) clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 7 + 1 + 2 + 4) + assert(numExecutorsTarget(manager) === 7 + 1 + 2 + 4) clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsPending(manager) === 20) // limit reached + assert(numExecutorsTarget(manager) === 20) // limit reached } test("mock polling loop remove behavior") { @@ -671,6 +672,31 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit assert(!removeTimes(manager).contains("executor-1")) } + test("avoid ramp up when target < running executors") { + sc = createSparkContext(0, 100000) + val manager = sc.executorAllocationManager.get + val stage1 = createStageInfo(0, 1000) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(stage1)) + + assert(addExecutors(manager) === 1) + assert(addExecutors(manager) === 2) + assert(addExecutors(manager) === 4) + assert(addExecutors(manager) === 8) + assert(numExecutorsTarget(manager) === 15) + (0 until 15).foreach { i => + onExecutorAdded(manager, s"executor-$i") + } + assert(executorIds(manager).size === 15) + sc.listenerBus.postToAll(SparkListenerStageCompleted(stage1)) + + adjustRequestedExecutors(manager) + assert(numExecutorsTarget(manager) === 0) + + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 1000))) + addExecutors(manager) + assert(numExecutorsTarget(manager) === 16) + } + private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = { val conf = new SparkConf() .setMaster("local") @@ -713,7 +739,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { * ------------------------------------------------------- */ private val _numExecutorsToAdd = PrivateMethod[Int]('numExecutorsToAdd) - private val _numExecutorsPending = PrivateMethod[Int]('numExecutorsPending) + private val _numExecutorsTarget = PrivateMethod[Int]('numExecutorsTarget) private val _maxNumExecutorsNeeded = PrivateMethod[Int]('maxNumExecutorsNeeded) private val _executorsPendingToRemove = PrivateMethod[collection.Set[String]]('executorsPendingToRemove) @@ -722,7 +748,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _removeTimes = PrivateMethod[collection.Map[String, Long]]('removeTimes) private val _schedule = PrivateMethod[Unit]('schedule) private val _addExecutors = PrivateMethod[Int]('addExecutors) - private val _addOrCancelExecutorRequests = PrivateMethod[Int]('addOrCancelExecutorRequests) + private val _updateAndSyncNumExecutorsTarget = + PrivateMethod[Int]('updateAndSyncNumExecutorsTarget) private val _removeExecutor = PrivateMethod[Boolean]('removeExecutor) private val _onExecutorAdded = PrivateMethod[Unit]('onExecutorAdded) private val _onExecutorRemoved = PrivateMethod[Unit]('onExecutorRemoved) @@ -735,8 +762,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _numExecutorsToAdd() } - private def numExecutorsPending(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _numExecutorsPending() + private def numExecutorsTarget(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _numExecutorsTarget() } private def executorsPendingToRemove( @@ -766,7 +793,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { } private def adjustRequestedExecutors(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _addOrCancelExecutorRequests(0L) + manager invokePrivate _updateAndSyncNumExecutorsTarget(0L) } private def removeExecutor(manager: ExecutorAllocationManager, id: String): Boolean = { diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala deleted file mode 100644 index 529f91e8eaf9e..0000000000000 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy - -import java.io.{File, FileInputStream, FileOutputStream} -import java.nio.file.{Files, Path} -import java.util.jar.{JarEntry, JarOutputStream} - -import org.apache.spark.TestUtils.{createCompiledClass, JavaSourceFromString} - -import com.google.common.io.ByteStreams - -import org.apache.commons.io.FileUtils - -import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate - -private[deploy] object IvyTestUtils { - - /** - * Create the path for the jar and pom from the maven coordinate. Extension should be `jar` - * or `pom`. - */ - private def pathFromCoordinate( - artifact: MavenCoordinate, - prefix: Path, - ext: String, - useIvyLayout: Boolean): Path = { - val groupDirs = artifact.groupId.replace(".", File.separator) - val artifactDirs = artifact.artifactId - val artifactPath = - if (!useIvyLayout) { - Seq(groupDirs, artifactDirs, artifact.version).mkString(File.separator) - } else { - Seq(groupDirs, artifactDirs, artifact.version, ext + "s").mkString(File.separator) - } - new File(prefix.toFile, artifactPath).toPath - } - - private def artifactName(artifact: MavenCoordinate, ext: String = ".jar"): String = { - s"${artifact.artifactId}-${artifact.version}$ext" - } - - /** Write the contents to a file to the supplied directory. */ - private def writeFile(dir: File, fileName: String, contents: String): File = { - val outputFile = new File(dir, fileName) - val outputStream = new FileOutputStream(outputFile) - outputStream.write(contents.toCharArray.map(_.toByte)) - outputStream.close() - outputFile - } - - /** Create an example Python file. */ - private def createPythonFile(dir: File): File = { - val contents = - """def myfunc(x): - | return x + 1 - """.stripMargin - writeFile(dir, "mylib.py", contents) - } - - /** Create a simple testable Class. */ - private def createJavaClass(dir: File, className: String, packageName: String): File = { - val contents = - s"""package $packageName; - | - |import java.lang.Integer; - | - |class $className implements java.io.Serializable { - | - | public $className() {} - | - | public Integer myFunc(Integer x) { - | return x + 1; - | } - |} - """.stripMargin - val sourceFile = - new JavaSourceFromString(new File(dir, className + ".java").getAbsolutePath, contents) - createCompiledClass(className, dir, sourceFile, Seq.empty) - } - - /** Helper method to write artifact information in the pom. */ - private def pomArtifactWriter(artifact: MavenCoordinate, tabCount: Int = 1): String = { - var result = "\n" + " " * tabCount + s"${artifact.groupId}" - result += "\n" + " " * tabCount + s"${artifact.artifactId}" - result += "\n" + " " * tabCount + s"${artifact.version}" - result - } - - /** Create a pom file for this artifact. */ - private def createPom( - dir: File, - artifact: MavenCoordinate, - dependencies: Option[Seq[MavenCoordinate]]): File = { - var content = """ - | - | - | 4.0.0 - """.stripMargin.trim - content += pomArtifactWriter(artifact) - content += dependencies.map { deps => - val inside = deps.map { dep => - "\t" + pomArtifactWriter(dep, 3) + "\n\t" - }.mkString("\n") - "\n \n" + inside + "\n " - }.getOrElse("") - content += "\n" - writeFile(dir, artifactName(artifact, ".pom"), content.trim) - } - - /** Create the jar for the given maven coordinate, using the supplied files. */ - private def packJar( - dir: File, - artifact: MavenCoordinate, - files: Seq[(String, File)]): File = { - val jarFile = new File(dir, artifactName(artifact)) - val jarFileStream = new FileOutputStream(jarFile) - val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) - - for (file <- files) { - val jarEntry = new JarEntry(file._1) - jarStream.putNextEntry(jarEntry) - - val in = new FileInputStream(file._2) - ByteStreams.copy(in, jarStream) - in.close() - } - jarStream.close() - jarFileStream.close() - - jarFile - } - - /** - * Creates a jar and pom file, mocking a Maven repository. The root path can be supplied with - * `tempDir`, dependencies can be created into the same repo, and python files can also be packed - * inside the jar. - * - * @param artifact The maven coordinate to generate the jar and pom for. - * @param dependencies List of dependencies this artifact might have to also create jars and poms. - * @param tempDir The root folder of the repository - * @param useIvyLayout whether to mock the Ivy layout for local repository testing - * @param withPython Whether to pack python files inside the jar for extensive testing. - * @return Root path of the repository - */ - private def createLocalRepository( - artifact: MavenCoordinate, - dependencies: Option[Seq[MavenCoordinate]] = None, - tempDir: Option[Path] = None, - useIvyLayout: Boolean = false, - withPython: Boolean = false): Path = { - // Where the root of the repository exists, and what Ivy will search in - val tempPath = tempDir.getOrElse(Files.createTempDirectory(null)) - // Create directory if it doesn't exist - Files.createDirectories(tempPath) - // Where to create temporary class files and such - val root = Files.createTempDirectory(tempPath, null).toFile - try { - val jarPath = pathFromCoordinate(artifact, tempPath, "jar", useIvyLayout) - Files.createDirectories(jarPath) - val className = "MyLib" - - val javaClass = createJavaClass(root, className, artifact.groupId) - // A tuple of files representation in the jar, and the file - val javaFile = (artifact.groupId.replace(".", "/") + "/" + javaClass.getName, javaClass) - val allFiles = - if (withPython) { - val pythonFile = createPythonFile(root) - Seq(javaFile, (pythonFile.getName, pythonFile)) - } else { - Seq(javaFile) - } - val jarFile = packJar(jarPath.toFile, artifact, allFiles) - assert(jarFile.exists(), "Problem creating Jar file") - val pomPath = pathFromCoordinate(artifact, tempPath, "pom", useIvyLayout) - Files.createDirectories(pomPath) - val pomFile = createPom(pomPath.toFile, artifact, dependencies) - assert(pomFile.exists(), "Problem creating Pom file") - } finally { - FileUtils.deleteDirectory(root) - } - tempPath - } - - /** - * Creates a suite of jars and poms, with or without dependencies, mocking a maven repository. - * @param artifact The main maven coordinate to generate the jar and pom for. - * @param dependencies List of dependencies this artifact might have to also create jars and poms. - * @param rootDir The root folder of the repository (like `~/.m2/repositories`) - * @param useIvyLayout whether to mock the Ivy layout for local repository testing - * @param withPython Whether to pack python files inside the jar for extensive testing. - * @return Root path of the repository. Will be `rootDir` if supplied. - */ - private[deploy] def createLocalRepositoryForTests( - artifact: MavenCoordinate, - dependencies: Option[String], - rootDir: Option[Path], - useIvyLayout: Boolean = false, - withPython: Boolean = false): Path = { - val deps = dependencies.map(SparkSubmitUtils.extractMavenCoordinates) - val mainRepo = createLocalRepository(artifact, deps, rootDir, useIvyLayout, withPython) - deps.foreach { seq => seq.foreach { dep => - createLocalRepository(dep, None, Some(mainRepo), useIvyLayout, withPython = false) - }} - mainRepo - } - - /** - * Creates a repository for a test, and cleans it up afterwards. - * - * @param artifact The main maven coordinate to generate the jar and pom for. - * @param dependencies List of dependencies this artifact might have to also create jars and poms. - * @param rootDir The root folder of the repository (like `~/.m2/repositories`) - * @param useIvyLayout whether to mock the Ivy layout for local repository testing - * @param withPython Whether to pack python files inside the jar for extensive testing. - * @return Root path of the repository. Will be `rootDir` if supplied. - */ - private[deploy] def withRepository( - artifact: MavenCoordinate, - dependencies: Option[String], - rootDir: Option[Path], - useIvyLayout: Boolean = false, - withPython: Boolean = false)(f: String => Unit): Unit = { - val repo = createLocalRepositoryForTests(artifact, dependencies, rootDir, useIvyLayout, - withPython) - try { - f(repo.toUri.toString) - } finally { - // Clean up - if (repo.toString.contains(".m2") || repo.toString.contains(".ivy2")) { - FileUtils.deleteDirectory(new File(repo.toFile, - artifact.groupId.replace(".", File.separator) + File.separator + artifact.artifactId)) - dependencies.map(SparkSubmitUtils.extractMavenCoordinates).foreach { seq => - seq.foreach { dep => - FileUtils.deleteDirectory(new File(repo.toFile, - dep.artifactId.replace(".", File.separator))) - } - } - } else { - FileUtils.deleteDirectory(repo.toFile) - } - } - } -} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 35382be7e0ef1..029a1156fda6b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -30,7 +30,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.deploy.SparkSubmit._ -import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.util.{ResetSystemProperties, Utils} // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch @@ -335,22 +334,18 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties runSparkSubmit(args) } - test("includes jars passed in through --packages") { + ignore("includes jars passed in through --packages") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) - val main = MavenCoordinate("my.great.lib", "mylib", "0.1") - val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") - IvyTestUtils.withRepository(main, Some(dep.toString), None) { repo => - val args = Seq( - "--class", JarCreationTest.getClass.getName.stripSuffix("$"), - "--name", "testApp", - "--master", "local-cluster[2,1,512]", - "--packages", Seq(main, dep).mkString(","), - "--repositories", repo, - "--conf", "spark.ui.enabled=false", - unusedJar.toString, - "my.great.lib.MyLib", "my.great.dep.MyLib") - runSparkSubmit(args) - } + val packagesString = "com.databricks:spark-csv_2.10:0.1,com.databricks:spark-avro_2.10:0.1" + val args = Seq( + "--class", JarCreationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local-cluster[2,1,512]", + "--packages", packagesString, + "--conf", "spark.ui.enabled=false", + unusedJar.toString, + "com.databricks.spark.csv.DefaultSource", "com.databricks.spark.avro.DefaultSource") + runSparkSubmit(args) } test("resolves command line argument paths correctly") { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index cc79ee7ea20b4..2df2597e058cd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.deploy import java.io.{PrintStream, OutputStream, File} +import org.apache.ivy.core.settings.IvySettings + import scala.collection.mutable.ArrayBuffer + import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.ivy.core.module.descriptor.MDArtifact -import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.resolver.IBiblioResolver -import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate - class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { private val noOpOutputStream = new OutputStream { @@ -89,7 +89,7 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { } test("ivy path works correctly") { - val ivyPath = "dummy" + File.separator + "ivy" + val ivyPath = "dummy/ivy" val md = SparkSubmitUtils.getModuleDescriptor val artifacts = for (i <- 0 until 3) yield new MDArtifact(md, s"jar-$i", "jar", "jar") var jPaths = SparkSubmitUtils.resolveDependencyPaths(artifacts.toArray, new File(ivyPath)) @@ -98,38 +98,17 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { assert(index >= 0) jPaths = jPaths.substring(index + ivyPath.length) } - val main = MavenCoordinate("my.awesome.lib", "mylib", "0.1") - IvyTestUtils.withRepository(main, None, None) { repo => - // end to end - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, Option(repo), - Option(ivyPath), true) - assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") - } + // end to end + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + "com.databricks:spark-csv_2.10:0.1", None, Option(ivyPath), true) + assert(jarPath.indexOf(ivyPath) >= 0, "should use non-default ivy path") } - test("search for artifact at local repositories") { - val main = new MavenCoordinate("my.awesome.lib", "mylib", "0.1") - // Local M2 repository - IvyTestUtils.withRepository(main, None, Some(SparkSubmitUtils.m2Path)) { repo => - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, true) - assert(jarPath.indexOf("mylib") >= 0, "should find artifact") - } - // Local Ivy Repository - val settings = new IvySettings - val ivyLocal = new File(settings.getDefaultIvyUserDir, "local" + File.separator) - IvyTestUtils.withRepository(main, None, Some(ivyLocal.toPath), true) { repo => - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, true) - assert(jarPath.indexOf("mylib") >= 0, "should find artifact") - } - // Local ivy repository with modified home - val dummyIvyPath = "dummy" + File.separator + "ivy" - val dummyIvyLocal = new File(dummyIvyPath, "local" + File.separator) - IvyTestUtils.withRepository(main, None, Some(dummyIvyLocal.toPath), true) { repo => - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, - Some(dummyIvyPath), true) - assert(jarPath.indexOf("mylib") >= 0, "should find artifact") - assert(jarPath.indexOf(dummyIvyPath) >= 0, "should be in new ivy path") - } + test("search for artifact at other repositories") { + val path = SparkSubmitUtils.resolveMavenCoordinates("com.agimatec:agimatec-validation:0.9.3", + Option("https://oss.sonatype.org/content/repositories/agimatec/"), None, true) + assert(path.indexOf("agimatec-validation") >= 0, "should find package. If it doesn't, check" + + "if package still exists. If it has been removed, replace the example in this test.") } test("dependency not found throws RuntimeException") { @@ -138,7 +117,7 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { } } - test("neglects Spark and Spark's dependencies") { + ignore("neglects Spark and Spark's dependencies") { val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") @@ -148,11 +127,11 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { val path = SparkSubmitUtils.resolveMavenCoordinates(coordinates, None, None, true) assert(path === "", "should return empty path") - val main = MavenCoordinate("org.apache.spark", "spark-streaming-kafka-assembly_2.10", "1.2.0") - IvyTestUtils.withRepository(main, None, None) { repo => - val files = SparkSubmitUtils.resolveMavenCoordinates(coordinates + "," + main.toString, - Some(repo), None, true) - assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") + // Should not exclude the following dependency. Will throw an error, because it doesn't exist, + // but the fact that it is checking means that it wasn't excluded. + intercept[RuntimeException] { + SparkSubmitUtils.resolveMavenCoordinates(coordinates + + ",org.apache.spark:spark-streaming-kafka-assembly_2.10:1.2.0", None, None, true) } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 0a318a27ac212..f4d548d9e7720 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -39,7 +39,6 @@ import org.apache.spark.deploy.master.DriverState._ * Tests for the REST application submission protocol used in standalone cluster mode. */ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { - private val client = new RestSubmissionClient private var actorSystem: Option[ActorSystem] = None private var server: Option[RestSubmissionServer] = None @@ -52,7 +51,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val appArgs = Array("one", "two", "three") val sparkProperties = Map("spark.app.name" -> "pi") val environmentVariables = Map("SPARK_ONE" -> "UN", "SPARK_TWO" -> "DEUX") - val request = client.constructSubmitRequest( + val request = new RestSubmissionClient("spark://host:port").constructSubmitRequest( "my-app-resource", "my-main-class", appArgs, sparkProperties, environmentVariables) assert(request.action === Utils.getFormattedClassName(request)) assert(request.clientSparkVersion === SPARK_VERSION) @@ -71,7 +70,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val request = constructSubmitRequest(masterUrl, appArgs) assert(request.appArgs === appArgs) assert(request.sparkProperties("spark.master") === masterUrl) - val response = client.createSubmission(masterUrl, request) + val response = new RestSubmissionClient(masterUrl).createSubmission(request) val submitResponse = getSubmitResponse(response) assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) assert(submitResponse.serverSparkVersion === SPARK_VERSION) @@ -102,7 +101,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val submissionId = "my-lyft-driver" val killMessage = "your driver is killed" val masterUrl = startDummyServer(killMessage = killMessage) - val response = client.killSubmission(masterUrl, submissionId) + val response = new RestSubmissionClient(masterUrl).killSubmission(submissionId) val killResponse = getKillResponse(response) assert(killResponse.action === Utils.getFormattedClassName(killResponse)) assert(killResponse.serverSparkVersion === SPARK_VERSION) @@ -116,7 +115,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val submissionState = KILLED val submissionException = new Exception("there was an irresponsible mix of alcohol and cars") val masterUrl = startDummyServer(state = submissionState, exception = Some(submissionException)) - val response = client.requestSubmissionStatus(masterUrl, submissionId) + val response = new RestSubmissionClient(masterUrl).requestSubmissionStatus(submissionId) val statusResponse = getStatusResponse(response) assert(statusResponse.action === Utils.getFormattedClassName(statusResponse)) assert(statusResponse.serverSparkVersion === SPARK_VERSION) @@ -129,13 +128,14 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("create then kill") { val masterUrl = startSmartServer() val request = constructSubmitRequest(masterUrl) - val response1 = client.createSubmission(masterUrl, request) + val client = new RestSubmissionClient(masterUrl) + val response1 = client.createSubmission(request) val submitResponse = getSubmitResponse(response1) assert(submitResponse.success) assert(submitResponse.submissionId != null) // kill submission that was just created val submissionId = submitResponse.submissionId - val response2 = client.killSubmission(masterUrl, submissionId) + val response2 = client.killSubmission(submissionId) val killResponse = getKillResponse(response2) assert(killResponse.success) assert(killResponse.submissionId === submissionId) @@ -144,13 +144,14 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("create then request status") { val masterUrl = startSmartServer() val request = constructSubmitRequest(masterUrl) - val response1 = client.createSubmission(masterUrl, request) + val client = new RestSubmissionClient(masterUrl) + val response1 = client.createSubmission(request) val submitResponse = getSubmitResponse(response1) assert(submitResponse.success) assert(submitResponse.submissionId != null) // request status of submission that was just created val submissionId = submitResponse.submissionId - val response2 = client.requestSubmissionStatus(masterUrl, submissionId) + val response2 = client.requestSubmissionStatus(submissionId) val statusResponse = getStatusResponse(response2) assert(statusResponse.success) assert(statusResponse.submissionId === submissionId) @@ -160,8 +161,9 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("create then kill then request status") { val masterUrl = startSmartServer() val request = constructSubmitRequest(masterUrl) - val response1 = client.createSubmission(masterUrl, request) - val response2 = client.createSubmission(masterUrl, request) + val client = new RestSubmissionClient(masterUrl) + val response1 = client.createSubmission(request) + val response2 = client.createSubmission(request) val submitResponse1 = getSubmitResponse(response1) val submitResponse2 = getSubmitResponse(response2) assert(submitResponse1.success) @@ -171,13 +173,13 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val submissionId1 = submitResponse1.submissionId val submissionId2 = submitResponse2.submissionId // kill only submission 1, but not submission 2 - val response3 = client.killSubmission(masterUrl, submissionId1) + val response3 = client.killSubmission(submissionId1) val killResponse = getKillResponse(response3) assert(killResponse.success) assert(killResponse.submissionId === submissionId1) // request status for both submissions: 1 should be KILLED but 2 should be RUNNING still - val response4 = client.requestSubmissionStatus(masterUrl, submissionId1) - val response5 = client.requestSubmissionStatus(masterUrl, submissionId2) + val response4 = client.requestSubmissionStatus(submissionId1) + val response5 = client.requestSubmissionStatus(submissionId2) val statusResponse1 = getStatusResponse(response4) val statusResponse2 = getStatusResponse(response5) assert(statusResponse1.submissionId === submissionId1) @@ -189,13 +191,14 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("kill or request status before create") { val masterUrl = startSmartServer() val doesNotExist = "does-not-exist" + val client = new RestSubmissionClient(masterUrl) // kill a non-existent submission - val response1 = client.killSubmission(masterUrl, doesNotExist) + val response1 = client.killSubmission(doesNotExist) val killResponse = getKillResponse(response1) assert(!killResponse.success) assert(killResponse.submissionId === doesNotExist) // request status for a non-existent submission - val response2 = client.requestSubmissionStatus(masterUrl, doesNotExist) + val response2 = client.requestSubmissionStatus(doesNotExist) val statusResponse = getStatusResponse(response2) assert(!statusResponse.success) assert(statusResponse.submissionId === doesNotExist) @@ -339,6 +342,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("client handles faulty server") { val masterUrl = startFaultyServer() + val client = new RestSubmissionClient(masterUrl) val httpUrl = masterUrl.replace("spark://", "http://") val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" @@ -425,7 +429,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args) - client.constructSubmitRequest( + new RestSubmissionClient("spark://host:port").constructSubmitRequest( mainJar, mainClass, appArgs, sparkProperties.toMap, Map.empty) } @@ -492,7 +496,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { method: String, body: String = ""): (SubmitRestProtocolResponse, Int) = { val conn = sendHttpRequest(url, method, body) - (client.readResponse(conn), conn.getResponseCode) + (new RestSubmissionClient("spark://host:port").readResponse(conn), conn.getResponseCode) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index cdd7be0fbe5dd..ab863f3d8d672 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -73,6 +73,52 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo s"cd test-app-1*; ./bin/spark-class ${classOf[MesosExecutorBackend].getName}") } + test("spark docker properties correctly populate the DockerInfo message") { + val taskScheduler = mock[TaskSchedulerImpl] + + val conf = new SparkConf() + .set("spark.mesos.executor.docker.image", "spark/mock") + .set("spark.mesos.executor.docker.volumes", "/a,/b:/b,/c:/c:rw,/d:ro,/e:/e:ro") + .set("spark.mesos.executor.docker.portmaps", "80:8080,53:53:tcp") + + val listenerBus = mock[LiveListenerBus] + listenerBus.post( + SparkListenerExecutorAdded(anyLong, "s1", new ExecutorInfo("host1", 2, Map.empty))) + + val sc = mock[SparkContext] + when(sc.executorMemory).thenReturn(100) + when(sc.getSparkHome()).thenReturn(Option("/spark-home")) + when(sc.executorEnvs).thenReturn(new mutable.HashMap[String, String]) + when(sc.conf).thenReturn(conf) + when(sc.listenerBus).thenReturn(listenerBus) + + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + + val execInfo = backend.createExecutorInfo("mockExecutor") + assert(execInfo.getContainer.getDocker.getImage.equals("spark/mock")) + val portmaps = execInfo.getContainer.getDocker.getPortMappingsList + assert(portmaps.get(0).getHostPort.equals(80)) + assert(portmaps.get(0).getContainerPort.equals(8080)) + assert(portmaps.get(0).getProtocol.equals("tcp")) + assert(portmaps.get(1).getHostPort.equals(53)) + assert(portmaps.get(1).getContainerPort.equals(53)) + assert(portmaps.get(1).getProtocol.equals("tcp")) + val volumes = execInfo.getContainer.getVolumesList + assert(volumes.get(0).getContainerPath.equals("/a")) + assert(volumes.get(0).getMode.equals(Volume.Mode.RW)) + assert(volumes.get(1).getContainerPath.equals("/b")) + assert(volumes.get(1).getHostPath.equals("/b")) + assert(volumes.get(1).getMode.equals(Volume.Mode.RW)) + assert(volumes.get(2).getContainerPath.equals("/c")) + assert(volumes.get(2).getHostPath.equals("/c")) + assert(volumes.get(2).getMode.equals(Volume.Mode.RW)) + assert(volumes.get(3).getContainerPath.equals("/d")) + assert(volumes.get(3).getMode.equals(Volume.Mode.RO)) + assert(volumes.get(4).getContainerPath.equals("/e")) + assert(volumes.get(4).getHostPath.equals("/e")) + assert(volumes.get(4).getMode.equals(Volume.Mode.RO)) + } + test("mesos resource offers result in launching tasks") { def createOffer(id: Int, mem: Int, cpu: Int): Offer = { val builder = Offer.newBuilder() diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index eb9db550fd74c..d53d7f3ba5ae7 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -350,7 +350,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before } } - test("kill stage is POST only") { + test("kill stage POST/GET response is correct") { def getResponseCode(url: URL, method: String): Int = { val connection = url.openConnection().asInstanceOf[HttpURLConnection] connection.setRequestMethod(method) @@ -365,7 +365,8 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0&terminate=true") - getResponseCode(url, "GET") should be (405) + // SPARK-6846: should be POST only but YARN AM doesn't proxy POST + getResponseCode(url, "GET") should be (200) getResponseCode(url, "POST") should be (200) } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 62a3cbcdf69ea..651ead6ff1de2 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -35,9 +35,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.network.util.ByteUnit +import org.apache.spark.Logging import org.apache.spark.SparkConf -class UtilsSuite extends FunSuite with ResetSystemProperties { +class UtilsSuite extends FunSuite with ResetSystemProperties with Logging { test("timeConversion") { // Test -1 @@ -68,7 +69,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { intercept[NumberFormatException] { Utils.timeStringAsMs("600l") } - + intercept[NumberFormatException] { Utils.timeStringAsMs("This breaks 600s") } @@ -99,7 +100,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.byteStringAsGb("1k") === 0) assert(Utils.byteStringAsGb("1t") === ByteUnit.TiB.toGiB(1)) assert(Utils.byteStringAsGb("1p") === ByteUnit.PiB.toGiB(1)) - + assert(Utils.byteStringAsMb("1") === 1) assert(Utils.byteStringAsMb("1m") === 1) assert(Utils.byteStringAsMb("1048575b") === 0) @@ -118,7 +119,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.byteStringAsKb("1g") === ByteUnit.GiB.toKiB(1)) assert(Utils.byteStringAsKb("1t") === ByteUnit.TiB.toKiB(1)) assert(Utils.byteStringAsKb("1p") === ByteUnit.PiB.toKiB(1)) - + assert(Utils.byteStringAsBytes("1") === 1) assert(Utils.byteStringAsBytes("1k") === ByteUnit.KiB.toBytes(1)) assert(Utils.byteStringAsBytes("1m") === ByteUnit.MiB.toBytes(1)) @@ -127,17 +128,17 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.byteStringAsBytes("1p") === ByteUnit.PiB.toBytes(1)) // Overflow handling, 1073741824p exceeds Long.MAX_VALUE if converted straight to Bytes - // This demonstrates that we can have e.g 1024^3 PB without overflowing. + // This demonstrates that we can have e.g 1024^3 PB without overflowing. assert(Utils.byteStringAsGb("1073741824p") === ByteUnit.PiB.toGiB(1073741824)) assert(Utils.byteStringAsMb("1073741824p") === ByteUnit.PiB.toMiB(1073741824)) - + // Run this to confirm it doesn't throw an exception - assert(Utils.byteStringAsBytes("9223372036854775807") === 9223372036854775807L) + assert(Utils.byteStringAsBytes("9223372036854775807") === 9223372036854775807L) assert(ByteUnit.PiB.toPiB(9223372036854775807L) === 9223372036854775807L) - + // Test overflow exception intercept[IllegalArgumentException] { - // This value exceeds Long.MAX when converted to bytes + // This value exceeds Long.MAX when converted to bytes Utils.byteStringAsBytes("9223372036854775808") } @@ -146,22 +147,22 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { // This value exceeds Long.MAX when converted to TB ByteUnit.PiB.toTiB(9223372036854775807L) } - + // Test fractional string intercept[NumberFormatException] { Utils.byteStringAsMb("0.064") } - + // Test fractional string intercept[NumberFormatException] { Utils.byteStringAsMb("0.064m") } - + // Test invalid strings intercept[NumberFormatException] { Utils.byteStringAsBytes("500ub") } - + // Test invalid strings intercept[NumberFormatException] { Utils.byteStringAsBytes("This breaks 600b") @@ -174,12 +175,12 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { intercept[NumberFormatException] { Utils.byteStringAsBytes("600gb This breaks") } - + intercept[NumberFormatException] { Utils.byteStringAsBytes("This 123mb breaks") } } - + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") @@ -475,6 +476,15 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { } } + // Test for using the util function to change our log levels. + test("log4j log level change") { + Utils.setLogLevel(org.apache.log4j.Level.ALL) + assert(log.isInfoEnabled()) + Utils.setLogLevel(org.apache.log4j.Level.ERROR) + assert(!log.isInfoEnabled()) + assert(log.isErrorEnabled()) + } + test("deleteRecursively") { val tempDir1 = Utils.createTempDir() assert(tempDir1.exists()) diff --git a/docker/spark-mesos/Dockerfile b/docker/spark-mesos/Dockerfile new file mode 100644 index 0000000000000..b90aef3655dee --- /dev/null +++ b/docker/spark-mesos/Dockerfile @@ -0,0 +1,30 @@ +# This is an example Dockerfile for creating a Spark image which can be +# references by the Spark property 'spark.mesos.executor.docker.image' +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +FROM mesosphere/mesos:0.20.1 + +# Update the base ubuntu image with dependencies needed for Spark +RUN apt-get update && \ + apt-get install -y python libnss3 openjdk-7-jre-headless curl + +RUN mkdir /opt/spark && \ + curl http://www.apache.org/dyn/closer.cgi/spark/spark-1.4.0/spark-1.4.0-bin-hadoop2.4.tgz \ + | tar -xzC /opt +ENV SPARK_HOME /opt/spark +ENV MESOS_NATIVE_JAVA_LIBRARY /usr/local/lib/libmesos.so diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 8f53d8201a089..5f1d6daeb27f0 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -184,6 +184,16 @@ acquire. By default, it will acquire *all* cores in the cluster (that get offere only makes sense if you run just one application at a time. You can cap the maximum number of cores using `conf.set("spark.cores.max", "10")` (for example). +# Mesos Docker Support + +Spark can make use of a Mesos Docker containerizer by setting the property `spark.mesos.executor.docker.image` +in your [SparkConf](configuration.html#spark-properties). + +The Docker image used must have an appropriate version of Spark already part of the image, or you can +have Mesos download Spark via the usual methods. + +Requires Mesos version 0.20.1 or later. + # Running Alongside Hadoop You can run Spark and Mesos alongside your existing Hadoop cluster by just launching them as a @@ -237,6 +247,38 @@ See the [configuration page](configuration.html) for information on Spark config The value can be a floating point number. + + spark.mesos.executor.docker.image + (none) + + Set the name of the docker image that the Spark executors will run in. The selected + image must have Spark installed, as well as a compatible version of the Mesos library. + The installed path of Spark in the image can be specified with spark.mesos.executor.home; + the installed path of the Mesos library can be specified with spark.executorEnv.MESOS_NATIVE_LIBRARY. + + + + spark.mesos.executor.docker.volumes + (none) + + Set the list of volumes which will be mounted into the Docker image, which was set using + spark.mesos.executor.docker.image. The format of this property is a comma-separated list of + mappings following the form passed to docker run -v. That is they take the form: + +
[host_path:]container_path[:ro|:rw]
+ + + + spark.mesos.executor.docker.portmaps + (none) + + Set the list of incoming ports exposed by the Docker image, which was set using + spark.mesos.executor.docker.image. The format of this property is a comma-separated list of + mappings which take the form: + +
host_port:container_port[:tcp|:udp]
+ + spark.mesos.executor.home driver side SPARK_HOME diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 0968fc5ad632b..b6701b64c2925 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -189,6 +189,13 @@ Most of the configs are the same for Spark on YARN as for other deployment modes In cluster mode, use spark.driver.extraJavaOptions instead. + + spark.yarn.am.extraLibraryPath + (none) + + Set a special library path to use when launching the application master in client mode. + + spark.yarn.maxAppAttempts yarn.resourcemanager.am.max-attempts in YARN diff --git a/docs/security.md b/docs/security.md index c034ba12ff1fc..d4ffa60e59a33 100644 --- a/docs/security.md +++ b/docs/security.md @@ -32,6 +32,8 @@ SSL must be configured on each node and configured for each component involved i ### YARN mode The key-store can be prepared on the client side and then distributed and used by the executors as the part of the application. It is possible because the user is able to deploy files before the application is started in YARN by using `spark.yarn.dist.files` or `spark.yarn.dist.archives` configuration settings. The responsibility for encryption of transferring these files is on YARN side and has nothing to do with Spark. +For long-running apps like Spark Streaming apps to be able to write to HDFS, it is possible to pass a principal and keytab to `spark-submit` via the `--principal` and `--keytab` parameters respectively. The keytab passed in will be copied over to the machine running the Application Master via the Hadoop Distributed Cache (securely - if YARN is configured with SSL and HDFS encryption is enabled). The Kerberos login will be periodically renewed using this principal and keytab and the delegation tokens required for HDFS will be generated periodically so the application can continue writing to HDFS. + ### Standalone mode The user needs to provide key-stores and configuration options for master and workers. They have to be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in `SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. In this mode, the user may allow the executors to use the SSL settings inherited from the worker which spawned that executor. It can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. If that parameter is set, the settings provided by user on the client side, are not used by the executors. diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index f695cff410a18..243ce6eaca658 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -44,7 +44,7 @@ org.apache.kafka kafka_${scala.binary.version} - 0.8.1.1 + 0.8.2.1 com.sun.jmx diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index bd767031c1849..6cf254a7b69cb 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -20,9 +20,10 @@ package org.apache.spark.streaming.kafka import scala.util.control.NonFatal import scala.util.Random import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters._ import java.util.Properties import kafka.api._ -import kafka.common.{ErrorMapping, OffsetMetadataAndError, TopicAndPartition} +import kafka.common.{ErrorMapping, OffsetAndMetadata, OffsetMetadataAndError, TopicAndPartition} import kafka.consumer.{ConsumerConfig, SimpleConsumer} import org.apache.spark.SparkException @@ -220,12 +221,22 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-OffsetCommit/FetchAPI // scalastyle:on + // this 0 here indicates api version, in this case the original ZK backed api. + private def defaultConsumerApiVersion: Short = 0 + /** Requires Kafka >= 0.8.1.1 */ def getConsumerOffsets( groupId: String, topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, Long]] = + getConsumerOffsets(groupId, topicAndPartitions, defaultConsumerApiVersion) + + def getConsumerOffsets( + groupId: String, + topicAndPartitions: Set[TopicAndPartition], + consumerApiVersion: Short ): Either[Err, Map[TopicAndPartition, Long]] = { - getConsumerOffsetMetadata(groupId, topicAndPartitions).right.map { r => + getConsumerOffsetMetadata(groupId, topicAndPartitions, consumerApiVersion).right.map { r => r.map { kv => kv._1 -> kv._2.offset } @@ -236,9 +247,16 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { def getConsumerOffsetMetadata( groupId: String, topicAndPartitions: Set[TopicAndPartition] + ): Either[Err, Map[TopicAndPartition, OffsetMetadataAndError]] = + getConsumerOffsetMetadata(groupId, topicAndPartitions, defaultConsumerApiVersion) + + def getConsumerOffsetMetadata( + groupId: String, + topicAndPartitions: Set[TopicAndPartition], + consumerApiVersion: Short ): Either[Err, Map[TopicAndPartition, OffsetMetadataAndError]] = { var result = Map[TopicAndPartition, OffsetMetadataAndError]() - val req = OffsetFetchRequest(groupId, topicAndPartitions.toSeq) + val req = OffsetFetchRequest(groupId, topicAndPartitions.toSeq, consumerApiVersion) val errs = new Err withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => val resp = consumer.fetchOffsets(req) @@ -266,24 +284,39 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { def setConsumerOffsets( groupId: String, offsets: Map[TopicAndPartition, Long] + ): Either[Err, Map[TopicAndPartition, Short]] = + setConsumerOffsets(groupId, offsets, defaultConsumerApiVersion) + + def setConsumerOffsets( + groupId: String, + offsets: Map[TopicAndPartition, Long], + consumerApiVersion: Short ): Either[Err, Map[TopicAndPartition, Short]] = { - setConsumerOffsetMetadata(groupId, offsets.map { kv => - kv._1 -> OffsetMetadataAndError(kv._2) - }) + val meta = offsets.map { kv => + kv._1 -> OffsetAndMetadata(kv._2) + } + setConsumerOffsetMetadata(groupId, meta, consumerApiVersion) } /** Requires Kafka >= 0.8.1.1 */ def setConsumerOffsetMetadata( groupId: String, - metadata: Map[TopicAndPartition, OffsetMetadataAndError] + metadata: Map[TopicAndPartition, OffsetAndMetadata] + ): Either[Err, Map[TopicAndPartition, Short]] = + setConsumerOffsetMetadata(groupId, metadata, defaultConsumerApiVersion) + + def setConsumerOffsetMetadata( + groupId: String, + metadata: Map[TopicAndPartition, OffsetAndMetadata], + consumerApiVersion: Short ): Either[Err, Map[TopicAndPartition, Short]] = { var result = Map[TopicAndPartition, Short]() - val req = OffsetCommitRequest(groupId, metadata) + val req = OffsetCommitRequest(groupId, metadata, consumerApiVersion) val errs = new Err val topicAndPartitions = metadata.keySet withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer => val resp = consumer.commitOffsets(req) - val respMap = resp.requestInfo + val respMap = resp.commitStatus val needed = topicAndPartitions.diff(result.keySet) needed.foreach { tp: TopicAndPartition => respMap.get(tp).foreach { err: Short => diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 13e9475065979..6dc4e9517d5a4 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -29,10 +29,12 @@ import scala.language.postfixOps import scala.util.control.NonFatal import kafka.admin.AdminUtils +import kafka.api.Request +import kafka.common.TopicAndPartition import kafka.producer.{KeyedMessage, Producer, ProducerConfig} import kafka.serializer.StringEncoder import kafka.server.{KafkaConfig, KafkaServer} -import kafka.utils.ZKStringSerializer +import kafka.utils.{ZKStringSerializer, ZkUtils} import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} import org.I0Itec.zkclient.ZkClient @@ -227,12 +229,35 @@ private class KafkaTestUtils extends Logging { tryAgain(1) } - private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + /** Wait until the leader offset for the given topic/partition equals the specified offset */ + def waitUntilLeaderOffset( + topic: String, + partition: Int, + offset: Long): Unit = { eventually(Time(10000), Time(100)) { + val kc = new KafkaCluster(Map("metadata.broker.list" -> brokerAddress)) + val tp = TopicAndPartition(topic, partition) + val llo = kc.getLatestLeaderOffsets(Set(tp)).right.get.apply(tp).offset assert( - server.apis.metadataCache.containsTopicAndPartition(topic, partition), - s"Partition [$topic, $partition] metadata not propagated after timeout" - ) + llo == offset, + s"$topic $partition $offset not reached after timeout") + } + } + + private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = { + def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match { + case Some(partitionState) => + val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr + + ZkUtils.getLeaderForPartition(zkClient, topic, partition).isDefined && + Request.isValidBrokerId(leaderAndInSyncReplicas.leader) && + leaderAndInSyncReplicas.isr.size >= 1 + + case _ => + false + } + eventually(Time(10000), Time(100)) { + assert(isPropagated, s"Partition [$topic, $partition] metadata not propagated after timeout") } } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index a9dc6e50613ca..5cf379635354f 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -72,6 +72,9 @@ public void testKafkaRDD() throws InterruptedException { HashMap kafkaParams = new HashMap(); kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress()); + kafkaTestUtils.waitUntilLeaderOffset(topic1, 0, topic1data.length); + kafkaTestUtils.waitUntilLeaderOffset(topic2, 0, topic2data.length); + OffsetRange[] offsetRanges = { OffsetRange.create(topic1, 0, 0, 1), OffsetRange.create(topic2, 0, 0, 1) diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 7d26ce50875b3..39c3fb448ff57 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -53,14 +53,15 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { } test("basic usage") { - val topic = "topicbasic" + val topic = s"topicbasic-${Random.nextInt}" kafkaTestUtils.createTopic(topic) val messages = Set("the", "quick", "brown", "fox") kafkaTestUtils.sendMessages(topic, messages.toArray) - val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, - "group.id" -> s"test-consumer-${Random.nextInt(10000)}") + "group.id" -> s"test-consumer-${Random.nextInt}") + + kafkaTestUtils.waitUntilLeaderOffset(topic, 0, messages.size) val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) @@ -73,27 +74,38 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { test("iterator boundary conditions") { // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd - val topic = "topic1" + val topic = s"topicboundary-${Random.nextInt}" val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) kafkaTestUtils.createTopic(topic) val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, - "group.id" -> s"test-consumer-${Random.nextInt(10000)}") + "group.id" -> s"test-consumer-${Random.nextInt}") val kc = new KafkaCluster(kafkaParams) // this is the "lots of messages" case kafkaTestUtils.sendMessages(topic, sent) + val sentCount = sent.values.sum + kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount) + // rdd defined from leaders after sending messages, should get the number sent val rdd = getRdd(kc, Set(topic)) assert(rdd.isDefined) - assert(rdd.get.count === sent.values.sum, "didn't get all sent messages") - val ranges = rdd.get.asInstanceOf[HasOffsetRanges] - .offsetRanges.map(o => TopicAndPartition(o.topic, o.partition) -> o.untilOffset).toMap + val ranges = rdd.get.asInstanceOf[HasOffsetRanges].offsetRanges + val rangeCount = ranges.map(o => o.untilOffset - o.fromOffset).sum - kc.setConsumerOffsets(kafkaParams("group.id"), ranges) + assert(rangeCount === sentCount, "offset range didn't include all sent messages") + assert(rdd.get.count === sentCount, "didn't get all sent messages") + + val rangesMap = ranges.map(o => TopicAndPartition(o.topic, o.partition) -> o.untilOffset).toMap + + // make sure consumer offsets are committed before the next getRdd call + kc.setConsumerOffsets(kafkaParams("group.id"), rangesMap).fold( + err => throw new Exception(err.mkString("\n")), + _ => () + ) // this is the "0 messages" case val rdd2 = getRdd(kc, Set(topic)) @@ -101,6 +113,8 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { val sentOnlyOne = Map("d" -> 1) kafkaTestUtils.sendMessages(topic, sentOnlyOne) + kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount + 1) + assert(rdd2.isDefined) assert(rdd2.get.count === 0, "got messages when there shouldn't be any") diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 86f611d55aa8a..7edd627b20918 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -372,6 +372,31 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali PageRank.runUntilConvergence(graph, tol, resetProb) } + + /** + * Run personalized PageRank for a given vertex, such that all random walks + * are started relative to the source node. + * + * @see [[org.apache.spark.graphx.lib.PageRank$#runUntilConvergenceWithOptions]] + */ + def personalizedPageRank(src: VertexId, tol: Double, + resetProb: Double = 0.15) : Graph[Double, Double] = { + PageRank.runUntilConvergenceWithOptions(graph, tol, resetProb, Some(src)) + } + + /** + * Run Personalized PageRank for a fixed number of iterations with + * with all iterations originating at the source node + * returning a graph with vertex attributes + * containing the PageRank and edge attributes the normalized edge weight. + * + * @see [[org.apache.spark.graphx.lib.PageRank$#runWithOptions]] + */ + def staticPersonalizedPageRank(src: VertexId, numIter: Int, + resetProb: Double = 0.15) : Graph[Double, Double] = { + PageRank.runWithOptions(graph, numIter, resetProb, Some(src)) + } + /** * Run PageRank for a fixed number of iterations returning a graph with vertex attributes * containing the PageRank and edge attributes the normalized edge weight. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 042e366a29f58..bc974b2f04e70 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -18,6 +18,7 @@ package org.apache.spark.graphx.lib import scala.reflect.ClassTag +import scala.language.postfixOps import org.apache.spark.Logging import org.apache.spark.graphx._ @@ -60,6 +61,7 @@ import org.apache.spark.graphx._ */ object PageRank extends Logging { + /** * Run PageRank for a fixed number of iterations returning a graph * with vertex attributes containing the PageRank and edge @@ -74,10 +76,33 @@ object PageRank extends Logging { * * @return the graph containing with each vertex containing the PageRank and each edge * containing the normalized weight. + */ + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int, + resetProb: Double = 0.15): Graph[Double, Double] = + { + runWithOptions(graph, numIter, resetProb) + } + + /** + * Run PageRank for a fixed number of iterations returning a graph + * with vertex attributes containing the PageRank and edge + * attributes the normalized edge weight. + * + * @tparam VD the original vertex attribute (not used) + * @tparam ED the original edge attribute (not used) + * + * @param graph the graph on which to compute PageRank + * @param numIter the number of iterations of PageRank to run + * @param resetProb the random reset probability (alpha) + * @param srcId the source vertex for a Personalized Page Rank (optional) + * + * @return the graph containing with each vertex containing the PageRank and each edge + * containing the normalized weight. * */ - def run[VD: ClassTag, ED: ClassTag]( - graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15): Graph[Double, Double] = + def runWithOptions[VD: ClassTag, ED: ClassTag]( + graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15, + srcId: Option[VertexId] = None): Graph[Double, Double] = { // Initialize the PageRank graph with each edge attribute having // weight 1/outDegree and each vertex with attribute 1.0. @@ -89,6 +114,10 @@ object PageRank extends Logging { // Set the vertex attributes to the initial pagerank values .mapVertices( (id, attr) => resetProb ) + val personalized = srcId isDefined + val src: VertexId = srcId.getOrElse(-1L) + def delta(u: VertexId, v: VertexId):Double = { if (u == v) 1.0 else 0.0 } + var iteration = 0 var prevRankGraph: Graph[Double, Double] = null while (iteration < numIter) { @@ -103,8 +132,14 @@ object PageRank extends Logging { // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the // edge partitions. prevRankGraph = rankGraph + val rPrb = if (personalized) { + (src: VertexId ,id: VertexId) => resetProb * delta(src,id) + } else { + (src: VertexId, id: VertexId) => resetProb + } + rankGraph = rankGraph.joinVertices(rankUpdates) { - (id, oldRank, msgSum) => resetProb + (1.0 - resetProb) * msgSum + (id, oldRank, msgSum) => rPrb(src,id) + (1.0 - resetProb) * msgSum }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -133,7 +168,29 @@ object PageRank extends Logging { * containing the normalized weight. */ def runUntilConvergence[VD: ClassTag, ED: ClassTag]( - graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = + graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = + { + runUntilConvergenceWithOptions(graph, tol, resetProb) + } + + /** + * Run a dynamic version of PageRank returning a graph with vertex attributes containing the + * PageRank and edge attributes containing the normalized edge weight. + * + * @tparam VD the original vertex attribute (not used) + * @tparam ED the original edge attribute (not used) + * + * @param graph the graph on which to compute PageRank + * @param tol the tolerance allowed at convergence (smaller => more accurate). + * @param resetProb the random reset probability (alpha) + * @param srcId the source vertex for a Personalized Page Rank (optional) + * + * @return the graph containing with each vertex containing the PageRank and each edge + * containing the normalized weight. + */ + def runUntilConvergenceWithOptions[VD: ClassTag, ED: ClassTag]( + graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15, + srcId: Option[VertexId] = None): Graph[Double, Double] = { // Initialize the pagerankGraph with each edge attribute // having weight 1/outDegree and each vertex with attribute 1.0. @@ -148,6 +205,10 @@ object PageRank extends Logging { .mapVertices( (id, attr) => (0.0, 0.0) ) .cache() + val personalized = srcId.isDefined + val src: VertexId = srcId.getOrElse(-1L) + + // Define the three functions needed to implement PageRank in the GraphX // version of Pregel def vertexProgram(id: VertexId, attr: (Double, Double), msgSum: Double): (Double, Double) = { @@ -156,7 +217,18 @@ object PageRank extends Logging { (newPR, newPR - oldPR) } - def sendMessage(edge: EdgeTriplet[(Double, Double), Double]): Iterator[(VertexId, Double)] = { + def personalizedVertexProgram(id: VertexId, attr: (Double, Double), + msgSum: Double): (Double, Double) = { + val (oldPR, lastDelta) = attr + var teleport = oldPR + val delta = if (src==id) 1.0 else 0.0 + teleport = oldPR*delta + + val newPR = teleport + (1.0 - resetProb) * msgSum + (newPR, newPR - oldPR) + } + + def sendMessage(edge: EdgeTriplet[(Double, Double), Double]) = { if (edge.srcAttr._2 > tol) { Iterator((edge.dstId, edge.srcAttr._2 * edge.attr)) } else { @@ -170,8 +242,17 @@ object PageRank extends Logging { val initialMessage = resetProb / (1.0 - resetProb) // Execute a dynamic version of Pregel. + val vp = if (personalized) { + (id: VertexId, attr: (Double, Double),msgSum: Double) => + personalizedVertexProgram(id, attr, msgSum) + } else { + (id: VertexId, attr: (Double, Double), msgSum: Double) => + vertexProgram(id, attr, msgSum) + } + Pregel(pagerankGraph, initialMessage, activeDirection = EdgeDirection.Out)( - vertexProgram, sendMessage, messageCombiner) + vp, sendMessage, messageCombiner) .mapVertices((vid, attr) => attr._1) } // end of deltaPageRank + } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 95804b07b1db0..3f3c9dfd7b3dd 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -92,6 +92,36 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } } // end of test Star PageRank + test("Star PersonalPageRank") { + withSpark { sc => + val nVertices = 100 + val starGraph = GraphGenerators.starGraph(sc, nVertices).cache() + val resetProb = 0.15 + val errorTol = 1.0e-5 + + val staticRanks1 = starGraph.staticPersonalizedPageRank(0,numIter = 1, resetProb).vertices + val staticRanks2 = starGraph.staticPersonalizedPageRank(0,numIter = 2, resetProb) + .vertices.cache() + + // Static PageRank should only take 2 iterations to converge + val notMatching = staticRanks1.innerZipJoin(staticRanks2) { (vid, pr1, pr2) => + if (pr1 != pr2) 1 else 0 + }.map { case (vid, test) => test }.sum + assert(notMatching === 0) + + val staticErrors = staticRanks2.map { case (vid, pr) => + val correct = (vid > 0 && pr == resetProb) || + (vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * + (nVertices - 1)) )) < 1.0E-5) + if (!correct) 1 else 0 + } + assert(staticErrors.sum === 0) + + val dynamicRanks = starGraph.personalizedPageRank(0,0, resetProb).vertices.cache() + assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) + } + } // end of test Star PageRank + test("Grid PageRank") { withSpark { sc => val rows = 10 @@ -128,4 +158,21 @@ class PageRankSuite extends FunSuite with LocalSparkContext { assert(compareRanks(staticRanks, dynamicRanks) < errorTol) } } + + test("Chain PersonalizedPageRank") { + withSpark { sc => + val chain1 = (0 until 9).map(x => (x, x + 1) ) + val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) } + val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() + val resetProb = 0.15 + val tol = 0.0001 + val numIter = 10 + val errorTol = 1.0e-1 + + val staticRanks = chain.staticPersonalizedPageRank(4, numIter, resetProb).vertices + val dynamicRanks = chain.personalizedPageRank(4, tol, resetProb).vertices + + assert(compareRanks(staticRanks, dynamicRanks) < errorTol) + } + } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index 8526d2e7cfa3f..229000087688f 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -69,8 +69,10 @@ class SparkSubmitOptionParser { // YARN-only options. protected final String ARCHIVES = "--archives"; protected final String EXECUTOR_CORES = "--executor-cores"; - protected final String QUEUE = "--queue"; + protected final String KEYTAB = "--keytab"; protected final String NUM_EXECUTORS = "--num-executors"; + protected final String PRINCIPAL = "--principal"; + protected final String QUEUE = "--queue"; /** * This is the canonical list of spark-submit options. Each entry in the array contains the @@ -96,11 +98,13 @@ class SparkSubmitOptionParser { { EXECUTOR_MEMORY }, { FILES }, { JARS }, + { KEYTAB }, { KILL_SUBMISSION }, { MASTER }, { NAME }, { NUM_EXECUTORS }, { PACKAGES }, + { PRINCIPAL }, { PROPERTIES_FILE }, { PROXY_USER }, { PY_FILES }, diff --git a/make-distribution.sh b/make-distribution.sh index cb65932b4abc0..92177e19fe6be 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -26,6 +26,7 @@ set -o pipefail set -e +set -x # Figure out where the Spark framework is installed SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" @@ -126,7 +127,7 @@ if [ ! $(command -v "$MVN") ] ; then exit -1; fi -VERSION=$("$MVN" help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) +VERSION=$("$MVN" help:evaluate -Dexpression=project.version $@ 2>/dev/null | grep -v "INFO" | tail -n 1) SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7b2a451ca5ee5..5e781a326d98c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -25,9 +25,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.sql.{Column, DataFrame, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -53,13 +51,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { val inputColNames = map(inputCols) val args = inputColNames.map { c => schema(c).dataType match { - case DoubleType => UnresolvedAttribute(c) - case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) - case _: NumericType | BooleanType => - Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() + case DoubleType => dataset(c) + case _: VectorUDT => dataset(c) + case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } - dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol))) + dataset.select(col("*"), assembleFunc(struct(args : _*)).as(map(outputCol))) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 3fe69b1bd8851..b8d073fa16b4b 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -36,6 +36,7 @@ import org.apache.spark.network.server.TransportChannelHandler; import org.apache.spark.network.server.TransportRequestHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.NettyUtils; import org.apache.spark.network.util.TransportConf; @@ -82,13 +83,21 @@ public TransportClientFactory createClientFactory() { } /** Create a server which will attempt to bind to a specific port. */ - public TransportServer createServer(int port) { - return new TransportServer(this, port); + public TransportServer createServer(int port, List bootstraps) { + return new TransportServer(this, port, rpcHandler, bootstraps); } /** Creates a new server, binding to any available ephemeral port. */ + public TransportServer createServer(List bootstraps) { + return createServer(0, bootstraps); + } + public TransportServer createServer() { - return new TransportServer(this, 0); + return createServer(0, Lists.newArrayList()); + } + + public TransportChannelHandler initializePipeline(SocketChannel channel) { + return initializePipeline(channel, rpcHandler); } /** @@ -96,13 +105,18 @@ public TransportServer createServer() { * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or * response messages. * + * @param channel The channel to initialize. + * @param channelRpcHandler The RPC handler to use for the channel. + * * @return Returns the created TransportChannelHandler, which includes a TransportClient that can * be used to communicate on this channel. The TransportClient is directly associated with a * ChannelHandler to ensure all users of the same channel get the same TransportClient object. */ - public TransportChannelHandler initializePipeline(SocketChannel channel) { + public TransportChannelHandler initializePipeline( + SocketChannel channel, + RpcHandler channelRpcHandler) { try { - TransportChannelHandler channelHandler = createChannelHandler(channel); + TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); channel.pipeline() .addLast("encoder", encoder) .addLast("frameDecoder", NettyUtils.createFrameDecoder()) @@ -123,7 +137,7 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) { * ResponseMessages. The channel is expected to have been successfully created, though certain * properties (such as the remoteAddress()) may not be available yet. */ - private TransportChannelHandler createChannelHandler(Channel channel) { + private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler rpcHandler) { TransportResponseHandler responseHandler = new TransportResponseHandler(channel); TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java index 65e8020e34121..eaae2ee043c5a 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java @@ -17,6 +17,8 @@ package org.apache.spark.network.client; +import io.netty.channel.Channel; + /** * A bootstrap which is executed on a TransportClient before it is returned to the user. * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- @@ -28,5 +30,5 @@ */ public interface TransportClientBootstrap { /** Performs the bootstrapping operation, throwing an exception on failure. */ - public void doBootstrap(TransportClient client) throws RuntimeException; + void doBootstrap(TransportClient client, Channel channel) throws RuntimeException; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index d26b9b4d6055f..4952ffb44bb8b 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -172,12 +172,14 @@ private TransportClient createClient(InetSocketAddress address) throws IOExcepti .option(ChannelOption.ALLOCATOR, pooledAllocator); final AtomicReference clientRef = new AtomicReference(); + final AtomicReference channelRef = new AtomicReference(); bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { TransportChannelHandler clientHandler = context.initializePipeline(ch); clientRef.set(clientHandler.getClient()); + channelRef.set(ch); } }); @@ -192,6 +194,7 @@ public void initChannel(SocketChannel ch) { } TransportClient client = clientRef.get(); + Channel channel = channelRef.get(); assert client != null : "Channel future completed successfully with null client"; // Execute any client bootstraps synchronously before marking the Client as successful. @@ -199,7 +202,7 @@ public void initChannel(SocketChannel ch) { logger.debug("Connection to {} successful, running bootstraps...", address); try { for (TransportClientBootstrap clientBootstrap : clientBootstraps) { - clientBootstrap.doBootstrap(client); + clientBootstrap.doBootstrap(client, channel); } } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000; diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 33aa1344345ff..185ba2ef3bb1f 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -17,8 +17,12 @@ package org.apache.spark.network.sasl; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,14 +37,24 @@ public class SaslClientBootstrap implements TransportClientBootstrap { private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); + private final boolean encrypt; private final TransportConf conf; private final String appId; private final SecretKeyHolder secretKeyHolder; public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { + this(conf, appId, secretKeyHolder, false); + } + + public SaslClientBootstrap( + TransportConf conf, + String appId, + SecretKeyHolder secretKeyHolder, + boolean encrypt) { this.conf = conf; this.appId = appId; this.secretKeyHolder = secretKeyHolder; + this.encrypt = encrypt; } /** @@ -49,8 +63,8 @@ public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder sec * due to mismatch. */ @Override - public void doBootstrap(TransportClient client) { - SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder); + public void doBootstrap(TransportClient client, Channel channel) { + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt); try { byte[] payload = saslClient.firstToken(); @@ -62,13 +76,26 @@ public void doBootstrap(TransportClient client) { byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeoutMs()); payload = saslClient.response(response); } + + if (encrypt) { + if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) { + throw new RuntimeException( + new SaslException("Encryption requests by negotiated non-encrypted connection.")); + } + SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); + saslClient = null; + logger.debug("Channel {} configured for SASL encryption.", client); + } } finally { - try { - // Once authentication is complete, the server will trust all remaining communication. - saslClient.dispose(); - } catch (RuntimeException e) { - logger.error("Error while disposing SASL client", e); + if (saslClient != null) { + try { + // Once authentication is complete, the server will trust all remaining communication. + saslClient.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL client", e); + } } } } + } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java new file mode 100644 index 0000000000000..127335e4d35fb --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryption.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.List; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.FileRegion; +import io.netty.handler.codec.MessageToMessageDecoder; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCountUtil; + +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.NettyUtils; + +/** + * Provides SASL-based encription for transport channels. The single method exposed by this + * class installs the needed channel handlers on a connected channel. + */ +class SaslEncryption { + + @VisibleForTesting + static final String ENCRYPTION_HANDLER_NAME = "saslEncryption"; + + /** + * Adds channel handlers that perform encryption / decryption of data using SASL. + * + * @param channel The channel. + * @param backend The SASL backend. + * @param maxOutboundBlockSize Max size in bytes of outgoing encrypted blocks, to control + * memory usage. + */ + static void addToChannel( + Channel channel, + SaslEncryptionBackend backend, + int maxOutboundBlockSize) { + channel.pipeline() + .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(backend, maxOutboundBlockSize)) + .addFirst("saslDecryption", new DecryptionHandler(backend)) + .addFirst("saslFrameDecoder", NettyUtils.createFrameDecoder()); + } + + private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { + + private final int maxOutboundBlockSize; + private final SaslEncryptionBackend backend; + + EncryptionHandler(SaslEncryptionBackend backend, int maxOutboundBlockSize) { + this.backend = backend; + this.maxOutboundBlockSize = maxOutboundBlockSize; + } + + /** + * Wrap the incoming message in an implementation that will perform encryption lazily. This is + * needed to guarantee ordering of the outgoing encrypted packets - they need to be decrypted in + * the same order, and netty doesn't have an atomic ChannelHandlerContext.write() API, so it + * does not guarantee any ordering. + */ + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + + ctx.write(new EncryptedMessage(backend, msg, maxOutboundBlockSize), promise); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + try { + backend.dispose(); + } finally { + super.handlerRemoved(ctx); + } + } + + } + + private static class DecryptionHandler extends MessageToMessageDecoder { + + private final SaslEncryptionBackend backend; + + DecryptionHandler(SaslEncryptionBackend backend) { + this.backend = backend; + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) + throws Exception { + + byte[] data; + int offset; + int length = msg.readableBytes(); + if (msg.hasArray()) { + data = msg.array(); + offset = msg.arrayOffset(); + msg.skipBytes(length); + } else { + data = new byte[length]; + msg.readBytes(data); + offset = 0; + } + + out.add(Unpooled.wrappedBuffer(backend.unwrap(data, offset, length))); + } + + } + + @VisibleForTesting + static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { + + private final SaslEncryptionBackend backend; + private final boolean isByteBuf; + private final ByteBuf buf; + private final FileRegion region; + + /** + * A channel used to buffer input data for encryption. The channel has an upper size bound + * so that if the input is larger than the allowed buffer, it will be broken into multiple + * chunks. + */ + private final ByteArrayWritableChannel byteChannel; + + private ByteBuf currentHeader; + private ByteBuffer currentChunk; + private long currentChunkSize; + private long currentReportedBytes; + private long unencryptedChunkSize; + private long transferred; + + EncryptedMessage(SaslEncryptionBackend backend, Object msg, int maxOutboundBlockSize) { + Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, + "Unrecognized message type: %s", msg.getClass().getName()); + this.backend = backend; + this.isByteBuf = msg instanceof ByteBuf; + this.buf = isByteBuf ? (ByteBuf) msg : null; + this.region = isByteBuf ? null : (FileRegion) msg; + this.byteChannel = new ByteArrayWritableChannel(maxOutboundBlockSize); + } + + /** + * Returns the size of the original (unencrypted) message. + * + * This makes assumptions about how netty treats FileRegion instances, because there's no way + * to know beforehand what will be the size of the encrypted message. Namely, it assumes + * that netty will try to transfer data from this message while + * transfered() < count(). So these two methods return, technically, wrong data, + * but netty doesn't know better. + */ + @Override + public long count() { + return isByteBuf ? buf.readableBytes() : region.count(); + } + + @Override + public long position() { + return 0; + } + + /** + * Returns an approximation of the amount of data transferred. See {@link #count()}. + */ + @Override + public long transfered() { + return transferred; + } + + /** + * Transfers data from the original message to the channel, encrypting it in the process. + * + * This method also breaks down the original message into smaller chunks when needed. This + * is done to keep memory usage under control. This avoids having to copy the whole message + * data into memory at once, and can avoid ballooning memory usage when transferring large + * messages such as shuffle blocks. + * + * The {@link #transfered()} counter also behaves a little funny, in that it won't go forward + * until a whole chunk has been written. This is done because the code can't use the actual + * number of bytes written to the channel as the transferred count (see {@link #count()}). + * Instead, once an encrypted chunk is written to the output (including its header), the + * size of the original block will be added to the {@link #transfered()} amount. + */ + @Override + public long transferTo(final WritableByteChannel target, final long position) + throws IOException { + + Preconditions.checkArgument(position == transfered(), "Invalid position."); + + long reportedWritten = 0L; + long actuallyWritten = 0L; + do { + if (currentChunk == null) { + nextChunk(); + } + + if (currentHeader.readableBytes() > 0) { + int bytesWritten = target.write(currentHeader.nioBuffer()); + currentHeader.skipBytes(bytesWritten); + actuallyWritten += bytesWritten; + if (currentHeader.readableBytes() > 0) { + // Break out of loop if there are still header bytes left to write. + break; + } + } + + actuallyWritten += target.write(currentChunk); + if (!currentChunk.hasRemaining()) { + // Only update the count of written bytes once a full chunk has been written. + // See method javadoc. + long chunkBytesRemaining = unencryptedChunkSize - currentReportedBytes; + reportedWritten += chunkBytesRemaining; + transferred += chunkBytesRemaining; + currentHeader.release(); + currentHeader = null; + currentChunk = null; + currentChunkSize = 0; + currentReportedBytes = 0; + } + } while (currentChunk == null && transfered() + reportedWritten < count()); + + // Returning 0 triggers a backoff mechanism in netty which may harm performance. Instead, + // we return 1 until we can (i.e. until the reported count would actually match the size + // of the current chunk), at which point we resort to returning 0 so that the counts still + // match, at the cost of some performance. That situation should be rare, though. + if (reportedWritten != 0L) { + return reportedWritten; + } + + if (actuallyWritten > 0 && currentReportedBytes < currentChunkSize - 1) { + transferred += 1L; + currentReportedBytes += 1L; + return 1L; + } + + return 0L; + } + + private void nextChunk() throws IOException { + byteChannel.reset(); + if (isByteBuf) { + int copied = byteChannel.write(buf.nioBuffer()); + buf.skipBytes(copied); + } else { + region.transferTo(byteChannel, region.transfered()); + } + + byte[] encrypted = backend.wrap(byteChannel.getData(), 0, byteChannel.length()); + this.currentChunk = ByteBuffer.wrap(encrypted); + this.currentChunkSize = encrypted.length; + this.currentHeader = Unpooled.copyLong(8 + currentChunkSize); + this.unencryptedChunkSize = byteChannel.length(); + } + + @Override + protected void deallocate() { + if (currentHeader != null) { + currentHeader.release(); + } + if (buf != null) { + buf.release(); + } + if (region != null) { + region.release(); + } + } + + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java new file mode 100644 index 0000000000000..89b78bc7e1df1 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslEncryptionBackend.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.sasl.SaslException; + +interface SaslEncryptionBackend { + + /** Disposes of resources used by the backend. */ + void dispose(); + + /** Encrypt data. */ + byte[] wrap(byte[] data, int offset, int len) throws SaslException; + + /** Decrypt data. */ + byte[] unwrap(byte[] data, int offset, int len) throws SaslException; + +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 026cbd260d16c..be6165caf3c74 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -17,10 +17,10 @@ package org.apache.spark.network.sasl; -import java.util.concurrent.ConcurrentMap; +import javax.security.sasl.Sasl; -import com.google.common.collect.Maps; import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,6 +28,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.TransportConf; /** * RPC Handler which performs SASL authentication before delegating to a child RPC handler. @@ -37,8 +38,14 @@ * Note that the authentication process consists of multiple challenge-response pairs, each of * which are individual RPCs. */ -public class SaslRpcHandler extends RpcHandler { - private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); +class SaslRpcHandler extends RpcHandler { + private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); + + /** Transport configuration. */ + private final TransportConf conf; + + /** The client channel. */ + private final Channel channel; /** RpcHandler we will delegate to for authenticated connections. */ private final RpcHandler delegate; @@ -46,19 +53,25 @@ public class SaslRpcHandler extends RpcHandler { /** Class which provides secret keys which are shared by server and client on a per-app basis. */ private final SecretKeyHolder secretKeyHolder; - /** Maps each channel to its SASL authentication state. */ - private final ConcurrentMap channelAuthenticationMap; + private SparkSaslServer saslServer; + private boolean isComplete; - public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { + SaslRpcHandler( + TransportConf conf, + Channel channel, + RpcHandler delegate, + SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.channel = channel; this.delegate = delegate; this.secretKeyHolder = secretKeyHolder; - this.channelAuthenticationMap = Maps.newConcurrentMap(); + this.saslServer = null; + this.isComplete = false; } @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - SparkSaslServer saslServer = channelAuthenticationMap.get(client); - if (saslServer != null && saslServer.isComplete()) { + if (isComplete) { // Authentication complete, delegate to base handler. delegate.receive(client, message, callback); return; @@ -68,15 +81,30 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback if (saslServer == null) { // First message in the handshake, setup the necessary state. - saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder); - channelAuthenticationMap.put(client, saslServer); + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, + conf.saslServerAlwaysEncrypt()); } byte[] response = saslServer.response(saslMessage.payload); + callback.onSuccess(response); + + // Setup encryption after the SASL response is sent, otherwise the client can't parse the + // response. It's ok to change the channel pipeline here since we are processing an incoming + // message, so the pipeline is busy and no new incoming messages will be fed to it before this + // method returns. This assumes that the code ensures, through other means, that no outbound + // messages are being written to the channel while negotiation is still going on. if (saslServer.isComplete()) { logger.debug("SASL authentication successful for channel {}", client); + isComplete = true; + if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { + logger.debug("Enabling encryption for channel {}", client); + SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); + saslServer = null; + } else { + saslServer.dispose(); + saslServer = null; + } } - callback.onSuccess(response); } @Override @@ -86,9 +114,9 @@ public StreamManager getStreamManager() { @Override public void connectionTerminated(TransportClient client) { - SparkSaslServer saslServer = channelAuthenticationMap.remove(client); if (saslServer != null) { saslServer.dispose(); } } + } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java new file mode 100644 index 0000000000000..f2f983856f444 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslServerBootstrap.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.channel.Channel; + +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * A bootstrap which is executed on a TransportServer's client channel once a client connects + * to the server. This allows customizing the client channel to allow for things such as SASL + * authentication. + */ +public class SaslServerBootstrap implements TransportServerBootstrap { + + private final TransportConf conf; + private final SecretKeyHolder secretKeyHolder; + + public SaslServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + } + + /** + * Wrap the given application handler in a SaslRpcHandler that will handle the initial SASL + * negotiation. + */ + public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { + return new SaslRpcHandler(conf, channel, rpcHandler, secretKeyHolder); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index 9abad1f30a259..94685e91b862e 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -17,6 +17,8 @@ package org.apache.spark.network.sasl; +import java.io.IOException; +import java.util.Map; import javax.security.auth.callback.Callback; import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.NameCallback; @@ -27,9 +29,9 @@ import javax.security.sasl.Sasl; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; -import java.io.IOException; import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,19 +42,25 @@ * initial state to the "authenticated" state. This client initializes the protocol via a * firstToken, which is then followed by a set of challenges and responses. */ -public class SparkSaslClient { +public class SparkSaslClient implements SaslEncryptionBackend { private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class); private final String secretKeyId; private final SecretKeyHolder secretKeyHolder; + private final String expectedQop; private SaslClient saslClient; - public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) { + public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) { this.secretKeyId = secretKeyId; this.secretKeyHolder = secretKeyHolder; + this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH; + + Map saslProps = ImmutableMap.builder() + .put(Sasl.QOP, expectedQop) + .build(); try { this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, - SASL_PROPS, new ClientCallbackHandler()); + saslProps, new ClientCallbackHandler()); } catch (SaslException e) { throw Throwables.propagate(e); } @@ -76,6 +84,11 @@ public synchronized boolean isComplete() { return saslClient != null && saslClient.isComplete(); } + /** Returns the value of a negotiated property. */ + public Object getNegotiatedProperty(String name) { + return saslClient.getNegotiatedProperty(name); + } + /** * Respond to server's SASL token. * @param token contains server's SASL token @@ -93,6 +106,7 @@ public synchronized byte[] response(byte[] token) { * Disposes of any system resources or security-sensitive information the * SaslClient might be using. */ + @Override public synchronized void dispose() { if (saslClient != null) { try { @@ -134,4 +148,15 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback } } } + + @Override + public byte[] wrap(byte[] data, int offset, int len) throws SaslException { + return saslClient.wrap(data, offset, len); + } + + @Override + public byte[] unwrap(byte[] data, int offset, int len) throws SaslException { + return saslClient.unwrap(data, offset, len); + } + } diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index e87b17ead1e1a..431cb67a2ae0b 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -44,7 +44,7 @@ * initial state to the "authenticated" state. (It is not a server in the sense of accepting * connections on some socket.) */ -public class SparkSaslServer { +public class SparkSaslServer implements SaslEncryptionBackend { private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); /** @@ -60,26 +60,37 @@ public class SparkSaslServer { static final String DIGEST = "DIGEST-MD5"; /** - * The quality of protection is just "auth". This means that we are doing - * authentication only, we are not supporting integrity or privacy protection of the - * communication channel after authentication. This could be changed to be configurable - * in the future. + * Quality of protection value that includes encryption. */ - static final Map SASL_PROPS = ImmutableMap.builder() - .put(Sasl.QOP, "auth") - .put(Sasl.SERVER_AUTH, "true") - .build(); + static final String QOP_AUTH_CONF = "auth-conf"; + + /** + * Quality of protection value that does not include encryption. + */ + static final String QOP_AUTH = "auth"; /** Identifier for a certain secret key within the secretKeyHolder. */ private final String secretKeyId; private final SecretKeyHolder secretKeyHolder; private SaslServer saslServer; - public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { + public SparkSaslServer( + String secretKeyId, + SecretKeyHolder secretKeyHolder, + boolean alwaysEncrypt) { this.secretKeyId = secretKeyId; this.secretKeyHolder = secretKeyHolder; + + // Sasl.QOP is a comma-separated list of supported values. The value that allows encryption + // is listed first since it's preferred over the non-encrypted one (if the client also + // lists both in the request). + String qop = alwaysEncrypt ? QOP_AUTH_CONF : String.format("%s,%s", QOP_AUTH_CONF, QOP_AUTH); + Map saslProps = ImmutableMap.builder() + .put(Sasl.SERVER_AUTH, "true") + .put(Sasl.QOP, qop) + .build(); try { - this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, + this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps, new DigestCallbackHandler()); } catch (SaslException e) { throw Throwables.propagate(e); @@ -93,6 +104,11 @@ public synchronized boolean isComplete() { return saslServer != null && saslServer.isComplete(); } + /** Returns the value of a negotiated property. */ + public Object getNegotiatedProperty(String name) { + return saslServer.getNegotiatedProperty(name); + } + /** * Used to respond to server SASL tokens. * @param token Server's SASL token @@ -110,6 +126,7 @@ public synchronized byte[] response(byte[] token) { * Disposes of any system resources or security-sensitive information the * SaslServer might be using. */ + @Override public synchronized void dispose() { if (saslServer != null) { try { @@ -122,6 +139,16 @@ public synchronized void dispose() { } } + @Override + public byte[] wrap(byte[] data, int offset, int len) throws SaslException { + return saslServer.wrap(data, offset, len); + } + + @Override + public byte[] unwrap(byte[] data, int offset, int len) throws SaslException { + return saslServer.unwrap(data, offset, len); + } + /** * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. */ diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index a6d390e13f396..c95e64e8e2cda 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -20,14 +20,18 @@ import java.util.Iterator; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import com.google.common.base.Preconditions; + /** * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually * fetched as chunks by the client. Each registered buffer is one chunk. @@ -36,18 +40,21 @@ public class OneForOneStreamManager extends StreamManager { private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); private final AtomicLong nextStreamId; - private final Map streams; + private final ConcurrentHashMap streams; /** State of a single stream. */ private static class StreamState { final Iterator buffers; + // The channel associated to the stream + Channel associatedChannel = null; + // Used to keep track of the index of the buffer that the user has retrieved, just to ensure // that the caller only requests each chunk one at a time, in order. int curChunk = 0; StreamState(Iterator buffers) { - this.buffers = buffers; + this.buffers = Preconditions.checkNotNull(buffers); } } @@ -58,6 +65,13 @@ public OneForOneStreamManager() { streams = new ConcurrentHashMap(); } + @Override + public void registerChannel(Channel channel, long streamId) { + if (streams.containsKey(streamId)) { + streams.get(streamId).associatedChannel = channel; + } + } + @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { StreamState state = streams.get(streamId); @@ -80,12 +94,17 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { } @Override - public void connectionTerminated(long streamId) { - // Release all remaining buffers. - StreamState state = streams.remove(streamId); - if (state != null && state.buffers != null) { - while (state.buffers.hasNext()) { - state.buffers.next().release(); + public void connectionTerminated(Channel channel) { + // Close all streams which have been associated with the channel. + for (Map.Entry entry: streams.entrySet()) { + StreamState state = entry.getValue(); + if (state.associatedChannel == channel) { + streams.remove(entry.getKey()); + + // Release all remaining buffers. + while (state.buffers.hasNext()) { + state.buffers.next().release(); + } } } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 5a9a14a180c10..929f789bf9d24 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -17,6 +17,8 @@ package org.apache.spark.network.server; +import io.netty.channel.Channel; + import org.apache.spark.network.buffer.ManagedBuffer; /** @@ -44,9 +46,18 @@ public abstract class StreamManager { public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); /** - * Indicates that the TCP connection that was tied to the given stream has been terminated. After - * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned - * up. + * Associates a stream with a single client connection, which is guaranteed to be the only reader + * of the stream. The getChunk() method will be called serially on this connection and once the + * connection is closed, the stream will never be used again, enabling cleanup. + * + * This must be called before the first getChunk() on the stream, but it may be invoked multiple + * times with the same channel and stream id. + */ + public void registerChannel(Channel channel, long streamId) { } + + /** + * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not + * to read from the associated streams again, so any state can be cleaned up. */ - public void connectionTerminated(long streamId) { } + public void connectionTerminated(Channel channel) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 1580180cc17e9..e5159ab56d0d4 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -17,10 +17,7 @@ package org.apache.spark.network.server; -import java.util.Set; - import com.google.common.base.Throwables; -import com.google.common.collect.Sets; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; @@ -62,9 +59,6 @@ public class TransportRequestHandler extends MessageHandler { /** Returns each chunk part of a stream. */ private final StreamManager streamManager; - /** List of all stream ids that have been read on this handler, used for cleanup. */ - private final Set streamIds; - public TransportRequestHandler( Channel channel, TransportClient reverseClient, @@ -73,7 +67,6 @@ public TransportRequestHandler( this.reverseClient = reverseClient; this.rpcHandler = rpcHandler; this.streamManager = rpcHandler.getStreamManager(); - this.streamIds = Sets.newHashSet(); } @Override @@ -82,10 +75,7 @@ public void exceptionCaught(Throwable cause) { @Override public void channelUnregistered() { - // Inform the StreamManager that these streams will no longer be read from. - for (long streamId : streamIds) { - streamManager.connectionTerminated(streamId); - } + streamManager.connectionTerminated(channel); rpcHandler.connectionTerminated(reverseClient); } @@ -102,12 +92,12 @@ public void handle(RequestMessage request) { private void processFetchRequest(final ChunkFetchRequest req) { final String client = NettyUtils.getRemoteAddress(channel); - streamIds.add(req.streamChunkId.streamId); logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); ManagedBuffer buf; try { + streamManager.registerChannel(channel, req.streamChunkId.streamId); buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); } catch (Exception e) { logger.error(String.format( diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index b7ce8541e565e..941ef95772e16 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -19,8 +19,11 @@ import java.io.Closeable; import java.net.InetSocketAddress; +import java.util.List; import java.util.concurrent.TimeUnit; +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.ChannelFuture; @@ -44,15 +47,23 @@ public class TransportServer implements Closeable { private final TransportContext context; private final TransportConf conf; + private final RpcHandler appRpcHandler; + private final List bootstraps; private ServerBootstrap bootstrap; private ChannelFuture channelFuture; private int port = -1; /** Creates a TransportServer that binds to the given port, or to any available if 0. */ - public TransportServer(TransportContext context, int portToBind) { + public TransportServer( + TransportContext context, + int portToBind, + RpcHandler appRpcHandler, + List bootstraps) { this.context = context; this.conf = context.getConf(); + this.appRpcHandler = appRpcHandler; + this.bootstraps = Lists.newArrayList(Preconditions.checkNotNull(bootstraps)); init(portToBind); } @@ -95,7 +106,11 @@ private void init(int portToBind) { bootstrap.childHandler(new ChannelInitializer() { @Override protected void initChannel(SocketChannel ch) throws Exception { - context.initializePipeline(ch); + RpcHandler rpcHandler = appRpcHandler; + for (TransportServerBootstrap bootstrap : bootstraps) { + rpcHandler = bootstrap.doBootstrap(ch, rpcHandler); + } + context.initializePipeline(ch, rpcHandler); } }); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java new file mode 100644 index 0000000000000..05803ab1bb059 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServerBootstrap.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import io.netty.channel.Channel; + +/** + * A bootstrap which is executed on a TransportServer's client channel once a client connects + * to the server. This allows customizing the client channel to allow for things such as SASL + * authentication. + */ +public interface TransportServerBootstrap { + /** + * Customizes the channel to include new features, if needed. + * + * @param channel The connected channel opened by the client. + * @param rpcHandler The RPC handler for the server. + * @return The RPC handler to use for the channel. + */ + RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler); +} diff --git a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java b/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java similarity index 70% rename from network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java rename to network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java index b525ed69fc9fb..b1415720045e2 100644 --- a/network/common/src/test/java/org/apache/spark/network/ByteArrayWritableChannel.java +++ b/network/common/src/main/java/org/apache/spark/network/util/ByteArrayWritableChannel.java @@ -15,11 +15,14 @@ * limitations under the License. */ -package org.apache.spark.network; +package org.apache.spark.network.util; import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; +/** + * A writable channel that stores the written data in a byte array in memory. + */ public class ByteArrayWritableChannel implements WritableByteChannel { private final byte[] data; @@ -27,19 +30,30 @@ public class ByteArrayWritableChannel implements WritableByteChannel { public ByteArrayWritableChannel(int size) { this.data = new byte[size]; - this.offset = 0; } public byte[] getData() { return data; } + public int length() { + return offset; + } + + /** Resets the channel so that writing to it will overwrite the existing buffer. */ + public void reset() { + offset = 0; + } + + /** + * Reads from the given buffer into the internal byte array. + */ @Override public int write(ByteBuffer src) { - int available = src.remaining(); - src.get(data, offset, available); - offset += available; - return available; + int toTransfer = Math.min(src.remaining(), data.length - offset); + src.get(data, offset, toTransfer); + offset += toTransfer; + return toTransfer; } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java index 0aef7f1987315..3b2eff377955a 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -17,6 +17,8 @@ package org.apache.spark.network.util; +import com.google.common.primitives.Ints; + /** * A central location that tracks all the settings we expose to users. */ @@ -112,4 +114,20 @@ public boolean lazyFileDescriptor() { public int portMaxRetries() { return conf.getInt("spark.port.maxRetries", 16); } + + /** + * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled. + */ + public int maxSaslEncryptedBlockSize() { + return Ints.checkedCast(JavaUtils.byteStringAsBytes( + conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k"))); + } + + /** + * Whether the server should enforce encryption on SASL-authenticated connections. + */ + public boolean saslServerAlwaysEncrypt() { + return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); + } + } diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java index 860dd6d9b3915..d500bc3c98a78 100644 --- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -39,6 +39,7 @@ import org.apache.spark.network.protocol.RpcRequest; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.NettyUtils; public class ProtocolSuite { diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index ff985096d72d5..6c98e733b462f 100644 --- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -29,7 +29,7 @@ import static org.junit.Assert.*; -import org.apache.spark.network.ByteArrayWritableChannel; +import org.apache.spark.network.util.ByteArrayWritableChannel; public class MessageWithHeaderSuite { diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 23b4e06f064e1..be6632bb8cf49 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -17,12 +17,47 @@ package org.apache.spark.network.sasl; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static com.google.common.base.Charsets.UTF_8; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; +import java.io.File; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import javax.security.sasl.SaslException; + +import com.google.common.collect.Lists; +import com.google.common.io.ByteStreams; +import com.google.common.io.Files; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; /** * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. @@ -44,8 +79,8 @@ public String getSecretKey(String appId) { @Test public void testMatching() { - SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder); - SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder); + SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder, false); + SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder, false); assertFalse(client.isComplete()); assertFalse(server.isComplete()); @@ -64,11 +99,10 @@ public void testMatching() { assertFalse(client.isComplete()); } - @Test public void testNonMatching() { - SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder); - SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder); + SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder, false); + SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder, false); assertFalse(client.isComplete()); assertFalse(server.isComplete()); @@ -86,4 +120,312 @@ public void testNonMatching() { assertFalse(server.isComplete()); } } + + @Test + public void testSaslAuthentication() throws Exception { + testBasicSasl(false); + } + + @Test + public void testSaslEncryption() throws Exception { + testBasicSasl(true); + } + + private void testBasicSasl(boolean encrypt) throws Exception { + RpcHandler rpcHandler = mock(RpcHandler.class); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + byte[] message = (byte[]) invocation.getArguments()[1]; + RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; + assertEquals("Ping", new String(message, UTF_8)); + cb.onSuccess("Pong".getBytes(UTF_8)); + return null; + } + }) + .when(rpcHandler) + .receive(any(TransportClient.class), any(byte[].class), any(RpcResponseCallback.class)); + + SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); + try { + byte[] response = ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10)); + assertEquals("Pong", new String(response, UTF_8)); + } finally { + ctx.close(); + } + } + + @Test + public void testEncryptedMessage() throws Exception { + SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class); + byte[] data = new byte[1024]; + new Random().nextBytes(data); + when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data); + + ByteBuf msg = Unpooled.buffer(); + try { + msg.writeBytes(data); + + // Create a channel with a really small buffer compared to the data. This means that on each + // call, the outbound data will not be fully written, so the write() method should return a + // dummy count to keep the channel alive when possible. + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(32); + + SaslEncryption.EncryptedMessage emsg = + new SaslEncryption.EncryptedMessage(backend, msg, 1024); + long count = emsg.transferTo(channel, emsg.transfered()); + assertTrue(count < data.length); + assertTrue(count > 0); + + // Here, the output buffer is full so nothing should be transferred. + assertEquals(0, emsg.transferTo(channel, emsg.transfered())); + + // Now there's room in the buffer, but not enough to transfer all the remaining data, + // so the dummy count should be returned. + channel.reset(); + assertEquals(1, emsg.transferTo(channel, emsg.transfered())); + + // Eventually, the whole message should be transferred. + for (int i = 0; i < data.length / 32 - 2; i++) { + channel.reset(); + assertEquals(1, emsg.transferTo(channel, emsg.transfered())); + } + + channel.reset(); + count = emsg.transferTo(channel, emsg.transfered()); + assertTrue("Unexpected count: " + count, count > 1 && count < data.length); + assertEquals(data.length, emsg.transfered()); + } finally { + msg.release(); + } + } + + @Test + public void testEncryptedMessageChunking() throws Exception { + File file = File.createTempFile("sasltest", ".txt"); + try { + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + + byte[] data = new byte[8 * 1024]; + new Random().nextBytes(data); + Files.write(data, file); + + SaslEncryptionBackend backend = mock(SaslEncryptionBackend.class); + // It doesn't really matter what we return here, as long as it's not null. + when(backend.wrap(any(byte[].class), anyInt(), anyInt())).thenReturn(data); + + FileSegmentManagedBuffer msg = new FileSegmentManagedBuffer(conf, file, 0, file.length()); + SaslEncryption.EncryptedMessage emsg = + new SaslEncryption.EncryptedMessage(backend, msg.convertToNetty(), data.length / 8); + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); + while (emsg.transfered() < emsg.count()) { + channel.reset(); + emsg.transferTo(channel, emsg.transfered()); + } + + verify(backend, times(8)).wrap(any(byte[].class), anyInt(), anyInt()); + } finally { + file.delete(); + } + } + + @Test + public void testFileRegionEncryption() throws Exception { + final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize"; + System.setProperty(blockSizeConf, "1k"); + + final AtomicReference response = new AtomicReference(); + final File file = File.createTempFile("sasltest", ".txt"); + SaslTestCtx ctx = null; + try { + final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + StreamManager sm = mock(StreamManager.class); + when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { + @Override + public ManagedBuffer answer(InvocationOnMock invocation) { + return new FileSegmentManagedBuffer(conf, file, 0, file.length()); + } + }); + + RpcHandler rpcHandler = mock(RpcHandler.class); + when(rpcHandler.getStreamManager()).thenReturn(sm); + + byte[] data = new byte[8 * 1024]; + new Random().nextBytes(data); + Files.write(data, file); + + ctx = new SaslTestCtx(rpcHandler, true, false); + + final Object lock = new Object(); + + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + response.set((ManagedBuffer) invocation.getArguments()[1]); + response.get().retain(); + synchronized (lock) { + lock.notifyAll(); + } + return null; + } + }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); + + synchronized (lock) { + ctx.client.fetchChunk(0, 0, callback); + lock.wait(10 * 1000); + } + + verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); + verify(callback, never()).onFailure(anyInt(), any(Throwable.class)); + + byte[] received = ByteStreams.toByteArray(response.get().createInputStream()); + assertTrue(Arrays.equals(data, received)); + } finally { + file.delete(); + if (ctx != null) { + ctx.close(); + } + if (response.get() != null) { + response.get().release(); + } + System.clearProperty(blockSizeConf); + } + } + + @Test + public void testServerAlwaysEncrypt() throws Exception { + final String alwaysEncryptConfName = "spark.network.sasl.serverAlwaysEncrypt"; + System.setProperty(alwaysEncryptConfName, "true"); + + SaslTestCtx ctx = null; + try { + ctx = new SaslTestCtx(mock(RpcHandler.class), false, false); + fail("Should have failed to connect without encryption."); + } catch (Exception e) { + assertTrue(e.getCause() instanceof SaslException); + } finally { + if (ctx != null) { + ctx.close(); + } + System.clearProperty(alwaysEncryptConfName); + } + } + + @Test + public void testDataEncryptionIsActuallyEnabled() throws Exception { + // This test sets up an encrypted connection but then, using a client bootstrap, removes + // the encryption handler from the client side. This should cause the server to not be + // able to understand RPCs sent to it and thus close the connection. + SaslTestCtx ctx = null; + try { + ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); + ctx.client.sendRpcSync("Ping".getBytes(UTF_8), TimeUnit.SECONDS.toMillis(10)); + fail("Should have failed to send RPC to server."); + } catch (Exception e) { + assertFalse(e.getCause() instanceof TimeoutException); + } finally { + if (ctx != null) { + ctx.close(); + } + } + } + + private static class SaslTestCtx { + + final TransportClient client; + final TransportServer server; + + private final boolean encrypt; + private final boolean disableClientEncryption; + private final EncryptionCheckerBootstrap checker; + + SaslTestCtx( + RpcHandler rpcHandler, + boolean encrypt, + boolean disableClientEncryption) + throws Exception { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + + SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); + when(keyHolder.getSaslUser(anyString())).thenReturn("user"); + when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); + + TransportContext ctx = new TransportContext(conf, rpcHandler); + + this.checker = new EncryptionCheckerBootstrap(); + this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder), + checker)); + + try { + List clientBootstraps = Lists.newArrayList(); + clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt)); + if (disableClientEncryption) { + clientBootstraps.add(new EncryptionDisablerBootstrap()); + } + + this.client = ctx.createClientFactory(clientBootstraps) + .createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + close(); + throw e; + } + + this.encrypt = encrypt; + this.disableClientEncryption = disableClientEncryption; + } + + void close() { + if (!disableClientEncryption) { + assertEquals(encrypt, checker.foundEncryptionHandler); + } + if (client != null) { + client.close(); + } + if (server != null) { + server.close(); + } + } + + } + + private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAdapter + implements TransportServerBootstrap { + + boolean foundEncryptionHandler; + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + if (!foundEncryptionHandler) { + foundEncryptionHandler = + ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null; + } + ctx.write(msg, promise); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + super.handlerRemoved(ctx); + } + + @Override + public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { + channel.pipeline().addFirst("encryptionChecker", this); + return rpcHandler; + } + + } + + private static class EncryptionDisablerBootstrap implements TransportClientBootstrap { + + @Override + public void doBootstrap(TransportClient client, Channel channel) { + channel.pipeline().remove(SaslEncryption.ENCRYPTION_HANDLER_NAME); + } + + } + } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 6e8018b723dc6..612bce571a493 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.List; +import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,6 +47,7 @@ public class ExternalShuffleClient extends ShuffleClient { private final TransportConf conf; private final boolean saslEnabled; + private final boolean saslEncryptionEnabled; private final SecretKeyHolder secretKeyHolder; private TransportClientFactory clientFactory; @@ -58,10 +60,15 @@ public class ExternalShuffleClient extends ShuffleClient { public ExternalShuffleClient( TransportConf conf, SecretKeyHolder secretKeyHolder, - boolean saslEnabled) { + boolean saslEnabled, + boolean saslEncryptionEnabled) { + Preconditions.checkArgument( + !saslEncryptionEnabled || saslEnabled, + "SASL encryption can only be enabled if SASL is also enabled."); this.conf = conf; this.secretKeyHolder = secretKeyHolder; this.saslEnabled = saslEnabled; + this.saslEncryptionEnabled = saslEncryptionEnabled; } @Override @@ -70,7 +77,7 @@ public void init(String appId) { TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); List bootstraps = Lists.newArrayList(); if (saslEnabled) { - bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder)); + bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled)); } clientFactory = context.createClientFactory(bootstraps); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index d25283e46ef96..382f613ecbb1b 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network.sasl; import java.io.IOException; +import java.util.Arrays; import com.google.common.collect.Lists; import org.junit.After; @@ -37,6 +38,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -72,10 +74,11 @@ public String getSecretKey(String appId) { @BeforeClass public static void beforeAll() throws IOException { SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); - SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder); conf = new TransportConf(new SystemPropertyConfigProvider()); - context = new TransportContext(conf, handler); - server = context.createServer(); + context = new TransportContext(conf, new TestRpcHandler()); + + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, secretKeyHolder); + server = context.createServer(Arrays.asList(bootstrap)); } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 02c10bcb7b261..39aa49911d9cb 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -136,7 +136,7 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @@ -274,7 +274,7 @@ public void testFetchNoServer() throws Exception { private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) throws IOException { - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 759a12910c94d..d4ec1956c1e29 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.IOException; +import java.util.Arrays; import org.junit.After; import org.junit.Before; @@ -27,10 +28,11 @@ import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; @@ -42,10 +44,10 @@ public class ExternalShuffleSecuritySuite { @Before public void beforeEach() { - RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(conf), - new TestSecretKeyHolder("my-app-id", "secret")); - TransportContext context = new TransportContext(conf, handler); - this.server = context.createServer(); + TransportContext context = new TransportContext(conf, new ExternalShuffleBlockHandler(conf)); + TransportServerBootstrap bootstrap = new SaslServerBootstrap(conf, + new TestSecretKeyHolder("my-app-id", "secret")); + this.server = context.createServer(Arrays.asList(bootstrap)); } @After @@ -58,13 +60,13 @@ public void afterEach() { @Test public void testValid() throws IOException { - validate("my-app-id", "secret"); + validate("my-app-id", "secret", false); } @Test public void testBadAppId() { try { - validate("wrong-app-id", "secret"); + validate("wrong-app-id", "secret", false); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); } @@ -73,16 +75,21 @@ public void testBadAppId() { @Test public void testBadSecret() { try { - validate("my-app-id", "bad-secret"); + validate("my-app-id", "bad-secret", false); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); } } + @Test + public void testEncryption() throws IOException { + validate("my-app-id", "secret", true); + } + /** Creates an ExternalShuffleClient and attempts to register with the server. */ - private void validate(String appId, String secretKey) throws IOException { + private void validate(String appId, String secretKey, boolean encrypt) throws IOException { ExternalShuffleClient client = - new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); + new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt); client.init(appId); // Registration either succeeds or throws an exception. client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 63b21222e7b77..463f99ef3352d 100644 --- a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -17,9 +17,10 @@ package org.apache.spark.network.yarn; -import java.lang.Override; import java.nio.ByteBuffer; +import java.util.List; +import com.google.common.collect.Lists; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.ContainerId; @@ -32,10 +33,11 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SaslServerBootstrap; import org.apache.spark.network.sasl.ShuffleSecretManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; import org.apache.spark.network.util.TransportConf; import org.apache.spark.network.yarn.util.HadoopConfigProvider; @@ -103,16 +105,17 @@ protected void serviceInit(Configuration conf) { // special RPC handler that filters out unauthenticated fetch requests boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); blockHandler = new ExternalShuffleBlockHandler(transportConf); - RpcHandler rpcHandler = blockHandler; + + List bootstraps = Lists.newArrayList(); if (authEnabled) { secretManager = new ShuffleSecretManager(); - rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); + bootstraps.add(new SaslServerBootstrap(transportConf, secretManager)); } int port = conf.getInt( SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); - TransportContext transportContext = new TransportContext(transportConf, rpcHandler); - shuffleServer = transportContext.createServer(port); + TransportContext transportContext = new TransportContext(transportConf, blockHandler); + shuffleServer = transportContext.createServer(port, bootstraps); String authEnabledString = authEnabled ? "enabled" : "not enabled"; logger.info("Started YARN shuffle service for Spark on port {}. " + "Authentication is {}.", port, authEnabledString); diff --git a/pom.xml b/pom.xml index c85c5feeaf383..4313f940036c8 100644 --- a/pom.xml +++ b/pom.xml @@ -117,7 +117,7 @@ 1.6 spark 2.0.1 - 0.21.0 + 0.21.1 shaded-protobuf 1.7.10 1.2.17 diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b006120eb266d..31992795a9e45 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -267,6 +267,13 @@ def __exit__(self, type, value, trace): """ self.stop() + def setLogLevel(self, logLevel): + """ + Control our logLevel. This overrides any user-defined log settings. + Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN + """ + self._jsc.setLogLevel(logLevel) + @classmethod def setSystemProperty(cls, key, value): """ diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 6d54b9e49ed10..b60b991dd4d8b 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -54,7 +54,9 @@ from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions +from pyspark.sql.dataframe import DataFrameStatFunctions __all__ = [ - 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions' + 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', + 'DataFrameNaFunctions', 'DataFrameStatFunctions' ] diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5908ebc990a56..5ff49cac5522b 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1,6 +1,6 @@ # # Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with +# contir[butor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with @@ -34,7 +34,8 @@ from pyspark.sql.types import _create_cls, _parse_datatype_json_string -__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions"] +__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions", + "DataFrameStatFunctions"] class DataFrame(object): @@ -93,6 +94,12 @@ def na(self): """ return DataFrameNaFunctions(self) + @property + def stat(self): + """Returns a :class:`DataFrameStatFunctions` for statistic functions. + """ + return DataFrameStatFunctions(self) + @ignore_unicode_prefix def toJSON(self, use_unicode=True): """Converts a :class:`DataFrame` into a :class:`RDD` of string. @@ -868,6 +875,20 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + def cov(self, col1, col2): + """ + Calculate the sample covariance for the given columns, specified by their names, as a + double value. :func:`DataFrame.cov` and :func:`DataFrameStatFunctions.cov` are aliases. + + :param col1: The name of the first column + :param col2: The name of the second column + """ + if not isinstance(col1, str): + raise ValueError("col1 should be a string.") + if not isinstance(col2, str): + raise ValueError("col2 should be a string.") + return self._jdf.stat().cov(col1, col2) + @ignore_unicode_prefix def withColumn(self, colName, col): """Returns a new :class:`DataFrame` by adding a column. @@ -1311,6 +1332,19 @@ def fill(self, value, subset=None): fill.__doc__ = DataFrame.fillna.__doc__ +class DataFrameStatFunctions(object): + """Functionality for statistic functions with :class:`DataFrame`. + """ + + def __init__(self, df): + self.df = df + + def cov(self, col1, col2): + return self.df.cov(col1, col2) + + cov.__doc__ = DataFrame.cov.__doc__ + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 241f82175726f..641220a264295 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -24,13 +24,20 @@ from itertools import imap as map from pyspark import SparkContext -from pyspark.rdd import _prepare_for_python_RDD +from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.dataframe import Column, _to_java_column, _to_seq -__all__ = ['countDistinct', 'approxCountDistinct', 'udf'] +__all__ = [ + 'approxCountDistinct', + 'countDistinct', + 'monotonicallyIncreasingId', + 'rand', + 'randn', + 'sparkPartitionId', + 'udf'] def _create_function(name, doc=""): @@ -74,27 +81,21 @@ def _(col): __all__.sort() -def rand(seed=None): - """ - Generate a random column with i.i.d. samples from U[0.0, 1.0]. - """ - sc = SparkContext._active_spark_context - if seed: - jc = sc._jvm.functions.rand(seed) - else: - jc = sc._jvm.functions.rand() - return Column(jc) +def array(*cols): + """Creates a new array column. + :param cols: list of column names (string) or list of :class:`Column` expressions that have + the same data type. -def randn(seed=None): - """ - Generate a column with i.i.d. samples from the standard normal distribution. + >>> df.select(array('age', 'age').alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + >>> df.select(array([df.age, df.age]).alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] """ sc = SparkContext._active_spark_context - if seed: - jc = sc._jvm.functions.randn(seed) - else: - jc = sc._jvm.functions.randn() + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column)) return Column(jc) @@ -146,6 +147,28 @@ def monotonicallyIncreasingId(): return Column(sc._jvm.functions.monotonicallyIncreasingId()) +def rand(seed=None): + """Generates a random column with i.i.d. samples from U[0.0, 1.0]. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.rand(seed) + else: + jc = sc._jvm.functions.rand() + return Column(jc) + + +def randn(seed=None): + """Generates a column with i.i.d. samples from the standard normal distribution. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.randn(seed) + else: + jc = sc._jvm.functions.randn() + return Column(jc) + + def sparkPartitionId(): """A column for partition ID of the Spark task. @@ -158,6 +181,25 @@ def sparkPartitionId(): return Column(sc._jvm.functions.sparkPartitionId()) +@ignore_unicode_prefix +def struct(*cols): + """Creates a new struct column. + + :param cols: list of column names (string) or list of :class:`Column` expressions + that are named or aliased. + + >>> df.select(struct('age', 'name').alias("struct")).collect() + [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] + >>> df.select(struct([df.age, df.name]).alias("struct")).collect() + [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.struct(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + class UserDefinedFunction(object): """ User defined function in Python diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5640bb5ea2346..44c8b6a1aac13 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -387,6 +387,11 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_cov(self): + df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() + cov = df.stat.cov("a", "b") + self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) + def test_math_functions(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() from pyspark.sql import mathfunctions as functions diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 7c06c203455d9..33ea8c9293d74 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -606,7 +606,6 @@ def _validateRddResult(self, sendData, rdd): result = {} for i in rdd.map(lambda x: x[1]).collect(): result[i] = result.get(i, 0) + 1 - self.assertEqual(sendData, result) def test_kafka_stream(self): @@ -616,6 +615,7 @@ def test_kafka_stream(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) + self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, @@ -631,6 +631,7 @@ def test_kafka_direct_stream(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) + self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) self._validateStreamResult(sendData, stream) @@ -645,6 +646,7 @@ def test_kafka_direct_stream_from_offset(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) + self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) self._validateStreamResult(sendData, stream) @@ -659,7 +661,7 @@ def test_kafka_rdd(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - + self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) self._validateRddResult(sendData, rdd) @@ -675,7 +677,7 @@ def test_kafka_rdd_with_leaders(self): self._kafkaTestUtils.createTopic(topic) self._kafkaTestUtils.sendMessages(topic, sendData) - + self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values())) rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 2225621dbaabd..c6217f07c452d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -28,13 +28,21 @@ import org.apache.spark.sql.catalyst.trees * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) - extends Expression with trees.LeafNode[Expression] { + extends NamedExpression with trees.LeafNode[Expression] { type EvaluatedType = Any override def toString: String = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) + + override def name: String = s"i[$ordinal]" + + override def toAttribute: Attribute = throw new UnsupportedOperationException + + override def qualifiers: Seq[String] = throw new UnsupportedOperationException + + override def exprId: ExprId = throw new UnsupportedOperationException } object BindReferences extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 42e5cbc05e1e0..23652aeb7c7bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.stat.FrequentItems +import org.apache.spark.sql.execution.stat._ /** * :: Experimental :: @@ -65,4 +65,14 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: List[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } + + /** + * Calculate the sample covariance of two numerical columns of a DataFrame. + * @param col1 the name of the first column + * @param col2 the name of the second column + * @return the covariance of the two columns. + */ + def cov(col1: String, col2: String): Double = { + StatFunctions.calculateCov(df, Seq(col1, col2)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index bd4a55fa132fb..5116fcefd4bf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -428,20 +428,6 @@ class SQLContext(@transient val sparkContext: SparkContext) createDataFrame(rowRDD.rdd, schema) } - /** - * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s by applying - * a seq of names of columns to this RDD, the data type for each column will - * be inferred by the first row. - * - * @param rowRDD an JavaRDD of Row - * @param columns names for each column - * @return DataFrame - * @group dataframes - */ - def createDataFrame(rowRDD: JavaRDD[Row], columns: java.util.List[String]): DataFrame = { - createDataFrame(rowRDD.rdd, columns.toSeq) - } - /** * Applies a schema to an RDD of Java Beans. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala new file mode 100644 index 0000000000000..d4a94c24d9866 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.stat + +import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.types.{DoubleType, NumericType} + +private[sql] object StatFunctions { + + /** Helper class to simplify tracking and merging counts. */ + private class CovarianceCounter extends Serializable { + var xAvg = 0.0 + var yAvg = 0.0 + var Ck = 0.0 + var count = 0L + // add an example to the calculation + def add(x: Double, y: Double): this.type = { + val oldX = xAvg + count += 1 + xAvg += (x - xAvg) / count + yAvg += (y - yAvg) / count + Ck += (y - yAvg) * (x - oldX) + this + } + // merge counters from other partitions. Formula can be found at: + // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Covariance + def merge(other: CovarianceCounter): this.type = { + val totalCount = count + other.count + Ck += other.Ck + + (xAvg - other.xAvg) * (yAvg - other.yAvg) * count / totalCount * other.count + xAvg = (xAvg * count + other.xAvg * other.count) / totalCount + yAvg = (yAvg * count + other.yAvg * other.count) / totalCount + count = totalCount + this + } + // return the sample covariance for the observed examples + def cov: Double = Ck / (count - 1) + } + + /** + * Calculate the covariance of two numerical columns of a DataFrame. + * @param df The DataFrame + * @param cols the column names + * @return the covariance of the two columns. + */ + private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { + require(cols.length == 2, "Currently cov supports calculating the covariance " + + "between two columns.") + cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => + require(data.nonEmpty, s"Couldn't find column with name $name") + require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " + + s"with dataType ${data.get.dataType} not supported.") + } + val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) + val counts = df.select(columns:_*).rdd.aggregate(new CovarianceCounter)( + seqOp = (counter, row) => { + counter.add(row.getDouble(0), row.getDouble(1)) + }, + combOp = (baseCounter, other) => { + baseCounter.merge(other) + }) + counts.cov + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 242e64d3ff881..7e283393d0563 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -283,6 +283,23 @@ object functions { */ def abs(e: Column): Column = Abs(e.expr) + /** + * Creates a new array column. The input columns must all have the same data type. + * + * @group normal_funcs + */ + @scala.annotation.varargs + def array(cols: Column*): Column = CreateArray(cols.map(_.expr)) + + /** + * Creates a new array column. The input columns must all have the same data type. + * + * @group normal_funcs + */ + def array(colName: String, colNames: String*): Column = { + array((colName +: colNames).map(col) : _*) + } + /** * Returns the first column that is not null. * {{{ @@ -390,6 +407,28 @@ object functions { */ def sqrt(e: Column): Column = Sqrt(e.expr) + /** + * Creates a new struct column. The input column must be a column in a [[DataFrame]], or + * a derived column expression that is named (i.e. aliased). + * + * @group normal_funcs + */ + @scala.annotation.varargs + def struct(cols: Column*): Column = { + require(cols.forall(_.expr.isInstanceOf[NamedExpression]), + s"struct input columns must all be named or aliased ($cols)") + CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression])) + } + + /** + * Creates a new struct column that composes multiple input columns. + * + * @group normal_funcs + */ + def struct(colName: String, colNames: String*): Column = { + struct((colName +: colNames).map(col) : _*) + } + /** * Converts a string expression to upper case. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index ae9af1eabe68e..3a6c2c1e9101f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement} +import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement, SQLFeatureNotSupportedException} import java.util.Properties import scala.collection.mutable @@ -195,7 +195,9 @@ package object jdbc { override def getMinorVersion: Int = wrapped.getMinorVersion - override def getParentLogger: java.util.logging.Logger = wrapped.getParentLogger + def getParentLogger: java.util.logging.Logger = + throw new SQLFeatureNotSupportedException( + s"${this.getClass().getName}.getParentLogger is not yet implemented.") override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index ebe96e649d940..96fe66d0b84a6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -186,4 +186,11 @@ public void testFrequentItems() { DataFrame results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); } + + @Test + public void testCovariance() { + DataFrame df = context.table("testData2"); + Double result = df.stat().cov("a", "b"); + Assert.assertTrue(Math.abs(result) < 1e-6); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala new file mode 100644 index 0000000000000..ca03713ef4658 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.types._ + +/** + * Test suite for functions in [[org.apache.spark.sql.functions]]. + */ +class DataFrameFunctionsSuite extends QueryTest { + + test("array with column name") { + val df = Seq((0, 1)).toDF("a", "b") + val row = df.select(array("a", "b")).first() + + val expectedType = ArrayType(IntegerType, containsNull = false) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Seq[Int]](0) === Seq(0, 1)) + } + + test("array with column expression") { + val df = Seq((0, 1)).toDF("a", "b") + val row = df.select(array(col("a"), col("b") + col("b"))).first() + + val expectedType = ArrayType(IntegerType, containsNull = false) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Seq[Int]](0) === Seq(0, 2)) + } + + // Turn this on once we add a rule to the analyzer to throw a friendly exception + ignore("array: throw exception if putting columns of different types into an array") { + val df = Seq((0, "str")).toDF("a", "b") + intercept[AnalysisException] { + df.select(array("a", "b")) + } + } + + test("struct with column name") { + val df = Seq((1, "str")).toDF("a", "b") + val row = df.select(struct("a", "b")).first() + + val expectedType = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Row](0) === Row(1, "str")) + } + + test("struct with column expression") { + val df = Seq((1, "str")).toDF("a", "b") + val row = df.select(struct((col("a") * 2).as("c"), col("b"))).first() + + val expectedType = StructType(Seq( + StructField("c", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Row](0) === Row(2, "str")) + } + + test("struct: must use named column expression") { + intercept[IllegalArgumentException] { + struct(col("a") * 2) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index bb1d29c71d23b..4f5a2ff696789 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -25,10 +25,11 @@ import org.apache.spark.sql.test.TestSQLContext.implicits._ class DataFrameStatSuite extends FunSuite { + import TestData._ val sqlCtx = TestSQLContext - + def toLetter(i: Int): String = (i + 97).toChar.toString + test("Frequent Items") { - def toLetter(i: Int): String = (i + 96).toChar.toString val rows = Array.tabulate(1000) { i => if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0) } @@ -44,4 +45,17 @@ class DataFrameStatSuite extends FunSuite { items2.getSeq[Double](0) should contain (-1.0) } + + test("covariance") { + val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i))) + val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters") + + val results = df.stat.cov("singles", "doubles") + assert(math.abs(results - 55.0 / 3) < 1e-6) + intercept[IllegalArgumentException] { + df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes + } + val decimalRes = decimalData.stat.cov("a", "b") + assert(math.abs(decimalRes) < 1e-6) + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 90c8b47aebce0..117cb59fb61c9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -159,7 +159,7 @@ class StreamingContext private[streaming] ( } } - private val nextReceiverInputStreamId = new AtomicInteger(0) + private val nextInputStreamId = new AtomicInteger(0) private[streaming] var checkpointDir: String = { if (isCheckpointPresent) { @@ -241,7 +241,7 @@ class StreamingContext private[streaming] ( if (isCheckpointPresent) cp_ else null } - private[streaming] def getNewReceiverStreamId() = nextReceiverInputStreamId.getAndIncrement() + private[streaming] def getNewInputStreamId() = nextInputStreamId.getAndIncrement() /** * Create an input stream with any arbitrary user implemented receiver. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index e652702e213ef..e4ad4b509d8d8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -41,6 +41,9 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) ssc.graph.addInputStream(this) + /** This is an unique identifier for the input stream. */ + val id = ssc.getNewInputStreamId() + /** * Checks whether the 'time' is valid wrt slideDuration for generating RDD. * Additionally it also ensures valid times are in strictly increasing order. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 4c7fd2c57c006..ba88416ef4009 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -24,7 +24,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming._ import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.{Receiver, WriteAheadLogBasedStoreResult} -import org.apache.spark.streaming.scheduler.ReceivedBlockInfo +import org.apache.spark.streaming.scheduler.{InputInfo, ReceivedBlockInfo} /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -39,9 +39,6 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { - /** This is an unique identifier for the receiver input stream. */ - val id = ssc.getNewReceiverStreamId() - /** * Gets the receiver object that will be sent to the worker nodes * to receive data. This method needs to defined by any specific implementation @@ -72,6 +69,10 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont val blockStoreResults = blockInfos.map { _.blockStoreResult } val blockIds = blockStoreResults.map { _.blockId.asInstanceOf[BlockId] }.toArray + // Register the input blocks information into InputInfoTracker + val inputInfo = InputInfo(id, blockInfos.map(_.numRecords).sum) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + // Check whether all the results are of the same type val resultTypes = blockStoreResults.map { _.getClass }.distinct if (resultTypes.size > 1) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 4b3d9ee4b0090..651b534ac1900 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -190,6 +190,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( def stop() { writeAheadLog.close() + executionContext.shutdown() } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 92dc113f397ca..5b9bfbf9b01e3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -24,6 +24,7 @@ import org.apache.spark.streaming.Time * :: DeveloperApi :: * Class having information on completed batches. * @param batchTime Time of the batch + * @param streamIdToNumRecords A map of input stream id to record number * @param submissionTime Clock time of when jobs of this batch was submitted to * the streaming scheduler queue * @param processingStartTime Clock time of when the first job of this batch started processing @@ -32,7 +33,7 @@ import org.apache.spark.streaming.Time @DeveloperApi case class BatchInfo( batchTime: Time, - receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]], + streamIdToNumRecords: Map[Int, Long], submissionTime: Long, processingStartTime: Option[Long], processingEndTime: Option[Long] @@ -58,4 +59,9 @@ case class BatchInfo( */ def totalDelay: Option[Long] = schedulingDelay.zip(processingDelay) .map(x => x._1 + x._2).headOption + + /** + * The number of recorders received by the receivers in this batch. + */ + def numRecords: Long = streamIdToNumRecords.values.sum } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala new file mode 100644 index 0000000000000..a72efccf2f994 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.streaming.{Time, StreamingContext} + +/** To track the information of input stream at specified batch time. */ +private[streaming] case class InputInfo(inputStreamId: Int, numRecords: Long) + +/** + * This class manages all the input streams as well as their input data statistics. The information + * will be exposed through StreamingListener for monitoring. + */ +private[streaming] class InputInfoTracker(ssc: StreamingContext) extends Logging { + + // Map to track all the InputInfo related to specific batch time and input stream. + private val batchTimeToInputInfos = new mutable.HashMap[Time, mutable.HashMap[Int, InputInfo]] + + /** Report the input information with batch time to the tracker */ + def reportInfo(batchTime: Time, inputInfo: InputInfo): Unit = synchronized { + val inputInfos = batchTimeToInputInfos.getOrElseUpdate(batchTime, + new mutable.HashMap[Int, InputInfo]()) + + if (inputInfos.contains(inputInfo.inputStreamId)) { + throw new IllegalStateException(s"Input stream ${inputInfo.inputStreamId}} for batch" + + s"$batchTime is already added into InputInfoTracker, this is a illegal state") + } + inputInfos += ((inputInfo.inputStreamId, inputInfo)) + } + + /** Get the all the input stream's information of specified batch time */ + def getInfo(batchTime: Time): Map[Int, InputInfo] = synchronized { + val inputInfos = batchTimeToInputInfos.get(batchTime) + // Convert mutable HashMap to immutable Map for the caller + inputInfos.map(_.toMap).getOrElse(Map[Int, InputInfo]()) + } + + /** Cleanup the tracked input information older than threshold batch time */ + def cleanup(batchThreshTime: Time): Unit = synchronized { + val timesToCleanup = batchTimeToInputInfos.keys.filter(_ < batchThreshTime) + logInfo(s"remove old batch metadata: ${timesToCleanup.mkString(" ")}") + batchTimeToInputInfos --= timesToCleanup + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 2467d50839add..9f93d6cbc3c20 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -243,9 +243,9 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { graph.generateJobs(time) // generate jobs using allocated block } match { case Success(jobs) => - val receivedBlockInfos = - jobScheduler.receiverTracker.getBlocksOfBatch(time).mapValues { _.toArray } - jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfos)) + val streamIdToInputInfos = jobScheduler.inputInfoTracker.getInfo(time) + val streamIdToNumRecords = streamIdToInputInfos.mapValues(_.numRecords) + jobScheduler.submitJobSet(JobSet(time, jobs, streamIdToNumRecords)) case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } @@ -266,6 +266,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // checkpointing of this batch to complete. val maxRememberDuration = graph.getMaxInputStreamRememberDuration() jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) + jobScheduler.inputInfoTracker.cleanup(time - maxRememberDuration) markBatchFullyProcessed(time) } } @@ -278,6 +279,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // been saved to checkpoints, so its safe to delete block metadata and data WAL files val maxRememberDuration = graph.getMaxInputStreamRememberDuration() jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration) + jobScheduler.inputInfoTracker.cleanup(time - maxRememberDuration) markBatchFullyProcessed(time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index c7a2c1141a128..1d1ddaaccf217 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -50,6 +50,9 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // These two are created only when scheduler starts. // eventLoop not being null means the scheduler has been started and not stopped var receiverTracker: ReceiverTracker = null + // A tracker to track all the input stream information as well as processed record number + var inputInfoTracker: InputInfoTracker = null + private var eventLoop: EventLoop[JobSchedulerEvent] = null def start(): Unit = synchronized { @@ -65,6 +68,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { listenerBus.start(ssc.sparkContext) receiverTracker = new ReceiverTracker(ssc) + inputInfoTracker = new InputInfoTracker(ssc) receiverTracker.start() jobGenerator.start() logInfo("Started JobScheduler") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 24b3794236ea5..e6be63b2ddbdc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -28,7 +28,7 @@ private[streaming] case class JobSet( time: Time, jobs: Seq[Job], - receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty) { + streamIdToNumRecords: Map[Int, Long] = Map.empty) { private val incompleteJobs = new HashSet[Job]() private val submissionTime = System.currentTimeMillis() // when this jobset was submitted @@ -64,7 +64,7 @@ case class JobSet( def toBatchInfo: BatchInfo = { new BatchInfo( time, - receivedBlockInfo, + streamIdToNumRecords, submissionTime, if (processingStartTime >= 0 ) Some(processingStartTime) else None, if (processingEndTime >= 0 ) Some(processingEndTime) else None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala index f45c291b7c0fe..99e10d2b0be12 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchUIData.scala @@ -66,7 +66,7 @@ private[ui] object BatchUIData { def apply(batchInfo: BatchInfo): BatchUIData = { new BatchUIData( batchInfo.batchTime, - batchInfo.receivedBlockInfo.mapValues(_.map(_.numRecords).sum), + batchInfo.streamIdToNumRecords, batchInfo.submissionTime, batchInfo.processingStartTime, batchInfo.processingEndTime diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index e6ac4975c5e68..eb136758249d5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -27,17 +27,18 @@ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQu import scala.language.postfixOps import com.google.common.io.Files +import org.apache.hadoop.io.{Text, LongWritable} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ import org.apache.spark.Logging +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.rdd.RDD -import org.apache.hadoop.io.{Text, LongWritable} -import org.apache.hadoop.mapreduce.lib.input.TextInputFormat -import org.apache.hadoop.fs.Path class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -278,6 +279,30 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } + test("test track the number of input stream") { + val ssc = new StreamingContext(conf, batchDuration) + + class TestInputDStream extends InputDStream[String](ssc) { + def start() { } + def stop() { } + def compute(validTime: Time): Option[RDD[String]] = None + } + + class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { + def getReceiver: Receiver[String] = null + } + + // Register input streams + val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) + val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) + + assert(ssc.graph.getInputStreams().length == receiverInputStreams.length + inputStreams.length) + assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) + assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) + assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) + assert(receiverInputStreams.map(_.id) === Array(0, 1)) + } + def testFileStream(newFilesOnly: Boolean) { val testDir: File = null try { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 9020be166acf0..312cce408cfe7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -57,6 +57,11 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay should be (None) }) + batchInfosSubmitted.foreach { info => + info.numRecords should be (1L) + info.streamIdToNumRecords should be (Map(0 -> 1L)) + } + isInIncreasingOrder(batchInfosSubmitted.map(_.submissionTime)) should be (true) // SPARK-6766: processingStartTime of batch info should not be None when starting @@ -70,6 +75,11 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay should be (None) }) + batchInfosStarted.foreach { info => + info.numRecords should be (1L) + info.streamIdToNumRecords should be (Map(0 -> 1L)) + } + isInIncreasingOrder(batchInfosStarted.map(_.submissionTime)) should be (true) isInIncreasingOrder(batchInfosStarted.map(_.processingStartTime.get)) should be (true) @@ -86,6 +96,11 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { info.totalDelay.get should be >= 0L }) + batchInfosCompleted.foreach { info => + info.numRecords should be (1L) + info.streamIdToNumRecords should be (Map(0 -> 1L)) + } + isInIncreasingOrder(batchInfosCompleted.map(_.submissionTime)) should be (true) isInIncreasingOrder(batchInfosCompleted.map(_.processingStartTime.get)) should be (true) isInIncreasingOrder(batchInfosCompleted.map(_.processingEndTime.get)) should be (true) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index c3cae8aeb6d15..2ba86aeaf9d51 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -29,10 +29,10 @@ import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration -import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} -import org.apache.spark.streaming.scheduler.{StreamingListenerBatchStarted, StreamingListenerBatchCompleted, StreamingListener} import org.apache.spark.{SparkConf, Logging} import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} +import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ManualClock, Utils} /** @@ -57,6 +57,10 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], return None } + // Report the input data's information to InputInfoTracker for testing + val inputInfo = InputInfo(id, selectedInput.length.toLong) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) Some(rdd) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala new file mode 100644 index 0000000000000..5478b41845943 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.{Time, Duration, StreamingContext} + +class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter { + + private var ssc: StreamingContext = _ + + before { + val conf = new SparkConf().setMaster("local[2]").setAppName("DirectStreamTacker") + if (ssc == null) { + ssc = new StreamingContext(conf, Duration(1000)) + } + } + + after { + if (ssc != null) { + ssc.stop() + ssc = null + } + } + + test("test report and get InputInfo from InputInfoTracker") { + val inputInfoTracker = new InputInfoTracker(ssc) + + val streamId1 = 0 + val streamId2 = 1 + val time = Time(0L) + val inputInfo1 = InputInfo(streamId1, 100L) + val inputInfo2 = InputInfo(streamId2, 300L) + inputInfoTracker.reportInfo(time, inputInfo1) + inputInfoTracker.reportInfo(time, inputInfo2) + + val batchTimeToInputInfos = inputInfoTracker.getInfo(time) + assert(batchTimeToInputInfos.size == 2) + assert(batchTimeToInputInfos.keys === Set(streamId1, streamId2)) + assert(batchTimeToInputInfos(streamId1) === inputInfo1) + assert(batchTimeToInputInfos(streamId2) === inputInfo2) + assert(inputInfoTracker.getInfo(time)(streamId1) === inputInfo1) + } + + test("test cleanup InputInfo from InputInfoTracker") { + val inputInfoTracker = new InputInfoTracker(ssc) + + val streamId1 = 0 + val inputInfo1 = InputInfo(streamId1, 100L) + val inputInfo2 = InputInfo(streamId1, 300L) + inputInfoTracker.reportInfo(Time(0), inputInfo1) + inputInfoTracker.reportInfo(Time(1), inputInfo2) + + inputInfoTracker.cleanup(Time(0)) + assert(inputInfoTracker.getInfo(Time(0))(streamId1) === inputInfo1) + assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2) + + inputInfoTracker.cleanup(Time(1)) + assert(inputInfoTracker.getInfo(Time(0)).get(streamId1) === None) + assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 5103187b68aa9..d72dcdc017fe4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -49,13 +49,10 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) - val receivedBlockInfo = Map( - 0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)), - 1 -> Array(ReceivedBlockInfo(1, 300, null)) - ) + val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) // onBatchSubmitted - val batchInfoSubmitted = BatchInfo(Time(1000), receivedBlockInfo, 1000, None, None) + val batchInfoSubmitted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, None, None) listener.onBatchSubmitted(StreamingListenerBatchSubmitted(batchInfoSubmitted)) listener.waitingBatches should be (List(BatchUIData(batchInfoSubmitted))) listener.runningBatches should be (Nil) @@ -67,7 +64,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) @@ -106,7 +103,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { OutputOpIdAndSparkJobId(1, 1)) // onBatchCompleted - val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (Nil) @@ -144,11 +141,9 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) - val receivedBlockInfo = Map( - 0 -> Array(ReceivedBlockInfo(0, 100, null), ReceivedBlockInfo(0, 200, null)), - 1 -> Array(ReceivedBlockInfo(1, 300, null)) - ) - val batchInfoCompleted = BatchInfo(Time(1000), receivedBlockInfo, 1000, Some(2000), None) + val streamIdToNumRecords = Map(0 -> 300L, 1 -> 300L) + + val batchInfoCompleted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) for(_ <- 0 until (limit + 10)) { listener.onBatchCompleted(StreamingListenerBatchCompleted(batchInfoCompleted)) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala new file mode 100644 index 0000000000000..aaae6f9734a85 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.security.PrivilegedExceptionAction +import java.util.concurrent.{Executors, TimeUnit} + +import scala.language.postfixOps + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.UserGroupInformation +import org.apache.spark.deploy.SparkHadoopUtil + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.ThreadUtils + +/* + * The following methods are primarily meant to make sure long-running apps like Spark + * Streaming apps can run without interruption while writing to secure HDFS. The + * scheduleLoginFromKeytab method is called on the driver when the + * CoarseGrainedScheduledBackend starts up. This method wakes up a thread that logs into the KDC + * once 75% of the renewal interval of the original delegation tokens used for the container + * has elapsed. It then creates new delegation tokens and writes them to HDFS in a + * pre-specified location - the prefix of which is specified in the sparkConf by + * spark.yarn.credentials.file (so the file(s) would be named c-1, c-2 etc. - each update goes + * to a new file, with a monotonically increasing suffix). After this, the credentials are + * updated once 75% of the new tokens renewal interval has elapsed. + * + * On the executor side, the updateCredentialsIfRequired method is called once 80% of the + * validity of the original tokens has elapsed. At that time the executor finds the + * credentials file with the latest timestamp and checks if it has read those credentials + * before (by keeping track of the suffix of the last file it read). If a new file has + * appeared, it will read the credentials and update the currently running UGI with it. This + * process happens again once 80% of the validity of this has expired. + */ +private[yarn] class AMDelegationTokenRenewer( + sparkConf: SparkConf, + hadoopConf: Configuration) extends Logging { + + private var lastCredentialsFileSuffix = 0 + + private val delegationTokenRenewer = + Executors.newSingleThreadScheduledExecutor( + ThreadUtils.namedThreadFactory("Delegation Token Refresh Thread")) + + private val hadoopUtil = YarnSparkHadoopUtil.get + + private val daysToKeepFiles = sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) + private val numFilesToKeep = sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + + /** + * Schedule a login from the keytab and principal set using the --principal and --keytab + * arguments to spark-submit. This login happens only when the credentials of the current user + * are about to expire. This method reads spark.yarn.principal and spark.yarn.keytab from + * SparkConf to do the login. This method is a no-op in non-YARN mode. + * + */ + private[spark] def scheduleLoginFromKeytab(): Unit = { + val principal = sparkConf.get("spark.yarn.principal") + val keytab = sparkConf.get("spark.yarn.keytab") + + /** + * Schedule re-login and creation of new tokens. If tokens have already expired, this method + * will synchronously create new ones. + */ + def scheduleRenewal(runnable: Runnable): Unit = { + val credentials = UserGroupInformation.getCurrentUser.getCredentials + val renewalInterval = hadoopUtil.getTimeFromNowToRenewal(sparkConf, 0.75, credentials) + // Run now! + if (renewalInterval <= 0) { + logInfo("HDFS tokens have expired, creating new tokens now.") + runnable.run() + } else { + logInfo(s"Scheduling login from keytab in $renewalInterval millis.") + delegationTokenRenewer.schedule(runnable, renewalInterval, TimeUnit.MILLISECONDS) + } + } + + // This thread periodically runs on the driver to update the delegation tokens on HDFS. + val driverTokenRenewerRunnable = + new Runnable { + override def run(): Unit = { + try { + writeNewTokensToHDFS(principal, keytab) + cleanupOldFiles() + } catch { + case e: Exception => + // Log the error and try to write new tokens back in an hour + logWarning("Failed to write out new credentials to HDFS, will try again in an " + + "hour! If this happens too often tasks will fail.", e) + delegationTokenRenewer.schedule(this, 1, TimeUnit.HOURS) + return + } + scheduleRenewal(this) + } + } + // Schedule update of credentials. This handles the case of updating the tokens right now + // as well, since the renenwal interval will be 0, and the thread will get scheduled + // immediately. + scheduleRenewal(driverTokenRenewerRunnable) + } + + // Keeps only files that are newer than daysToKeepFiles days, and deletes everything else. At + // least numFilesToKeep files are kept for safety + private def cleanupOldFiles(): Unit = { + import scala.concurrent.duration._ + try { + val remoteFs = FileSystem.get(hadoopConf) + val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis + hadoopUtil.listFilesSorted( + remoteFs, credentialsPath.getParent, + credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) + .dropRight(numFilesToKeep) + .takeWhile(_.getModificationTime < thresholdTime) + .foreach(x => remoteFs.delete(x.getPath, true)) + } catch { + // Such errors are not fatal, so don't throw. Make sure they are logged though + case e: Exception => + logWarning("Error while attempting to cleanup old tokens. If you are seeing many such " + + "warnings there may be an issue with your HDFS cluster.", e) + } + } + + private def writeNewTokensToHDFS(principal: String, keytab: String): Unit = { + // Keytab is copied by YARN to the working directory of the AM, so full path is + // not needed. + + // HACK: + // HDFS will not issue new delegation tokens, if the Credentials object + // passed in already has tokens for that FS even if the tokens are expired (it really only + // checks if there are tokens for the service, and not if they are valid). So the only real + // way to get new tokens is to make sure a different Credentials object is used each time to + // get new tokens and then the new tokens are copied over the the current user's Credentials. + // So: + // - we login as a different user and get the UGI + // - use that UGI to get the tokens (see doAs block below) + // - copy the tokens over to the current user's credentials (this will overwrite the tokens + // in the current user's Credentials object for this FS). + // The login to KDC happens each time new tokens are required, but this is rare enough to not + // have to worry about (like once every day or so). This makes this code clearer than having + // to login and then relogin every time (the HDFS API may not relogin since we don't use this + // UGI directly for HDFS communication. + logInfo(s"Attempting to login to KDC using principal: $principal") + val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) + logInfo("Successfully logged into KDC.") + val tempCreds = keytabLoggedInUGI.getCredentials + val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val dst = credentialsPath.getParent + keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { + // Get a copy of the credentials + override def run(): Void = { + val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst + hadoopUtil.obtainTokensForNamenodes(nns, hadoopConf, tempCreds) + null + } + }) + // Add the temp credentials back to the original ones. + UserGroupInformation.getCurrentUser.addCredentials(tempCreds) + val remoteFs = FileSystem.get(hadoopConf) + // If lastCredentialsFileSuffix is 0, then the AM is either started or restarted. If the AM + // was restarted, then the lastCredentialsFileSuffix might be > 0, so find the newest file + // and update the lastCredentialsFileSuffix. + if (lastCredentialsFileSuffix == 0) { + hadoopUtil.listFilesSorted( + remoteFs, credentialsPath.getParent, + credentialsPath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) + .lastOption.foreach { status => + lastCredentialsFileSuffix = hadoopUtil.getSuffixForCredentialsPath(status.getPath) + } + } + val nextSuffix = lastCredentialsFileSuffix + 1 + val tokenPathStr = + sparkConf.get("spark.yarn.credentials.file") + + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix + val tokenPath = new Path(tokenPathStr) + val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) + logInfo("Writing out delegation tokens to " + tempTokenPath.toString) + val credentials = UserGroupInformation.getCurrentUser.getCredentials + credentials.writeTokenStorageFile(tempTokenPath, hadoopConf) + logInfo(s"Delegation Tokens written out successfully. Renaming file to $tokenPathStr") + remoteFs.rename(tempTokenPath, tokenPath) + logInfo("Delegation token file rename complete.") + lastCredentialsFileSuffix = nextSuffix + } + + def stop(): Unit = { + delegationTokenRenewer.shutdown() + } +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 27f804782f355..e1694c1f64a9f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -75,6 +75,8 @@ private[spark] class ApplicationMaster( // Fields used in cluster mode. private val sparkContextRef = new AtomicReference[SparkContext](null) + private var delegationTokenRenewerOption: Option[AMDelegationTokenRenewer] = None + final def run(): Int = { try { val appAttemptId = client.getAttemptId() @@ -129,6 +131,15 @@ private[spark] class ApplicationMaster( // doAs in order for the credentials to be passed on to the executor containers. val securityMgr = new SecurityManager(sparkConf) + // If the credentials file config is present, we must periodically renew tokens. So create + // a new AMDelegationTokenRenewer + if (sparkConf.contains("spark.yarn.credentials.file")) { + delegationTokenRenewerOption = Some(new AMDelegationTokenRenewer(sparkConf, yarnConf)) + // If a principal and keytab have been set, use that to create new credentials for executors + // periodically + delegationTokenRenewerOption.foreach(_.scheduleLoginFromKeytab()) + } + if (isClusterMode) { runDriver(securityMgr) } else { @@ -193,6 +204,7 @@ private[spark] class ApplicationMaster( logDebug("shutting down user thread") userClassThread.interrupt() } + if (!inShutdown) delegationTokenRenewerOption.foreach(_.stop()) } } } @@ -240,12 +252,12 @@ private[spark] class ApplicationMaster( host: String, port: String, isClusterMode: Boolean): Unit = { - val driverEndpont = rpcEnv.setupEndpointRef( + val driverEndpoint = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, RpcAddress(host, port.toInt), YarnSchedulerBackend.ENDPOINT_NAME) amEndpoint = - rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpont, isClusterMode)) + rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpoint, isClusterMode)) } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -499,6 +511,7 @@ private[spark] class ApplicationMaster( override def onStart(): Unit = { driver.send(RegisterClusterManager(self)) + } override def receive: PartialFunction[Any, Unit] = { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 4abcf7307a388..20ecaf092e3f8 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,9 +17,11 @@ package org.apache.spark.deploy.yarn -import java.io.{File, FileOutputStream} +import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer +import java.security.PrivilegedExceptionAction +import java.util.UUID import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConversions._ @@ -36,7 +38,6 @@ import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifie import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{TokenIdentifier, Token} @@ -50,8 +51,8 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} import org.apache.spark.util.Utils private[spark] class Client( @@ -69,11 +70,13 @@ private[spark] class Client( private val yarnClient = YarnClient.createYarnClient private val yarnConf = new YarnConfiguration(hadoopConf) - private val credentials = UserGroupInformation.getCurrentUser.getCredentials + private var credentials: Credentials = null private val amMemoryOverhead = args.amMemoryOverhead // MB private val executorMemoryOverhead = args.executorMemoryOverhead // MB private val distCacheMgr = new ClientDistributedCacheManager() private val isClusterMode = args.isClusterMode + + private var loginFromKeytab = false private val fireAndForget = isClusterMode && !sparkConf.getBoolean("spark.yarn.submit.waitAppCompletion", true) @@ -88,6 +91,8 @@ private[spark] class Client( * available in the alpha API. */ def submitApplication(): ApplicationId = { + // Setup the credentials before doing anything else, so we have don't have issues at any point. + setupCredentials() yarnClient.init(yarnConf) yarnClient.start() @@ -219,12 +224,12 @@ private[spark] class Client( // and add them as local resources to the application master. val fs = FileSystem.get(hadoopConf) val dst = new Path(fs.getHomeDirectory(), appStagingDir) - val nns = getNameNodesToAccess(sparkConf) + dst + val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst + YarnSparkHadoopUtil.get.obtainTokensForNamenodes(nns, hadoopConf, credentials) // Used to keep track of URIs added to the distributed cache. If the same URI is added // multiple times, YARN will fail to launch containers for the app with an internal // error. val distributedUris = new HashSet[String] - obtainTokensForNamenodes(nns, hadoopConf, credentials) obtainTokenForHiveMetastore(hadoopConf, credentials) obtainTokenForHBase(hadoopConf, credentials) @@ -243,6 +248,20 @@ private[spark] class Client( "for alternatives.") } + // If we passed in a keytab, make sure we copy the keytab to the staging directory on + // HDFS, and setup the relevant environment vars, so the AM can login again. + if (loginFromKeytab) { + logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + + " via the YARN Secure Distributed Cache.") + val localUri = new URI(args.keytab) + val localPath = getQualifiedLocalPath(localUri, hadoopConf) + val destinationPath = copyFileToRemote(dst, localPath, replication) + val destFs = FileSystem.get(destinationPath.toUri(), hadoopConf) + distCacheMgr.addResource( + destFs, hadoopConf, destinationPath, localResources, LocalResourceType.FILE, + sparkConf.get("spark.yarn.keytab"), statCache, appMasterOnly = true) + } + def addDistributedUri(uri: URI): Boolean = { val uriStr = uri.toString() if (distributedUris.contains(uriStr)) { @@ -371,9 +390,11 @@ private[spark] class Client( try { hadoopConfStream.setLevel(0) hadoopConfFiles.foreach { case (name, file) => - hadoopConfStream.putNextEntry(new ZipEntry(name)) - Files.copy(file, hadoopConfStream) - hadoopConfStream.closeEntry() + if (file.canRead()) { + hadoopConfStream.putNextEntry(new ZipEntry(name)) + Files.copy(file, hadoopConfStream) + hadoopConfStream.closeEntry() + } } } finally { hadoopConfStream.close() @@ -385,6 +406,28 @@ private[spark] class Client( } } + /** + * Get the renewal interval for tokens. + */ + private def getTokenRenewalInterval(stagingDirPath: Path): Long = { + // We cannot use the tokens generated above since those have renewer yarn. Trying to renew + // those will fail with an access control issue. So create new tokens with the logged in + // user as renewer. + val creds = new Credentials() + val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + stagingDirPath + YarnSparkHadoopUtil.get.obtainTokensForNamenodes( + nns, hadoopConf, creds, Some(sparkConf.get("spark.yarn.principal"))) + val t = creds.getAllTokens + .filter(_.getKind == DelegationTokenIdentifier.HDFS_DELEGATION_KIND) + .head + val newExpiration = t.renew(hadoopConf) + val identifier = new DelegationTokenIdentifier() + identifier.readFields(new DataInputStream(new ByteArrayInputStream(t.getIdentifier))) + val interval = newExpiration - identifier.getIssueDate + logInfo(s"Renewal Interval set to $interval") + interval + } + /** * Set up the environment for launching our ApplicationMaster container. */ @@ -396,7 +439,16 @@ private[spark] class Client( env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() - + if (loginFromKeytab) { + val remoteFs = FileSystem.get(hadoopConf) + val stagingDirPath = new Path(remoteFs.getHomeDirectory, stagingDir) + val credentialsFile = "credentials-" + UUID.randomUUID().toString + sparkConf.set( + "spark.yarn.credentials.file", new Path(stagingDirPath, credentialsFile).toString) + logInfo(s"Credentials file set to: $credentialsFile") + val renewalInterval = getTokenRenewalInterval(stagingDirPath) + sparkConf.set("spark.yarn.token.renewal.interval", renewalInterval.toString) + } // Set the environment variables to be passed on to the executors. distCacheMgr.setDistFilesEnv(env) distCacheMgr.setDistArchivesEnv(env) @@ -461,7 +513,6 @@ private[spark] class Client( private def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse) : ContainerLaunchContext = { logInfo("Setting up container launch context for our AM") - val appId = newAppResponse.getApplicationId val appStagingDir = getAppStagingDir(appId) val localResources = prepareLocalResources(appStagingDir) @@ -542,6 +593,10 @@ private[spark] class Client( } javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } + + sparkConf.getOption("spark.yarn.am.extraLibraryPath").foreach { paths => + prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(paths))) + } } // For log4j configuration to reference @@ -632,6 +687,24 @@ private[spark] class Client( amContainer } + def setupCredentials(): Unit = { + if (args.principal != null) { + require(args.keytab != null, "Keytab must be specified when principal is specified.") + logInfo("Attempting to login to the Kerberos" + + s" using principal: ${args.principal} and keytab: ${args.keytab}") + val f = new File(args.keytab) + // Generate a file name that can be used for the keytab file, that does not conflict + // with any user file. + val keytabFileName = f.getName + "-" + UUID.randomUUID().toString + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + loginFromKeytab = true + sparkConf.set("spark.yarn.keytab", keytabFileName) + sparkConf.set("spark.yarn.principal", args.principal) + logInfo("Successfully logged into the KDC.") + } + credentials = UserGroupInformation.getCurrentUser.getCredentials + } + /** * Report the state of an application until it has exited, either successfully or * due to some failure, then return a pair of the yarn application state (FINISHED, FAILED, @@ -987,46 +1060,6 @@ object Client extends Logging { private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit = YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path) - /** - * Get the list of namenodes the user may access. - */ - private[yarn] def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { - sparkConf.get("spark.yarn.access.namenodes", "") - .split(",") - .map(_.trim()) - .filter(!_.isEmpty) - .map(new Path(_)) - .toSet - } - - private[yarn] def getTokenRenewer(conf: Configuration): String = { - val delegTokenRenewer = Master.getMasterPrincipal(conf) - logDebug("delegation token renewer is: " + delegTokenRenewer) - if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { - val errorMessage = "Can't get Master Kerberos principal for use as renewer" - logError(errorMessage) - throw new SparkException(errorMessage) - } - delegTokenRenewer - } - - /** - * Obtains tokens for the namenodes passed in and adds them to the credentials. - */ - private def obtainTokensForNamenodes( - paths: Set[Path], - conf: Configuration, - creds: Credentials): Unit = { - if (UserGroupInformation.isSecurityEnabled()) { - val delegTokenRenewer = getTokenRenewer(conf) - paths.foreach { dst => - val dstFs = dst.getFileSystem(conf) - logDebug("getting token for namenode: " + dst) - dstFs.addDelegationTokens(delegTokenRenewer, creds) - } - } - } - /** * Obtains token for the Hive metastore and adds them to the credentials. */ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 1423533470fc0..5653c9f14dc6d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -42,6 +42,8 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var amCores: Int = 1 var appName: String = "Spark" var priority = 0 + var principal: String = null + var keytab: String = null def isClusterMode: Boolean = userClass != null private var driverMemory: Int = 512 // MB @@ -231,6 +233,14 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) archives = value args = tail + case ("--principal") :: value :: tail => + principal = value + args = tail + + case ("--keytab") :: value :: tail => + keytab = value + args = tail + case Nil => case _ => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala new file mode 100644 index 0000000000000..229c2c4d5eb36 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.yarn + +import java.util.concurrent.{Executors, TimeUnit} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.{Credentials, UserGroupInformation} + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.{ThreadUtils, Utils} + +import scala.util.control.NonFatal + +private[spark] class ExecutorDelegationTokenUpdater( + sparkConf: SparkConf, + hadoopConf: Configuration) extends Logging { + + @volatile private var lastCredentialsFileSuffix = 0 + + private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + + private val delegationTokenRenewer = + Executors.newSingleThreadScheduledExecutor( + ThreadUtils.namedThreadFactory("Delegation Token Refresh Thread")) + + // On the executor, this thread wakes up and picks up new tokens from HDFS, if any. + private val executorUpdaterRunnable = + new Runnable { + override def run(): Unit = Utils.logUncaughtExceptions(updateCredentialsIfRequired()) + } + + def updateCredentialsIfRequired(): Unit = { + try { + val credentialsFilePath = new Path(credentialsFile) + val remoteFs = FileSystem.get(hadoopConf) + SparkHadoopUtil.get.listFilesSorted( + remoteFs, credentialsFilePath.getParent, + credentialsFilePath.getName, SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) + .lastOption.foreach { credentialsStatus => + val suffix = SparkHadoopUtil.get.getSuffixForCredentialsPath(credentialsStatus.getPath) + if (suffix > lastCredentialsFileSuffix) { + logInfo("Reading new delegation tokens from " + credentialsStatus.getPath) + val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsStatus.getPath) + lastCredentialsFileSuffix = suffix + UserGroupInformation.getCurrentUser.addCredentials(newCredentials) + logInfo("Tokens updated from credentials file.") + } else { + // Check every hour to see if new credentials arrived. + logInfo("Updated delegation tokens were expected, but the driver has not updated the " + + "tokens yet, will check again in an hour.") + delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.HOURS) + return + } + } + val timeFromNowToRenewal = + SparkHadoopUtil.get.getTimeFromNowToRenewal( + sparkConf, 0.8, UserGroupInformation.getCurrentUser.getCredentials) + if (timeFromNowToRenewal <= 0) { + executorUpdaterRunnable.run() + } else { + logInfo(s"Scheduling token refresh from HDFS in $timeFromNowToRenewal millis.") + delegationTokenRenewer.schedule( + executorUpdaterRunnable, timeFromNowToRenewal, TimeUnit.MILLISECONDS) + } + } catch { + // Since the file may get deleted while we are reading it, catch the Exception and come + // back in an hour to try again + case NonFatal(e) => + logWarning("Error while trying to update credentials, will try again in 1 hour", e) + delegationTokenRenewer.schedule(executorUpdaterRunnable, 1, TimeUnit.HOURS) + } + } + + private def getCredentialsFromHDFSFile(remoteFs: FileSystem, tokenPath: Path): Credentials = { + val stream = remoteFs.open(tokenPath) + try { + val newCredentials = new Credentials() + newCredentials.readTokenStorageStream(stream) + newCredentials + } finally { + stream.close() + } + } + + def stop(): Unit = { + delegationTokenRenewer.shutdown() + } + +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 5881dc5ffa3ad..ba91872107d0c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -24,18 +24,19 @@ import java.util.regex.Pattern import scala.collection.mutable.HashMap import scala.util.Try +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.{Master, JobConf} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records.{Priority, ApplicationAccessType} -import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.Utils /** @@ -43,6 +44,8 @@ import org.apache.spark.util.Utils */ class YarnSparkHadoopUtil extends SparkHadoopUtil { + private var tokenRenewer: Option[ExecutorDelegationTokenUpdater] = None + override def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { dest.addCredentials(source.getCredentials()) } @@ -82,6 +85,57 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { if (credentials != null) credentials.getSecretKey(new Text(key)) else null } + /** + * Get the list of namenodes the user may access. + */ + def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { + sparkConf.get("spark.yarn.access.namenodes", "") + .split(",") + .map(_.trim()) + .filter(!_.isEmpty) + .map(new Path(_)) + .toSet + } + + def getTokenRenewer(conf: Configuration): String = { + val delegTokenRenewer = Master.getMasterPrincipal(conf) + logDebug("delegation token renewer is: " + delegTokenRenewer) + if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { + val errorMessage = "Can't get Master Kerberos principal for use as renewer" + logError(errorMessage) + throw new SparkException(errorMessage) + } + delegTokenRenewer + } + + /** + * Obtains tokens for the namenodes passed in and adds them to the credentials. + */ + def obtainTokensForNamenodes( + paths: Set[Path], + conf: Configuration, + creds: Credentials, + renewer: Option[String] = None + ): Unit = { + if (UserGroupInformation.isSecurityEnabled()) { + val delegTokenRenewer = renewer.getOrElse(getTokenRenewer(conf)) + paths.foreach { dst => + val dstFs = dst.getFileSystem(conf) + logInfo("getting token for namenode: " + dst) + dstFs.addDelegationTokens(delegTokenRenewer, creds) + } + } + } + + private[spark] override def startExecutorDelegationTokenRenewer(sparkConf: SparkConf): Unit = { + tokenRenewer = Some(new ExecutorDelegationTokenUpdater(sparkConf, conf)) + tokenRenewer.get.updateCredentialsIfRequired() + } + + private[spark] override def stopExecutorDelegationTokenRenewer(): Unit = { + tokenRenewer.foreach(_.stop()) + } + } object YarnSparkHadoopUtil { @@ -100,6 +154,14 @@ object YarnSparkHadoopUtil { // request types (like map/reduce in hadoop for example) val RM_REQUEST_PRIORITY = Priority.newInstance(1) + def get: YarnSparkHadoopUtil = { + val yarnMode = java.lang.Boolean.valueOf( + System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) + if (!yarnMode) { + throw new SparkException("YarnSparkHadoopUtil is not available in non-YARN mode!") + } + SparkHadoopUtil.get.asInstanceOf[YarnSparkHadoopUtil] + } /** * Add a path variable to the given environment map. * If the map already contains this key, append the value to the existing value instead. @@ -212,3 +274,4 @@ object YarnSparkHadoopUtil { classPathSeparatorField.get(null).asInstanceOf[String] } } + diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index a51c2005cb472..508819e242a26 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -151,57 +151,6 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { } } - test("check access nns empty") { - val sparkConf = new SparkConf() - sparkConf.set("spark.yarn.access.namenodes", "") - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set()) - } - - test("check access nns unset") { - val sparkConf = new SparkConf() - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set()) - } - - test("check access nns") { - val sparkConf = new SparkConf() - sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032") - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set(new Path("hdfs://nn1:8032"))) - } - - test("check access nns space") { - val sparkConf = new SparkConf() - sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032, ") - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set(new Path("hdfs://nn1:8032"))) - } - - test("check access two nns") { - val sparkConf = new SparkConf() - sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032,hdfs://nn2:8032") - val nns = Client.getNameNodesToAccess(sparkConf) - nns should be(Set(new Path("hdfs://nn1:8032"), new Path("hdfs://nn2:8032"))) - } - - test("check token renewer") { - val hadoopConf = new Configuration() - hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") - hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") - val renewer = Client.getTokenRenewer(hadoopConf) - renewer should be ("yarn/myrm:8032@SPARKTEST.COM") - } - - test("check token renewer default") { - val hadoopConf = new Configuration() - val caught = - intercept[SparkException] { - Client.getTokenRenewer(hadoopConf) - } - assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") - } - object Fixtures { val knownDefYarnAppCP: Seq[String] = diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 3877da4120e7c..d3c606e0ed998 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -86,6 +86,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit tempDir = Utils.createTempDir() logConfDir = new File(tempDir, "log4j") logConfDir.mkdir() + System.setProperty("SPARK_YARN_MODE", "true") val logConfFile = new File(logConfDir, "log4j.properties") Files.write(LOG4J_CONF, logConfFile, UTF_8) @@ -128,6 +129,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit override def afterAll() { yarnCluster.stop() + System.clearProperty("SPARK_YARN_MODE") super.afterAll() } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 9395316b71ff4..e10b985c3c236 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import com.google.common.io.{ByteStreams, Files} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -27,7 +29,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.hadoop.yarn.api.records.ApplicationAccessType -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.util.Utils @@ -173,4 +175,62 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { YarnSparkHadoopUtil.getClassPathSeparator() should be (":") } } + + test("check access nns empty") { + val sparkConf = new SparkConf() + val util = new YarnSparkHadoopUtil + sparkConf.set("spark.yarn.access.namenodes", "") + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns unset") { + val sparkConf = new SparkConf() + val util = new YarnSparkHadoopUtil + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set()) + } + + test("check access nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032") + val util = new YarnSparkHadoopUtil + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access nns space") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032, ") + val util = new YarnSparkHadoopUtil + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"))) + } + + test("check access two nns") { + val sparkConf = new SparkConf() + sparkConf.set("spark.yarn.access.namenodes", "hdfs://nn1:8032,hdfs://nn2:8032") + val util = new YarnSparkHadoopUtil + val nns = util.getNameNodesToAccess(sparkConf) + nns should be(Set(new Path("hdfs://nn1:8032"), new Path("hdfs://nn2:8032"))) + } + + test("check token renewer") { + val hadoopConf = new Configuration() + hadoopConf.set("yarn.resourcemanager.address", "myrm:8033") + hadoopConf.set("yarn.resourcemanager.principal", "yarn/myrm:8032@SPARKTEST.COM") + val util = new YarnSparkHadoopUtil + val renewer = util.getTokenRenewer(hadoopConf) + renewer should be ("yarn/myrm:8032@SPARKTEST.COM") + } + + test("check token renewer default") { + val hadoopConf = new Configuration() + val util = new YarnSparkHadoopUtil + val caught = + intercept[SparkException] { + util.getTokenRenewer(hadoopConf) + } + assert(caught.getMessage === "Can't get Master Kerberos principal for use as renewer") + } }