From f9d6220c792b779be385f3022d146911a22c2130 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Mon, 22 Sep 2014 13:47:43 -0700 Subject: [PATCH 01/22] [SPARK-3578] Fix upper bound in GraphGenerators.sampleLogNormal GraphGenerators.sampleLogNormal is supposed to return an integer strictly less than maxVal. However, it violates this guarantee. It generates its return value as follows: ```scala var X: Double = maxVal while (X >= maxVal) { val Z = rand.nextGaussian() X = math.exp(mu + sigma*Z) } math.round(X.toFloat) ``` When X is sampled to be close to (but less than) maxVal, then it will pass the while loop condition, but the rounded result will be equal to maxVal, which will violate the guarantee. For example, if maxVal is 5 and X is 4.9, then X < maxVal, but `math.round(X.toFloat)` is 5. This PR instead rounds X before checking the loop condition, guaranteeing that the condition will hold for the return value. Author: Ankur Dave Closes #2439 from ankurdave/SPARK-3578 and squashes the following commits: f6655e5 [Ankur Dave] Go back to math.floor 5900c22 [Ankur Dave] Round X in loop condition 6fd5fb1 [Ankur Dave] Run sampleLogNormal bounds check 1000 times 1638598 [Ankur Dave] Round down in sampleLogNormal to guarantee upper bound --- .../org/apache/spark/graphx/util/GraphGenerators.scala | 2 +- .../apache/spark/graphx/util/GraphGeneratorsSuite.scala | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index b8309289fe475..8a13c74221546 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -118,7 +118,7 @@ object GraphGenerators { val Z = rand.nextGaussian() X = math.exp(mu + sigma*Z) } - math.round(X.toFloat) + math.floor(X).toInt } /** diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index b346d4db2ef96..3abefbe52fa8a 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -64,8 +64,11 @@ class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { val sigma = 1.3 val maxVal = 100 - val dstId = GraphGenerators.sampleLogNormal(mu, sigma, maxVal) - assert(dstId < maxVal) + val trials = 1000 + for (i <- 1 to trials) { + val dstId = GraphGenerators.sampleLogNormal(mu, sigma, maxVal) + assert(dstId < maxVal) + } val dstId_round1 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 12345) val dstId_round2 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 12345) From 14f8c340402366cb998c563b3f7d9ff7d9940271 Mon Sep 17 00:00:00 2001 From: "peng.zhang" Date: Tue, 23 Sep 2014 08:45:56 -0500 Subject: [PATCH 02/22] [YARN] SPARK-2668: Add variable of yarn log directory for reference from the log4j configuration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Assign value of yarn container log directory to java opts "spark.yarn.app.container.log.dir", So user defined log4j.properties can reference this value and write log to YARN container's log directory. Otherwise, user defined file appender will only write to container's CWD, and log files in CWD will not be displayed on YARN UIļ¼Œand either cannot be aggregated to HDFS log directory after job finished. User defined log4j.properties reference example: log4j.appender.rolling_file.File = ${spark.yarn.app.container.log.dir}/spark.log Author: peng.zhang Closes #1573 from renozhang/yarn-log-dir and squashes the following commits: 16c5cb8 [peng.zhang] Update doc f2b5e2a [peng.zhang] Change variable's name, and update running-on-yarn.md 503ea2d [peng.zhang] Support log4j log to yarn container dir --- docs/running-on-yarn.md | 2 ++ .../main/scala/org/apache/spark/deploy/yarn/ClientBase.scala | 3 +++ .../org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala | 3 +++ 3 files changed, 8 insertions(+) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 74bcc2eeb65f6..4b3a49eca7007 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -205,6 +205,8 @@ Note that for the first option, both executors and the application master will s log4j configuration, which may cause issues when they run on the same node (e.g. trying to write to the same log file). +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use "${spark.yarn.app.container.log.dir}" in your log4j.properties. For example, log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log. For streaming application, configuring RollingFileAppender and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. + # Important notes - Before Hadoop 2.2, YARN does not support cores in container resource requests. Thus, when running against an earlier version, the numbers of cores given via command line arguments cannot be passed to YARN. Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index c96f731923d22..6ae4d496220a5 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -388,6 +388,9 @@ trait ClientBase extends Logging { .foreach(p => javaOpts += s"-Djava.library.path=$p") } + // For log4j configuration to reference + javaOpts += "-D=spark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + val userClass = if (args.userClass != null) { Seq("--class", YarnSparkHadoopUtil.escapeForShell(args.userClass)) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 312d82a649792..f56f72cafe50e 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -98,6 +98,9 @@ trait ExecutorRunnableUtil extends Logging { } */ + // For log4j configuration to reference + javaOpts += "-D=spark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server", // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. From c4022dd52b4827323ff956632dc7623f546da937 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 23 Sep 2014 11:20:52 -0500 Subject: [PATCH 03/22] [SPARK-3477] Clean up code in Yarn Client / ClientBase This is part of a broader effort to clean up the Yarn integration code after #2020. The high-level changes in this PR include: - Removing duplicate code, especially across the alpha and stable APIs - Simplify unnecessarily complex method signatures and hierarchies - Rename unclear variable and method names - Organize logging output produced when the user runs Spark on Yarn - Extensively add documentation - Privatize classes where possible I have tested the stable API on a Hadoop 2.4 cluster. I tested submitting a jar that references classes in other jars in both client and cluster mode. I also made changes in the alpha API, though I do not have access to an alpha cluster. I have verified that it compiles, but it would be ideal if others can help test it. For those interested in some examples in detail, please read on. -------------------------------------------------------------------------------------------------------- ***Appendix*** - The loop to `getApplicationReport` from the RM is duplicated in 4 places: in the stable `Client`, alpha `Client`, and twice in `YarnClientSchedulerBackend`. We should not have different loops for client and cluster deploy modes. - There are many fragmented small helper methods that are only used once and should just be inlined. For instance, `ClientBase#getLocalPath` returns `null` on certain conditions, and its only caller `ClientBase#addFileToClasspath` checks whether the value returned is `null`. We could just have the caller check on that same condition to avoid passing `null`s around. - In `YarnSparkHadoopUtil#addToEnvironment`, we take in an argument `classpathSeparator` that always has the same value upstream (i.e. `File.pathSeparator`). This argument is now removed from the signature and all callers of this method upstream. - `ClientBase#copyRemoteFile` is now renamed to `copyFileToRemote`. It was unclear whether we are copying a remote file to our local file system, or copying a locally visible file to a remote file system. Also, even the content of the method has inaccurately named variables. We use `val remoteFs` to signify the file system of the locally visible file and `val fs` to signify the remote, destination file system. These are now renamed `srcFs` and `destFs` respectively. - We currently log the AM container's environment and resource mappings directly as Scala collections. This is incredibly hard to read and probably too verbose for the average Spark user. In other modes (e.g. standalone), we also don't log the launch commands by default, so the logging level of these information is now set to `DEBUG`. - None of these classes (`Client`, `ClientBase`, `YarnSparkHadoopUtil` etc.) is intended to be used by a Spark application (the user should go through Spark submit instead). At the very least they should be `private[spark]`. Author: Andrew Or Closes #2350 from andrewor14/yarn-cleanup and squashes the following commits: 39e8c7b [Andrew Or] Address review comments 6619f9b [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-cleanup 2ca6d64 [Andrew Or] Improve logging in application monitor a3b9693 [Andrew Or] Minor changes 7dd6298 [Andrew Or] Simplify ClientBase#monitorApplication 547487c [Andrew Or] Provide default values for null application report entries a0ad1e9 [Andrew Or] Fix class not found error 1590141 [Andrew Or] Address review comments 45ccdea [Andrew Or] Remove usages of getAMMemory d8e33b6 [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-cleanup ed0b42d [Andrew Or] Fix alpha compilation error c0587b4 [Andrew Or] Merge branch 'master' of github.com:apache/spark into yarn-cleanup 6d74888 [Andrew Or] Minor comment changes 6573c1d [Andrew Or] Clean up, simplify and document code for setting classpaths e4779b6 [Andrew Or] Clean up log messages + variable naming in ClientBase 8766d37 [Andrew Or] Heavily add documentation to Client* classes + various clean-ups 6c94d79 [Andrew Or] Various cleanups in ClientBase and ClientArguments ef7069a [Andrew Or] Clean up YarnClientSchedulerBackend more 6de9072 [Andrew Or] Guard against potential NPE in debug logging mode fabe4c4 [Andrew Or] Reuse more code in YarnClientSchedulerBackend 3f941dc [Andrew Or] First cut at simplifying the Client (stable and alpha) --- .../org/apache/spark/deploy/yarn/Client.scala | 145 ++-- .../spark/deploy/yarn/ClientArguments.scala | 67 +- .../apache/spark/deploy/yarn/ClientBase.scala | 682 +++++++++++------- .../yarn/ClientDistributedCacheManager.scala | 97 +-- .../deploy/yarn/ExecutorRunnableUtil.scala | 16 +- .../deploy/yarn/YarnSparkHadoopUtil.scala | 63 +- .../cluster/YarnClientSchedulerBackend.scala | 145 ++-- .../spark/deploy/yarn/ClientBaseSuite.scala | 18 +- .../org/apache/spark/deploy/yarn/Client.scala | 167 ++--- 9 files changed, 738 insertions(+), 662 deletions(-) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index aff9ab71f0937..5a20532315e59 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -23,13 +23,11 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.YarnClientImpl import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.ipc.YarnRPC -import org.apache.hadoop.yarn.util.{Apps, Records} +import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil @@ -37,7 +35,10 @@ import org.apache.spark.deploy.SparkHadoopUtil /** * Version of [[org.apache.spark.deploy.yarn.ClientBase]] tailored to YARN's alpha API. */ -class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: SparkConf) +private[spark] class Client( + val args: ClientArguments, + val hadoopConf: Configuration, + val sparkConf: SparkConf) extends YarnClientImpl with ClientBase with Logging { def this(clientArgs: ClientArguments, spConf: SparkConf) = @@ -45,112 +46,86 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def this(clientArgs: ClientArguments) = this(clientArgs, new SparkConf()) - val args = clientArgs - val conf = hadoopConf - val sparkConf = spConf - var rpc: YarnRPC = YarnRPC.create(conf) - val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + val yarnConf: YarnConfiguration = new YarnConfiguration(hadoopConf) + /* ------------------------------------------------------------------------------------- * + | The following methods have much in common in the stable and alpha versions of Client, | + | but cannot be implemented in the parent trait due to subtle API differences across | + | hadoop versions. | + * ------------------------------------------------------------------------------------- */ - // for client user who want to monitor app status by itself. - def runApp() = { - validateArgs() - + /** Submit an application running our ApplicationMaster to the ResourceManager. */ + override def submitApplication(): ApplicationId = { init(yarnConf) start() - logClusterResourceDetails() - val newApp = super.getNewApplication() - val appId = newApp.getApplicationId() + logInfo("Requesting a new application from cluster with %d NodeManagers" + .format(getYarnClusterMetrics.getNumNodeManagers)) - verifyClusterResources(newApp) - val appContext = createApplicationSubmissionContext(appId) - val appStagingDir = getAppStagingDir(appId) - val localResources = prepareLocalResources(appStagingDir) - val env = setupLaunchEnv(localResources, appStagingDir) - val amContainer = createContainerLaunchContext(newApp, localResources, env) + // Get a new application from our RM + val newAppResponse = getNewApplication() + val appId = newAppResponse.getApplicationId() - val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] - // Memory for the ApplicationMaster. - capability.setMemory(args.amMemory + memoryOverhead) - amContainer.setResource(capability) + // Verify whether the cluster has enough resources for our AM + verifyClusterResources(newAppResponse) - appContext.setQueue(args.amQueue) - appContext.setAMContainerSpec(amContainer) - appContext.setUser(UserGroupInformation.getCurrentUser().getShortUserName()) + // Set up the appropriate contexts to launch our AM + val containerContext = createContainerLaunchContext(newAppResponse) + val appContext = createApplicationSubmissionContext(appId, containerContext) - submitApp(appContext) + // Finally, submit and monitor the application + logInfo(s"Submitting application ${appId.getId} to ResourceManager") + submitApplication(appContext) appId } - def run() { - val appId = runApp() - monitorApplication(appId) - } - - def logClusterResourceDetails() { - val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics - logInfo("Got cluster metric info from ASM, numNodeManagers = " + - clusterMetrics.getNumNodeManagers) + /** + * Set up a context for launching our ApplicationMaster container. + * In the Yarn alpha API, the memory requirements of this container must be set in + * the ContainerLaunchContext instead of the ApplicationSubmissionContext. + */ + override def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse) + : ContainerLaunchContext = { + val containerContext = super.createContainerLaunchContext(newAppResponse) + val capability = Records.newRecord(classOf[Resource]) + capability.setMemory(args.amMemory + amMemoryOverhead) + containerContext.setResource(capability) + containerContext } - - def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = { - logInfo("Setting up application submission context for ASM") + /** Set up the context for submitting our ApplicationMaster. */ + def createApplicationSubmissionContext( + appId: ApplicationId, + containerContext: ContainerLaunchContext): ApplicationSubmissionContext = { val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) appContext.setApplicationId(appId) appContext.setApplicationName(args.appName) + appContext.setQueue(args.amQueue) + appContext.setAMContainerSpec(containerContext) + appContext.setUser(UserGroupInformation.getCurrentUser.getShortUserName) appContext } - def setupSecurityToken(amContainer: ContainerLaunchContext) = { - // Setup security tokens. + /** + * Set up security tokens for launching our ApplicationMaster container. + * ContainerLaunchContext#setContainerTokens is renamed `setTokens` in the stable API. + */ + override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = { val dob = new DataOutputBuffer() credentials.writeTokenStorageToStream(dob) amContainer.setContainerTokens(ByteBuffer.wrap(dob.getData())) } - def submitApp(appContext: ApplicationSubmissionContext) = { - // Submit the application to the applications manager. - logInfo("Submitting application to ASM") - super.submitApplication(appContext) - } - - def monitorApplication(appId: ApplicationId): Boolean = { - val interval = sparkConf.getLong("spark.yarn.report.interval", 1000) - - while (true) { - Thread.sleep(interval) - val report = super.getApplicationReport(appId) - - logInfo("Application report from ASM: \n" + - "\t application identifier: " + appId.toString() + "\n" + - "\t appId: " + appId.getId() + "\n" + - "\t clientToken: " + report.getClientToken() + "\n" + - "\t appDiagnostics: " + report.getDiagnostics() + "\n" + - "\t appMasterHost: " + report.getHost() + "\n" + - "\t appQueue: " + report.getQueue() + "\n" + - "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + - "\t appStartTime: " + report.getStartTime() + "\n" + - "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + - "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" + - "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" + - "\t appUser: " + report.getUser() - ) - - val state = report.getYarnApplicationState() - if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { - return true - } - } - true - } + /** + * Return the security token used by this client to communicate with the ApplicationMaster. + * If no security is enabled, the token returned by the report is null. + * ApplicationReport#getClientToken is renamed `getClientToAMToken` in the stable API. + */ + override def getClientToken(report: ApplicationReport): String = + Option(report.getClientToken).getOrElse("") } object Client { - def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { println("WARNING: This client is deprecated and will be removed in a " + @@ -158,19 +133,17 @@ object Client { } // Set an env variable indicating we are running in YARN mode. - // Note that anything with SPARK prefix gets propagated to all (remote) processes + // Note that any env variable with the SPARK_ prefix gets propagated to all (remote) processes System.setProperty("SPARK_YARN_MODE", "true") - val sparkConf = new SparkConf try { val args = new ClientArguments(argStrings, sparkConf) new Client(args, sparkConf).run() } catch { - case e: Exception => { + case e: Exception => Console.err.println(e.getMessage) System.exit(1) - } } System.exit(0) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 40d8d6d6e6961..201b742736c6e 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -17,15 +17,14 @@ package org.apache.spark.deploy.yarn -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkConf -import org.apache.spark.scheduler.InputFormatInfo import org.apache.spark.util.{Utils, IntParam, MemoryParam} // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! -class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { +private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) { var addJars: String = null var files: String = null var archives: String = null @@ -35,28 +34,56 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { var executorMemory = 1024 // MB var executorCores = 1 var numExecutors = 2 - var amQueue = sparkConf.get("QUEUE", "default") + var amQueue = sparkConf.get("spark.yarn.queue", "default") var amMemory: Int = 512 // MB var appName: String = "Spark" var priority = 0 - parseArgs(args.toList) + // Additional memory to allocate to containers + // For now, use driver's memory overhead as our AM container's memory overhead + val amMemoryOverhead = sparkConf.getInt( + "spark.yarn.driver.memoryOverhead", YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) + val executorMemoryOverhead = sparkConf.getInt( + "spark.yarn.executor.memoryOverhead", YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) - // env variable SPARK_YARN_DIST_ARCHIVES/SPARK_YARN_DIST_FILES set in yarn-client then - // it should default to hdfs:// - files = Option(files).getOrElse(sys.env.get("SPARK_YARN_DIST_FILES").orNull) - archives = Option(archives).getOrElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES").orNull) + parseArgs(args.toList) + loadEnvironmentArgs() + validateArgs() + + /** Load any default arguments provided through environment variables and Spark properties. */ + private def loadEnvironmentArgs(): Unit = { + // For backward compatibility, SPARK_YARN_DIST_{ARCHIVES/FILES} should be resolved to hdfs://, + // while spark.yarn.dist.{archives/files} should be resolved to file:// (SPARK-2051). + files = Option(files) + .orElse(sys.env.get("SPARK_YARN_DIST_FILES")) + .orElse(sparkConf.getOption("spark.yarn.dist.files").map(p => Utils.resolveURIs(p))) + .orNull + archives = Option(archives) + .orElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES")) + .orElse(sparkConf.getOption("spark.yarn.dist.archives").map(p => Utils.resolveURIs(p))) + .orNull + } - // spark.yarn.dist.archives/spark.yarn.dist.files defaults to use file:// if not specified, - // for both yarn-client and yarn-cluster - files = Option(files).getOrElse(sparkConf.getOption("spark.yarn.dist.files"). - map(p => Utils.resolveURIs(p)).orNull) - archives = Option(archives).getOrElse(sparkConf.getOption("spark.yarn.dist.archives"). - map(p => Utils.resolveURIs(p)).orNull) + /** + * Fail fast if any arguments provided are invalid. + * This is intended to be called only after the provided arguments have been parsed. + */ + private def validateArgs(): Unit = { + // TODO: memory checks are outdated (SPARK-3476) + Map[Boolean, String]( + (numExecutors <= 0) -> "You must specify at least 1 executor!", + (amMemory <= amMemoryOverhead) -> s"AM memory must be > $amMemoryOverhead MB", + (executorMemory <= executorMemoryOverhead) -> + s"Executor memory must be > $executorMemoryOverhead MB" + ).foreach { case (errorCondition, errorMessage) => + if (errorCondition) { + throw new IllegalArgumentException(errorMessage + "\n" + getUsageMessage()) + } + } + } private def parseArgs(inputArgs: List[String]): Unit = { - val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]() - + val userArgsBuffer = new ArrayBuffer[String]() var args = inputArgs while (!args.isEmpty) { @@ -138,16 +165,14 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { userArgs = userArgsBuffer.readOnly } - - def getUsageMessage(unknownParam: Any = null): String = { + private def getUsageMessage(unknownParam: List[String] = null): String = { val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else "" - message + "Usage: org.apache.spark.deploy.yarn.Client [options] \n" + "Options:\n" + " --jar JAR_PATH Path to your application's JAR file (required in yarn-cluster mode)\n" + " --class CLASS_NAME Name of your application's main class (required)\n" + - " --arg ARGS Argument to be passed to your application's main class.\n" + + " --arg ARG Argument to be passed to your application's main class.\n" + " Multiple invocations are possible, each will be passed in order.\n" + " --num-executors NUM Number of executors to start (Default: 2)\n" + " --executor-cores NUM Number of cores for the executors (Default: 1).\n" + diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 6ae4d496220a5..4870b0cb3ddaf 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy.yarn -import java.io.File import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import scala.collection.JavaConversions._ @@ -37,154 +36,107 @@ import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.Records + import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} /** - * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The - * Client submits an application to the YARN ResourceManager. + * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. + * The Client submits an application to the YARN ResourceManager. */ -trait ClientBase extends Logging { - val args: ClientArguments - val conf: Configuration - val sparkConf: SparkConf - val yarnConf: YarnConfiguration - val credentials = UserGroupInformation.getCurrentUser().getCredentials() - private val SPARK_STAGING: String = ".sparkStaging" +private[spark] trait ClientBase extends Logging { + import ClientBase._ + + protected val args: ClientArguments + protected val hadoopConf: Configuration + protected val sparkConf: SparkConf + protected val yarnConf: YarnConfiguration + protected val credentials = UserGroupInformation.getCurrentUser.getCredentials + protected val amMemoryOverhead = args.amMemoryOverhead // MB + protected val executorMemoryOverhead = args.executorMemoryOverhead // MB private val distCacheMgr = new ClientDistributedCacheManager() - // Staging directory is private! -> rwx-------- - val STAGING_DIR_PERMISSION: FsPermission = - FsPermission.createImmutable(Integer.parseInt("700", 8).toShort) - // App files are world-wide readable and owner writable -> rw-r--r-- - val APP_FILE_PERMISSION: FsPermission = - FsPermission.createImmutable(Integer.parseInt("644", 8).toShort) - - // Additional memory overhead - in mb. - protected def memoryOverhead: Int = sparkConf.getInt("spark.yarn.driver.memoryOverhead", - YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) - - // TODO(harvey): This could just go in ClientArguments. - def validateArgs() = { - Map( - (args.numExecutors <= 0) -> "Error: You must specify at least 1 executor!", - (args.amMemory <= memoryOverhead) -> ("Error: AM memory size must be" + - "greater than: " + memoryOverhead), - (args.executorMemory <= memoryOverhead) -> ("Error: Executor memory size" + - "must be greater than: " + memoryOverhead.toString) - ).foreach { case(cond, errStr) => - if (cond) { - logError(errStr) - throw new IllegalArgumentException(args.getUsageMessage()) - } - } - } - - def getAppStagingDir(appId: ApplicationId): String = { - SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR - } - - def verifyClusterResources(app: GetNewApplicationResponse) = { - val maxMem = app.getMaximumResourceCapability().getMemory() - logInfo("Max mem capabililty of a single resource in this cluster " + maxMem) - - // If we have requested more then the clusters max for a single resource then exit. - if (args.executorMemory > maxMem) { - val errorMessage = - "Required executor memory (%d MB), is above the max threshold (%d MB) of this cluster." - .format(args.executorMemory, maxMem) - - logError(errorMessage) - throw new IllegalArgumentException(errorMessage) - } - val amMem = args.amMemory + memoryOverhead + /** + * Fail fast if we have requested more resources per container than is available in the cluster. + */ + protected def verifyClusterResources(newAppResponse: GetNewApplicationResponse): Unit = { + val maxMem = newAppResponse.getMaximumResourceCapability().getMemory() + logInfo("Verifying our application has not requested more than the maximum " + + s"memory capability of the cluster ($maxMem MB per container)") + val executorMem = args.executorMemory + executorMemoryOverhead + if (executorMem > maxMem) { + throw new IllegalArgumentException(s"Required executor memory ($executorMem MB) " + + s"is above the max threshold ($maxMem MB) of this cluster!") + } + val amMem = args.amMemory + amMemoryOverhead if (amMem > maxMem) { - - val errorMessage = "Required AM memory (%d) is above the max threshold (%d) of this cluster." - .format(amMem, maxMem) - logError(errorMessage) - throw new IllegalArgumentException(errorMessage) + throw new IllegalArgumentException(s"Required AM memory ($amMem MB) " + + s"is above the max threshold ($maxMem MB) of this cluster!") } - // We could add checks to make sure the entire cluster has enough resources but that involves // getting all the node reports and computing ourselves. } - /** See if two file systems are the same or not. */ - private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { - val srcUri = srcFs.getUri() - val dstUri = destFs.getUri() - if (srcUri.getScheme() == null) { - return false - } - if (!srcUri.getScheme().equals(dstUri.getScheme())) { - return false - } - var srcHost = srcUri.getHost() - var dstHost = dstUri.getHost() - if ((srcHost != null) && (dstHost != null)) { - try { - srcHost = InetAddress.getByName(srcHost).getCanonicalHostName() - dstHost = InetAddress.getByName(dstHost).getCanonicalHostName() - } catch { - case e: UnknownHostException => - return false - } - if (!srcHost.equals(dstHost)) { - return false - } - } else if (srcHost == null && dstHost != null) { - return false - } else if (srcHost != null && dstHost == null) { - return false - } - if (srcUri.getPort() != dstUri.getPort()) { - false - } else { - true - } - } - - /** Copy the file into HDFS if needed. */ - private[yarn] def copyRemoteFile( - dstDir: Path, - originalPath: Path, + /** + * Copy the given file to a remote file system (e.g. HDFS) if needed. + * The file is only copied if the source and destination file systems are different. This is used + * for preparing resources for launching the ApplicationMaster container. Exposed for testing. + */ + def copyFileToRemote( + destDir: Path, + srcPath: Path, replication: Short, setPerms: Boolean = false): Path = { - val fs = FileSystem.get(conf) - val remoteFs = originalPath.getFileSystem(conf) - var newPath = originalPath - if (!compareFs(remoteFs, fs)) { - newPath = new Path(dstDir, originalPath.getName()) - logInfo("Uploading " + originalPath + " to " + newPath) - FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf) - fs.setReplication(newPath, replication) - if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION)) + val destFs = destDir.getFileSystem(hadoopConf) + val srcFs = srcPath.getFileSystem(hadoopConf) + var destPath = srcPath + if (!compareFs(srcFs, destFs)) { + destPath = new Path(destDir, srcPath.getName()) + logInfo(s"Uploading resource $srcPath -> $destPath") + FileUtil.copy(srcFs, srcPath, destFs, destPath, false, hadoopConf) + destFs.setReplication(destPath, replication) + if (setPerms) { + destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION)) + } + } else { + logInfo(s"Source and destination file systems are the same. Not copying $srcPath") } // Resolve any symlinks in the URI path so using a "current" symlink to point to a specific // version shows the specific version in the distributed cache configuration - val qualPath = fs.makeQualified(newPath) - val fc = FileContext.getFileContext(qualPath.toUri(), conf) - val destPath = fc.resolvePath(qualPath) - destPath + val qualifiedDestPath = destFs.makeQualified(destPath) + val fc = FileContext.getFileContext(qualifiedDestPath.toUri(), hadoopConf) + fc.resolvePath(qualifiedDestPath) } - private def qualifyForLocal(localURI: URI): Path = { - var qualifiedURI = localURI - // If not specified, assume these are in the local filesystem to keep behavior like Hadoop - if (qualifiedURI.getScheme() == null) { - qualifiedURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(qualifiedURI)).toString) - } + /** + * Given a local URI, resolve it and return a qualified local path that corresponds to the URI. + * This is used for preparing local resources to be included in the container launch context. + */ + private def getQualifiedLocalPath(localURI: URI): Path = { + val qualifiedURI = + if (localURI.getScheme == null) { + // If not specified, assume this is in the local filesystem to keep the behavior + // consistent with that of Hadoop + new URI(FileSystem.getLocal(hadoopConf).makeQualified(new Path(localURI)).toString) + } else { + localURI + } new Path(qualifiedURI) } + /** + * Upload any resources to the distributed cache if needed. If a resource is intended to be + * consumed locally, set up the appropriate config for downstream code to handle it properly. + * This is used for setting up a container launch context for our ApplicationMaster. + * Exposed for testing. + */ def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { - logInfo("Preparing Local resources") - // Upload Spark and the application JAR to the remote file system if necessary. Add them as - // local resources to the application master. - val fs = FileSystem.get(conf) + logInfo("Preparing resources for our AM container") + // Upload Spark and the application JAR to the remote file system if necessary, + // and add them as local resources to the application master. + val fs = FileSystem.get(hadoopConf) val dst = new Path(fs.getHomeDirectory(), appStagingDir) - val nns = ClientBase.getNameNodesToAccess(sparkConf) + dst - ClientBase.obtainTokensForNamenodes(nns, conf, credentials) + val nns = getNameNodesToAccess(sparkConf) + dst + obtainTokensForNamenodes(nns, hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", 3).toShort val localResources = HashMap[String, LocalResource]() @@ -200,73 +152,84 @@ trait ClientBase extends Logging { "for alternatives.") } + /** + * Copy the given main resource to the distributed cache if the scheme is not "local". + * Otherwise, set the corresponding key in our SparkConf to handle it downstream. + * Each resource is represented by a 4-tuple of: + * (1) destination resource name, + * (2) local path to the resource, + * (3) Spark property key to set if the scheme is not local, and + * (4) whether to set permissions for this resource + */ List( - (ClientBase.SPARK_JAR, ClientBase.sparkJar(sparkConf), ClientBase.CONF_SPARK_JAR), - (ClientBase.APP_JAR, args.userJar, ClientBase.CONF_SPARK_USER_JAR), - ("log4j.properties", oldLog4jConf.getOrElse(null), null) - ).foreach { case(destName, _localPath, confKey) => + (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR, false), + (APP_JAR, args.userJar, CONF_SPARK_USER_JAR, true), + ("log4j.properties", oldLog4jConf.orNull, null, false) + ).foreach { case (destName, _localPath, confKey, setPermissions) => val localPath: String = if (_localPath != null) _localPath.trim() else "" - if (! localPath.isEmpty()) { + if (!localPath.isEmpty()) { val localURI = new URI(localPath) - if (!ClientBase.LOCAL_SCHEME.equals(localURI.getScheme())) { - val setPermissions = destName.equals(ClientBase.APP_JAR) - val destPath = copyRemoteFile(dst, qualifyForLocal(localURI), replication, setPermissions) - val destFs = FileSystem.get(destPath.toUri(), conf) - distCacheMgr.addResource(destFs, conf, destPath, localResources, LocalResourceType.FILE, - destName, statCache) + if (localURI.getScheme != LOCAL_SCHEME) { + val src = getQualifiedLocalPath(localURI) + val destPath = copyFileToRemote(dst, src, replication, setPermissions) + val destFs = FileSystem.get(destPath.toUri(), hadoopConf) + distCacheMgr.addResource(destFs, hadoopConf, destPath, + localResources, LocalResourceType.FILE, destName, statCache) } else if (confKey != null) { + // If the resource is intended for local use only, handle this downstream + // by setting the appropriate property sparkConf.set(confKey, localPath) } } } + /** + * Do the same for any additional resources passed in through ClientArguments. + * Each resource category is represented by a 3-tuple of: + * (1) comma separated list of resources in this category, + * (2) resource type, and + * (3) whether to add these resources to the classpath + */ val cachedSecondaryJarLinks = ListBuffer.empty[String] - val fileLists = List( (args.addJars, LocalResourceType.FILE, true), + List( + (args.addJars, LocalResourceType.FILE, true), (args.files, LocalResourceType.FILE, false), - (args.archives, LocalResourceType.ARCHIVE, false) ) - fileLists.foreach { case (flist, resType, addToClasspath) => + (args.archives, LocalResourceType.ARCHIVE, false) + ).foreach { case (flist, resType, addToClasspath) => if (flist != null && !flist.isEmpty()) { - flist.split(',').foreach { case file: String => + flist.split(',').foreach { file => val localURI = new URI(file.trim()) - if (!ClientBase.LOCAL_SCHEME.equals(localURI.getScheme())) { + if (localURI.getScheme != LOCAL_SCHEME) { val localPath = new Path(localURI) val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) - val destPath = copyRemoteFile(dst, localPath, replication) - distCacheMgr.addResource(fs, conf, destPath, localResources, resType, - linkname, statCache) + val destPath = copyFileToRemote(dst, localPath, replication) + distCacheMgr.addResource( + fs, hadoopConf, destPath, localResources, resType, linkname, statCache) if (addToClasspath) { cachedSecondaryJarLinks += linkname } } else if (addToClasspath) { + // Resource is intended for local use only and should be added to the class path cachedSecondaryJarLinks += file.trim() } } } } - logInfo("Prepared Local resources " + localResources) - sparkConf.set(ClientBase.CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) + if (cachedSecondaryJarLinks.nonEmpty) { + sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) + } - UserGroupInformation.getCurrentUser().addCredentials(credentials) localResources } - /** Get all application master environment variables set on this SparkConf */ - def getAppMasterEnv: Seq[(String, String)] = { - val prefix = "spark.yarn.appMasterEnv." - sparkConf.getAll.filter{case (k, v) => k.startsWith(prefix)} - .map{case (k, v) => (k.substring(prefix.length), v)} - } - - - def setupLaunchEnv( - localResources: HashMap[String, LocalResource], - stagingDir: String): HashMap[String, String] = { - logInfo("Setting up the launch environment") - + /** + * Set up the environment for launching our ApplicationMaster container. + */ + private def setupLaunchEnv(stagingDir: String): HashMap[String, String] = { + logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() - val extraCp = sparkConf.getOption("spark.driver.extraClassPath") - ClientBase.populateClasspath(args, yarnConf, sparkConf, env, extraCp) + populateClasspath(args, yarnConf, sparkConf, env, extraCp) env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() @@ -275,42 +238,20 @@ trait ClientBase extends Logging { distCacheMgr.setDistFilesEnv(env) distCacheMgr.setDistArchivesEnv(env) - getAppMasterEnv.foreach { case (key, value) => - YarnSparkHadoopUtil.addToEnvironment(env, key, value, File.pathSeparator) - } + // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* + val amEnvPrefix = "spark.yarn.appMasterEnv." + sparkConf.getAll + .filter { case (k, v) => k.startsWith(amEnvPrefix) } + .map { case (k, v) => (k.substring(amEnvPrefix.length), v) } + .foreach { case (k, v) => YarnSparkHadoopUtil.addPathToEnvironment(env, k, v) } // Keep this for backwards compatibility but users should move to the config sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => // Allow users to specify some environment variables. - YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs, File.pathSeparator) - + YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) // Pass SPARK_YARN_USER_ENV itself to the AM so it can use it to set up executor environments. env("SPARK_YARN_USER_ENV") = userEnvs } - env - } - - def userArgsToString(clientArgs: ClientArguments): String = { - val prefix = " --arg " - val args = clientArgs.userArgs - val retval = new StringBuilder() - for (arg <- args) { - retval.append(prefix).append(" ").append(YarnSparkHadoopUtil.escapeForShell(arg)) - } - retval.toString - } - - def setupSecurityToken(amContainer: ContainerLaunchContext) - - def createContainerLaunchContext( - newApp: GetNewApplicationResponse, - localResources: HashMap[String, LocalResource], - env: HashMap[String, String]): ContainerLaunchContext = { - logInfo("Setting up container launch context") - val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) - amContainer.setLocalResources(localResources) - - val isLaunchingDriver = args.userClass != null // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to // executors. But we can't just set spark.executor.extraJavaOptions, because the driver's @@ -320,6 +261,7 @@ trait ClientBase extends Logging { // Note that to warn the user about the deprecation in cluster mode, some code from // SparkConf#validateSettings() is duplicated here (to avoid triggering the condition // described above). + val isLaunchingDriver = args.userClass != null if (isLaunchingDriver) { sys.env.get("SPARK_JAVA_OPTS").foreach { value => val warning = @@ -342,14 +284,30 @@ trait ClientBase extends Logging { env("SPARK_JAVA_OPTS") = value } } - amContainer.setEnvironment(env) - val amMemory = args.amMemory + env + } + + /** + * Set up a ContainerLaunchContext to launch our ApplicationMaster container. + * This sets up the launch environment, java options, and the command for launching the AM. + */ + protected def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse) + : ContainerLaunchContext = { + logInfo("Setting up container launch context for our AM") + + val appId = newAppResponse.getApplicationId + val appStagingDir = getAppStagingDir(appId) + val localResources = prepareLocalResources(appStagingDir) + val launchEnv = setupLaunchEnv(appStagingDir) + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) + amContainer.setLocalResources(localResources) + amContainer.setEnvironment(launchEnv) val javaOpts = ListBuffer[String]() // Add Xmx for AM memory - javaOpts += "-Xmx" + amMemory + "m" + javaOpts += "-Xmx" + args.amMemory + "m" val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) javaOpts += "-Djava.io.tmpdir=" + tmpDir @@ -361,8 +319,7 @@ trait ClientBase extends Logging { // Instead of using this, rely on cpusets by YARN to enforce "proper" Spark behavior in // multi-tenant environments. Not sure how default Java GC behaves if it is limited to subset // of cores on a node. - val useConcurrentAndIncrementalGC = env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && - java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC")) + val useConcurrentAndIncrementalGC = launchEnv.get("SPARK_USE_CONC_INCR_GC").exists(_.toBoolean) if (useConcurrentAndIncrementalGC) { // In our expts, using (default) throughput collector has severe perf ramifications in // multi-tenant machines @@ -380,6 +337,8 @@ trait ClientBase extends Logging { javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } + // Include driver-specific java options if we are launching a driver + val isLaunchingDriver = args.userClass != null if (isLaunchingDriver) { sparkConf.getOption("spark.driver.extraJavaOptions") .orElse(sys.env.get("SPARK_JAVA_OPTS")) @@ -397,19 +356,27 @@ trait ClientBase extends Logging { } else { Nil } + val userJar = + if (args.userJar != null) { + Seq("--jar", args.userJar) + } else { + Nil + } val amClass = if (isLaunchingDriver) { - classOf[ApplicationMaster].getName() + Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName } else { - classOf[ApplicationMaster].getName().replace("ApplicationMaster", "ExecutorLauncher") + Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName } + val userArgs = args.userArgs.flatMap { arg => + Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) + } val amArgs = - Seq(amClass) ++ userClass ++ - (if (args.userJar != null) Seq("--jar", args.userJar) else Nil) ++ - Seq("--executor-memory", args.executorMemory.toString, + Seq(amClass) ++ userClass ++ userJar ++ userArgs ++ + Seq( + "--executor-memory", args.executorMemory.toString, "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString, - userArgsToString(args)) + "--num-executors ", args.numExecutors.toString) // Command for the ApplicationMaster val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ @@ -418,41 +385,153 @@ trait ClientBase extends Logging { "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", "2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") - logInfo("Yarn AM launch context:") - logInfo(s" user class: ${args.userClass}") - logInfo(s" env: $env") - logInfo(s" command: ${commands.mkString(" ")}") - // TODO: it would be nicer to just make sure there are no null commands here val printableCommands = commands.map(s => if (s == null) "null" else s).toList amContainer.setCommands(printableCommands) - setupSecurityToken(amContainer) + logDebug("===============================================================================") + logDebug("Yarn AM launch context:") + logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}") + logDebug(" env:") + launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") } + logDebug(" resources:") + localResources.foreach { case (k, v) => logDebug(s" $k -> $v")} + logDebug(" command:") + logDebug(s" ${printableCommands.mkString(" ")}") + logDebug("===============================================================================") // send the acl settings into YARN to control who has access via YARN interfaces val securityManager = new SecurityManager(sparkConf) amContainer.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager)) + setupSecurityToken(amContainer) + UserGroupInformation.getCurrentUser().addCredentials(credentials) amContainer } + + /** + * Report the state of an application until it has exited, either successfully or + * due to some failure, then return the application state. + * + * @param appId ID of the application to monitor. + * @param returnOnRunning Whether to also return the application state when it is RUNNING. + * @param logApplicationReport Whether to log details of the application report every iteration. + * @return state of the application, one of FINISHED, FAILED, KILLED, and RUNNING. + */ + def monitorApplication( + appId: ApplicationId, + returnOnRunning: Boolean = false, + logApplicationReport: Boolean = true): YarnApplicationState = { + val interval = sparkConf.getLong("spark.yarn.report.interval", 1000) + var lastState: YarnApplicationState = null + while (true) { + Thread.sleep(interval) + val report = getApplicationReport(appId) + val state = report.getYarnApplicationState + + if (logApplicationReport) { + logInfo(s"Application report for $appId (state: $state)") + val details = Seq[(String, String)]( + ("client token", getClientToken(report)), + ("diagnostics", report.getDiagnostics), + ("ApplicationMaster host", report.getHost), + ("ApplicationMaster RPC port", report.getRpcPort.toString), + ("queue", report.getQueue), + ("start time", report.getStartTime.toString), + ("final status", report.getFinalApplicationStatus.toString), + ("tracking URL", report.getTrackingUrl), + ("user", report.getUser) + ) + + // Use more loggable format if value is null or empty + val formattedDetails = details + .map { case (k, v) => + val newValue = Option(v).filter(_.nonEmpty).getOrElse("N/A") + s"\n\t $k: $newValue" } + .mkString("") + + // If DEBUG is enabled, log report details every iteration + // Otherwise, log them every time the application changes state + if (log.isDebugEnabled) { + logDebug(formattedDetails) + } else if (lastState != state) { + logInfo(formattedDetails) + } + } + + if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + return state + } + + if (returnOnRunning && state == YarnApplicationState.RUNNING) { + return state + } + + lastState = state + } + + // Never reached, but keeps compiler happy + throw new SparkException("While loop is depleted! This should never happen...") + } + + /** + * Submit an application to the ResourceManager and monitor its state. + * This continues until the application has exited for any reason. + */ + def run(): Unit = monitorApplication(submitApplication()) + + /* --------------------------------------------------------------------------------------- * + | Methods that cannot be implemented here due to API differences across hadoop versions | + * --------------------------------------------------------------------------------------- */ + + /** Submit an application running our ApplicationMaster to the ResourceManager. */ + def submitApplication(): ApplicationId + + /** Set up security tokens for launching our ApplicationMaster container. */ + protected def setupSecurityToken(containerContext: ContainerLaunchContext): Unit + + /** Get the application report from the ResourceManager for an application we have submitted. */ + protected def getApplicationReport(appId: ApplicationId): ApplicationReport + + /** + * Return the security token used by this client to communicate with the ApplicationMaster. + * If no security is enabled, the token returned by the report is null. + */ + protected def getClientToken(report: ApplicationReport): String } -object ClientBase extends Logging { +private[spark] object ClientBase extends Logging { + + // Alias for the Spark assembly jar and the user jar val SPARK_JAR: String = "__spark__.jar" val APP_JAR: String = "__app__.jar" + + // URI scheme that identifies local resources val LOCAL_SCHEME = "local" + + // Staging directory for any temporary jars or files + val SPARK_STAGING: String = ".sparkStaging" + + // Location of any user-defined Spark jars val CONF_SPARK_JAR = "spark.yarn.jar" - /** - * This is an internal config used to propagate the location of the user's jar file to the - * driver/executors. - */ + val ENV_SPARK_JAR = "SPARK_JAR" + + // Internal config to propagate the location of the user's jar to the driver/executors val CONF_SPARK_USER_JAR = "spark.yarn.user.jar" - /** - * This is an internal config used to propagate the list of extra jars to add to the classpath - * of executors. - */ + + // Internal config to propagate the locations of any extra jars to add to the classpath + // of the executors val CONF_SPARK_YARN_SECONDARY_JARS = "spark.yarn.secondary.jars" - val ENV_SPARK_JAR = "SPARK_JAR" + + // Staging directory is private! -> rwx-------- + val STAGING_DIR_PERMISSION: FsPermission = + FsPermission.createImmutable(Integer.parseInt("700", 8).toShort) + + // App files are world-wide readable and owner writable -> rw-r--r-- + val APP_FILE_PERMISSION: FsPermission = + FsPermission.createImmutable(Integer.parseInt("644", 8).toShort) /** * Find the user-defined Spark jar if configured, or return the jar containing this @@ -461,7 +540,7 @@ object ClientBase extends Logging { * This method first looks in the SparkConf object for the CONF_SPARK_JAR key, and in the * user environment if that is not found (for backwards compatibility). */ - def sparkJar(conf: SparkConf) = { + private def sparkJar(conf: SparkConf): String = { if (conf.contains(CONF_SPARK_JAR)) { conf.get(CONF_SPARK_JAR) } else if (System.getenv(ENV_SPARK_JAR) != null) { @@ -474,16 +553,22 @@ object ClientBase extends Logging { } } - def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) = { + /** + * Return the path to the given application's staging directory. + */ + private def getAppStagingDir(appId: ApplicationId): String = { + SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR + } + + /** + * Populate the classpath entry in the given environment map with any application + * classpath specified through the Hadoop and Yarn configurations. + */ + def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]): Unit = { val classPathElementsToAdd = getYarnAppClasspath(conf) ++ getMRAppClasspath(conf) for (c <- classPathElementsToAdd.flatten) { - YarnSparkHadoopUtil.addToEnvironment( - env, - Environment.CLASSPATH.name, - c.trim, - File.pathSeparator) + YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, c.trim) } - classPathElementsToAdd } private def getYarnAppClasspath(conf: Configuration): Option[Seq[String]] = @@ -519,7 +604,7 @@ object ClientBase extends Logging { /** * In Hadoop 0.23, the MR application classpath comes with the YARN application - * classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String. + * classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String. * So we need to use reflection to retrieve it. */ def getDefaultMRApplicationClasspath: Option[Seq[String]] = { @@ -545,8 +630,16 @@ object ClientBase extends Logging { triedDefault.toOption } - def populateClasspath(args: ClientArguments, conf: Configuration, sparkConf: SparkConf, - env: HashMap[String, String], extraClassPath: Option[String] = None) { + /** + * Populate the classpath entry in the given environment map. + * This includes the user jar, Spark jar, and any extra application jars. + */ + def populateClasspath( + args: ClientArguments, + conf: Configuration, + sparkConf: SparkConf, + env: HashMap[String, String], + extraClassPath: Option[String] = None): Unit = { extraClassPath.foreach(addClasspathEntry(_, env)) addClasspathEntry(Environment.PWD.$(), env) @@ -554,36 +647,40 @@ object ClientBase extends Logging { if (sparkConf.get("spark.yarn.user.classpath.first", "false").toBoolean) { addUserClasspath(args, sparkConf, env) addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) - ClientBase.populateHadoopClasspath(conf, env) + populateHadoopClasspath(conf, env) } else { addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env) - ClientBase.populateHadoopClasspath(conf, env) + populateHadoopClasspath(conf, env) addUserClasspath(args, sparkConf, env) } // Append all jar files under the working directory to the classpath. - addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + "*", env); + addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + "*", env) } /** * Adds the user jars which have local: URIs (or alternate names, such as APP_JAR) explicitly * to the classpath. */ - private def addUserClasspath(args: ClientArguments, conf: SparkConf, - env: HashMap[String, String]) = { - if (args != null) { - addFileToClasspath(args.userJar, APP_JAR, env) - if (args.addJars != null) { - args.addJars.split(",").foreach { case file: String => - addFileToClasspath(file, null, env) - } + private def addUserClasspath( + args: ClientArguments, + conf: SparkConf, + env: HashMap[String, String]): Unit = { + + // If `args` is not null, we are launching an AM container. + // Otherwise, we are launching executor containers. + val (mainJar, secondaryJars) = + if (args != null) { + (args.userJar, args.addJars) + } else { + (conf.get(CONF_SPARK_USER_JAR, null), conf.get(CONF_SPARK_YARN_SECONDARY_JARS, null)) } - } else { - val userJar = conf.get(CONF_SPARK_USER_JAR, null) - addFileToClasspath(userJar, APP_JAR, env) - val cachedSecondaryJarLinks = conf.get(CONF_SPARK_YARN_SECONDARY_JARS, "").split(",") - cachedSecondaryJarLinks.foreach(jar => addFileToClasspath(jar, null, env)) + addFileToClasspath(mainJar, APP_JAR, env) + if (secondaryJars != null) { + secondaryJars.split(",").filter(_.nonEmpty).foreach { jar => + addFileToClasspath(jar, null, env) + } } } @@ -599,46 +696,44 @@ object ClientBase extends Logging { * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. */ - private def addFileToClasspath(path: String, fileName: String, - env: HashMap[String, String]) : Unit = { + private def addFileToClasspath( + path: String, + fileName: String, + env: HashMap[String, String]): Unit = { if (path != null) { scala.util.control.Exception.ignoring(classOf[URISyntaxException]) { - val localPath = getLocalPath(path) - if (localPath != null) { - addClasspathEntry(localPath, env) + val uri = new URI(path) + if (uri.getScheme == LOCAL_SCHEME) { + addClasspathEntry(uri.getPath, env) return } } } if (fileName != null) { - addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + fileName, env); + addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + fileName, env) } } /** - * Returns the local path if the URI is a "local:" URI, or null otherwise. + * Add the given path to the classpath entry of the given environment map. + * If the classpath is already set, this appends the new path to the existing classpath. */ - private def getLocalPath(resource: String): String = { - val uri = new URI(resource) - if (LOCAL_SCHEME.equals(uri.getScheme())) { - return uri.getPath() - } - null - } - - private def addClasspathEntry(path: String, env: HashMap[String, String]) = - YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, path, - File.pathSeparator) + private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit = + YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path) /** * Get the list of namenodes the user may access. */ - private[yarn] def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { - sparkConf.get("spark.yarn.access.namenodes", "").split(",").map(_.trim()).filter(!_.isEmpty) - .map(new Path(_)).toSet + def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = { + sparkConf.get("spark.yarn.access.namenodes", "") + .split(",") + .map(_.trim()) + .filter(!_.isEmpty) + .map(new Path(_)) + .toSet } - private[yarn] def getTokenRenewer(conf: Configuration): String = { + def getTokenRenewer(conf: Configuration): String = { val delegTokenRenewer = Master.getMasterPrincipal(conf) logDebug("delegation token renewer is: " + delegTokenRenewer) if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { @@ -652,17 +747,54 @@ object ClientBase extends Logging { /** * Obtains tokens for the namenodes passed in and adds them to the credentials. */ - private[yarn] def obtainTokensForNamenodes(paths: Set[Path], conf: Configuration, - creds: Credentials) { + def obtainTokensForNamenodes( + paths: Set[Path], + conf: Configuration, + creds: Credentials): Unit = { if (UserGroupInformation.isSecurityEnabled()) { val delegTokenRenewer = getTokenRenewer(conf) + paths.foreach { dst => + val dstFs = dst.getFileSystem(conf) + logDebug("getting token for namenode: " + dst) + dstFs.addDelegationTokens(delegTokenRenewer, creds) + } + } + } - paths.foreach { - dst => - val dstFs = dst.getFileSystem(conf) - logDebug("getting token for namenode: " + dst) - dstFs.addDelegationTokens(delegTokenRenewer, creds) + /** + * Return whether the two file systems are the same. + */ + private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = { + val srcUri = srcFs.getUri() + val dstUri = destFs.getUri() + if (srcUri.getScheme() == null) { + return false + } + if (!srcUri.getScheme().equals(dstUri.getScheme())) { + return false + } + var srcHost = srcUri.getHost() + var dstHost = dstUri.getHost() + if ((srcHost != null) && (dstHost != null)) { + try { + srcHost = InetAddress.getByName(srcHost).getCanonicalHostName() + dstHost = InetAddress.getByName(dstHost).getCanonicalHostName() + } catch { + case e: UnknownHostException => + return false } + if (!srcHost.equals(dstHost)) { + return false + } + } else if (srcHost == null && dstHost != null) { + return false + } else if (srcHost != null && dstHost == null) { + return false + } + if (srcUri.getPort() != dstUri.getPort()) { + false + } else { + true } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 9b7f1fca96c6d..c592ecfdfce06 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -19,29 +19,24 @@ package org.apache.spark.deploy.yarn import java.net.URI +import scala.collection.mutable.{HashMap, LinkedHashMap, Map} + import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileStatus -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.permission.FsAction -import org.apache.hadoop.yarn.api.records.LocalResource -import org.apache.hadoop.yarn.api.records.LocalResourceVisibility -import org.apache.hadoop.yarn.api.records.LocalResourceType +import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.util.{Records, ConverterUtils} -import org.apache.spark.Logging - -import scala.collection.mutable.HashMap -import scala.collection.mutable.LinkedHashMap -import scala.collection.mutable.Map - +import org.apache.spark.Logging /** Client side methods to setup the Hadoop distributed cache */ -class ClientDistributedCacheManager() extends Logging { - private val distCacheFiles: Map[String, Tuple3[String, String, String]] = - LinkedHashMap[String, Tuple3[String, String, String]]() - private val distCacheArchives: Map[String, Tuple3[String, String, String]] = - LinkedHashMap[String, Tuple3[String, String, String]]() +private[spark] class ClientDistributedCacheManager() extends Logging { + + // Mappings from remote URI to (file status, modification time, visibility) + private val distCacheFiles: Map[String, (String, String, String)] = + LinkedHashMap[String, (String, String, String)]() + private val distCacheArchives: Map[String, (String, String, String)] = + LinkedHashMap[String, (String, String, String)]() /** @@ -68,9 +63,9 @@ class ClientDistributedCacheManager() extends Logging { resourceType: LocalResourceType, link: String, statCache: Map[URI, FileStatus], - appMasterOnly: Boolean = false) = { + appMasterOnly: Boolean = false): Unit = { val destStatus = fs.getFileStatus(destPath) - val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + val amJarRsrc = Records.newRecord(classOf[LocalResource]) amJarRsrc.setType(resourceType) val visibility = getVisibility(conf, destPath.toUri(), statCache) amJarRsrc.setVisibility(visibility) @@ -80,7 +75,7 @@ class ClientDistributedCacheManager() extends Logging { if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name") localResources(link) = amJarRsrc - if (appMasterOnly == false) { + if (!appMasterOnly) { val uri = destPath.toUri() val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link) if (resourceType == LocalResourceType.FILE) { @@ -95,12 +90,10 @@ class ClientDistributedCacheManager() extends Logging { /** * Adds the necessary cache file env variables to the env passed in - * @param env */ - def setDistFilesEnv(env: Map[String, String]) = { + def setDistFilesEnv(env: Map[String, String]): Unit = { val (keys, tupleValues) = distCacheFiles.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 - if (keys.size > 0) { env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = @@ -114,12 +107,10 @@ class ClientDistributedCacheManager() extends Logging { /** * Adds the necessary cache archive env variables to the env passed in - * @param env */ - def setDistArchivesEnv(env: Map[String, String]) = { + def setDistArchivesEnv(env: Map[String, String]): Unit = { val (keys, tupleValues) = distCacheArchives.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 - if (keys.size > 0) { env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = @@ -133,25 +124,21 @@ class ClientDistributedCacheManager() extends Logging { /** * Returns the local resource visibility depending on the cache file permissions - * @param conf - * @param uri - * @param statCache * @return LocalResourceVisibility */ - def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): - LocalResourceVisibility = { + def getVisibility( + conf: Configuration, + uri: URI, + statCache: Map[URI, FileStatus]): LocalResourceVisibility = { if (isPublic(conf, uri, statCache)) { - return LocalResourceVisibility.PUBLIC - } - LocalResourceVisibility.PRIVATE + LocalResourceVisibility.PUBLIC + } else { + LocalResourceVisibility.PRIVATE + } } /** - * Returns a boolean to denote whether a cache file is visible to all(public) - * or not - * @param conf - * @param uri - * @param statCache + * Returns a boolean to denote whether a cache file is visible to all (public) * @return true if the path in the uri is visible to all, false otherwise */ def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = { @@ -167,13 +154,12 @@ class ClientDistributedCacheManager() extends Logging { /** * Returns true if all ancestors of the specified path have the 'execute' * permission set for all users (i.e. that other users can traverse - * the directory heirarchy to the given path) - * @param fs - * @param path - * @param statCache + * the directory hierarchy to the given path) * @return true if all ancestors have the 'execute' permission set for all users */ - def ancestorsHaveExecutePermissions(fs: FileSystem, path: Path, + def ancestorsHaveExecutePermissions( + fs: FileSystem, + path: Path, statCache: Map[URI, FileStatus]): Boolean = { var current = path while (current != null) { @@ -187,32 +173,25 @@ class ClientDistributedCacheManager() extends Logging { } /** - * Checks for a given path whether the Other permissions on it + * Checks for a given path whether the Other permissions on it * imply the permission in the passed FsAction - * @param fs - * @param path - * @param action - * @param statCache * @return true if the path in the uri is visible to all, false otherwise */ - def checkPermissionOfOther(fs: FileSystem, path: Path, - action: FsAction, statCache: Map[URI, FileStatus]): Boolean = { + def checkPermissionOfOther( + fs: FileSystem, + path: Path, + action: FsAction, + statCache: Map[URI, FileStatus]): Boolean = { val status = getFileStatus(fs, path.toUri(), statCache) val perms = status.getPermission() val otherAction = perms.getOtherAction() - if (otherAction.implies(action)) { - return true - } - false + otherAction.implies(action) } /** - * Checks to see if the given uri exists in the cache, if it does it + * Checks to see if the given uri exists in the cache, if it does it * returns the existing FileStatus, otherwise it stats the uri, stores * it in the cache, and returns the FileStatus. - * @param fs - * @param uri - * @param statCache * @return FileStatus */ def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index f56f72cafe50e..bbbf615510762 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy.yarn -import java.io.File import java.net.URI import scala.collection.JavaConversions._ @@ -128,9 +127,9 @@ trait ExecutorRunnableUtil extends Logging { localResources: HashMap[String, LocalResource], timestamp: String, size: String, - vis: String) = { + vis: String): Unit = { val uri = new URI(file) - val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource] + val amJarRsrc = Records.newRecord(classOf[LocalResource]) amJarRsrc.setType(rtype) amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis)) amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri)) @@ -175,14 +174,17 @@ trait ExecutorRunnableUtil extends Logging { ClientBase.populateClasspath(null, yarnConf, sparkConf, env, extraCp) sparkConf.getExecutorEnv.foreach { case (key, value) => - YarnSparkHadoopUtil.addToEnvironment(env, key, value, File.pathSeparator) + // This assumes each executor environment variable set here is a path + // This is kept for backward compatibility and consistency with hadoop + YarnSparkHadoopUtil.addPathToEnvironment(env, key, value) } // Keep this for backwards compatibility but users should move to the config - YarnSparkHadoopUtil.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"), - File.pathSeparator) + sys.env.get("SPARK_YARN_USER_ENV").foreach { userEnvs => + YarnSparkHadoopUtil.setEnvFromInputString(env, userEnvs) + } - System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k, v) => env(k) = v } env } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 4a33e34c3bfc7..0b712c201904a 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.yarn import java.lang.{Boolean => JBoolean} +import java.io.File import java.util.{Collections, Set => JSet} import java.util.regex.Matcher import java.util.regex.Pattern @@ -29,14 +30,12 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.util.StringInterner import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.hadoop.yarn.util.RackResolver import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SecurityManager, SparkConf, SparkContext} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils @@ -100,30 +99,26 @@ object YarnSparkHadoopUtil { private val hostToRack = new ConcurrentHashMap[String, String]() private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() - def addToEnvironment( - env: HashMap[String, String], - variable: String, - value: String, - classPathSeparator: String) = { - var envVariable = "" - if (env.get(variable) == None) { - envVariable = value - } else { - envVariable = env.get(variable).get + classPathSeparator + value - } - env put (StringInterner.weakIntern(variable), StringInterner.weakIntern(envVariable)) + /** + * Add a path variable to the given environment map. + * If the map already contains this key, append the value to the existing value instead. + */ + def addPathToEnvironment(env: HashMap[String, String], key: String, value: String): Unit = { + val newValue = if (env.contains(key)) { env(key) + File.pathSeparator + value } else value + env.put(key, newValue) } - def setEnvFromInputString( - env: HashMap[String, String], - envString: String, - classPathSeparator: String) = { - if (envString != null && envString.length() > 0) { - var childEnvs = envString.split(",") - var p = Pattern.compile(getEnvironmentVariableRegex()) + /** + * Set zero or more environment variables specified by the given input string. + * The input string is expected to take the form "KEY1=VAL1,KEY2=VAL2,KEY3=VAL3". + */ + def setEnvFromInputString(env: HashMap[String, String], inputString: String): Unit = { + if (inputString != null && inputString.length() > 0) { + val childEnvs = inputString.split(",") + val p = Pattern.compile(environmentVariableRegex) for (cEnv <- childEnvs) { - var parts = cEnv.split("=") // split on '=' - var m = p.matcher(parts(1)) + val parts = cEnv.split("=") // split on '=' + val m = p.matcher(parts(1)) val sb = new StringBuffer while (m.find()) { val variable = m.group(1) @@ -131,8 +126,7 @@ object YarnSparkHadoopUtil { if (env.get(variable) != None) { replace = env.get(variable).get } else { - // if this key is not configured for the child .. get it - // from the env + // if this key is not configured for the child .. get it from the env replace = System.getenv(variable) if (replace == null) { // the env key is note present anywhere .. simply set it @@ -142,14 +136,15 @@ object YarnSparkHadoopUtil { m.appendReplacement(sb, Matcher.quoteReplacement(replace)) } m.appendTail(sb) - addToEnvironment(env, parts(0), sb.toString(), classPathSeparator) + // This treats the environment variable as path variable delimited by `File.pathSeparator` + // This is kept for backward compatibility and consistency with Hadoop's behavior + addPathToEnvironment(env, parts(0), sb.toString) } } } - private def getEnvironmentVariableRegex() : String = { - val osName = System.getProperty("os.name") - if (osName startsWith "Windows") { + private val environmentVariableRegex: String = { + if (Utils.isWindows) { "%([A-Za-z_][A-Za-z0-9_]*?)%" } else { "\\$([A-Za-z_][A-Za-z0-9_]*)" @@ -181,14 +176,14 @@ object YarnSparkHadoopUtil { } } - private[spark] def lookupRack(conf: Configuration, host: String): String = { + def lookupRack(conf: Configuration, host: String): String = { if (!hostToRack.contains(host)) { populateRackInfo(conf, host) } hostToRack.get(host) } - private[spark] def populateRackInfo(conf: Configuration, hostname: String) { + def populateRackInfo(conf: Configuration, hostname: String) { Utils.checkHost(hostname) if (!hostToRack.containsKey(hostname)) { @@ -212,8 +207,8 @@ object YarnSparkHadoopUtil { } } - private[spark] def getApplicationAclsForYarn(securityMgr: SecurityManager): - Map[ApplicationAccessType, String] = { + def getApplicationAclsForYarn(securityMgr: SecurityManager) + : Map[ApplicationAccessType, String] = { Map[ApplicationAccessType, String] ( ApplicationAccessType.VIEW_APP -> securityMgr.getViewAcls, ApplicationAccessType.MODIFY_APP -> securityMgr.getModifyAcls diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 6aa6475fe4a18..200a30899290b 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} import org.apache.spark.{SparkException, Logging, SparkContext} -import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} +import org.apache.spark.deploy.yarn.{Client, ClientArguments} import org.apache.spark.scheduler.TaskSchedulerImpl import scala.collection.mutable.ArrayBuffer @@ -34,115 +34,120 @@ private[spark] class YarnClientSchedulerBackend( minRegisteredRatio = 0.8 } - var client: Client = null - var appId: ApplicationId = null - var checkerThread: Thread = null - var stopping: Boolean = false - var totalExpectedExecutors = 0 - - private[spark] def addArg(optionName: String, envVar: String, sysProp: String, - arrayBuf: ArrayBuffer[String]) { - if (System.getenv(envVar) != null) { - arrayBuf += (optionName, System.getenv(envVar)) - } else if (sc.getConf.contains(sysProp)) { - arrayBuf += (optionName, sc.getConf.get(sysProp)) - } - } + private var client: Client = null + private var appId: ApplicationId = null + private var stopping: Boolean = false + private var totalExpectedExecutors = 0 + /** + * Create a Yarn client to submit an application to the ResourceManager. + * This waits until the application is running. + */ override def start() { super.start() - val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort sc.ui.foreach { ui => conf.set("spark.driver.appUIAddress", ui.appUIHostPort) } val argsArrayBuf = new ArrayBuffer[String]() - argsArrayBuf += ( - "--args", hostport - ) - - // process any optional arguments, given either as environment variables - // or system properties. use the defaults already defined in ClientArguments - // if things aren't specified. system properties override environment - // variables. - List(("--driver-memory", "SPARK_MASTER_MEMORY", "spark.master.memory"), - ("--driver-memory", "SPARK_DRIVER_MEMORY", "spark.driver.memory"), - ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), - ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), - ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), - ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), - ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), - ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), - ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), - ("--name", "SPARK_YARN_APP_NAME", "spark.app.name")) - .foreach { case (optName, envVar, sysProp) => addArg(optName, envVar, sysProp, argsArrayBuf) } - - logDebug("ClientArguments called with: " + argsArrayBuf) + argsArrayBuf += ("--arg", hostport) + argsArrayBuf ++= getExtraClientArguments + + logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" ")) val args = new ClientArguments(argsArrayBuf.toArray, conf) totalExpectedExecutors = args.numExecutors client = new Client(args, conf) - appId = client.runApp() - waitForApp() - checkerThread = yarnApplicationStateCheckerThread() + appId = client.submitApplication() + waitForApplication() + asyncMonitorApplication() } - def waitForApp() { - - // TODO : need a better way to find out whether the executors are ready or not - // maybe by resource usage report? - while(true) { - val report = client.getApplicationReport(appId) - - logInfo("Application report from ASM: \n" + - "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + - "\t appStartTime: " + report.getStartTime() + "\n" + - "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + /** + * Return any extra command line arguments to be passed to Client provided in the form of + * environment variables or Spark properties. + */ + private def getExtraClientArguments: Seq[String] = { + val extraArgs = new ArrayBuffer[String] + val optionTuples = // List of (target Client argument, environment variable, Spark property) + List( + ("--driver-memory", "SPARK_MASTER_MEMORY", "spark.master.memory"), + ("--driver-memory", "SPARK_DRIVER_MEMORY", "spark.driver.memory"), + ("--num-executors", "SPARK_WORKER_INSTANCES", "spark.executor.instances"), + ("--num-executors", "SPARK_EXECUTOR_INSTANCES", "spark.executor.instances"), + ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), + ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), + ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), + ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), + ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), + ("--name", "SPARK_YARN_APP_NAME", "spark.app.name") ) - - // Ready to go, or already gone. - val state = report.getYarnApplicationState() - if (state == YarnApplicationState.RUNNING) { - return - } else if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { - throw new SparkException("Yarn application already ended," + - "might be killed or not able to launch application master.") + optionTuples.foreach { case (optionName, envVar, sparkProp) => + if (System.getenv(envVar) != null) { + extraArgs += (optionName, System.getenv(envVar)) + } else if (sc.getConf.contains(sparkProp)) { + extraArgs += (optionName, sc.getConf.get(sparkProp)) } + } + extraArgs + } - Thread.sleep(1000) + /** + * Report the state of the application until it is running. + * If the application has finished, failed or been killed in the process, throw an exception. + * This assumes both `client` and `appId` have already been set. + */ + private def waitForApplication(): Unit = { + assert(client != null && appId != null, "Application has not been submitted yet!") + val state = client.monitorApplication(appId, returnOnRunning = true) // blocking + if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + throw new SparkException("Yarn application has already ended! " + + "It might have been killed or unable to launch application master.") + } + if (state == YarnApplicationState.RUNNING) { + logInfo(s"Application $appId has started running.") } } - private def yarnApplicationStateCheckerThread(): Thread = { + /** + * Monitor the application state in a separate thread. + * If the application has exited for any reason, stop the SparkContext. + * This assumes both `client` and `appId` have already been set. + */ + private def asyncMonitorApplication(): Unit = { + assert(client != null && appId != null, "Application has not been submitted yet!") val t = new Thread { override def run() { while (!stopping) { val report = client.getApplicationReport(appId) val state = report.getYarnApplicationState() - if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.KILLED - || state == YarnApplicationState.FAILED) { - logError(s"Yarn application already ended: $state") + if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.KILLED || + state == YarnApplicationState.FAILED) { + logError(s"Yarn application has already exited with state $state!") sc.stop() stopping = true } Thread.sleep(1000L) } - checkerThread = null Thread.currentThread().interrupt() } } - t.setName("Yarn Application State Checker") + t.setName("Yarn application state monitor") t.setDaemon(true) t.start() - t } + /** + * Stop the scheduler. This assumes `start()` has already been called. + */ override def stop() { + assert(client != null, "Attempted to stop this scheduler before starting it!") stopping = true super.stop() - client.stop + client.stop() logInfo("Stopped") } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala index c3b7a2c8f02e5..9bd916100dd2c 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse -import org.apache.hadoop.yarn.api.records.ContainerLaunchContext +import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ @@ -90,7 +90,7 @@ class ClientBaseSuite extends FunSuite with Matchers { val env = new MutableHashMap[String, String]() val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) - ClientBase.populateClasspath(args, conf, sparkConf, env, None) + ClientBase.populateClasspath(args, conf, sparkConf, env) val cp = env("CLASSPATH").split(File.pathSeparator) s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => @@ -114,10 +114,10 @@ class ClientBaseSuite extends FunSuite with Matchers { val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) val client = spy(new DummyClient(args, conf, sparkConf, yarnConf)) - doReturn(new Path("/")).when(client).copyRemoteFile(any(classOf[Path]), + doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]), any(classOf[Path]), anyShort(), anyBoolean()) - var tempDir = Files.createTempDir(); + val tempDir = Files.createTempDir() try { client.prepareLocalResources(tempDir.getAbsolutePath()) sparkConf.getOption(ClientBase.CONF_SPARK_USER_JAR) should be (Some(USER)) @@ -247,13 +247,13 @@ class ClientBaseSuite extends FunSuite with Matchers { private class DummyClient( val args: ClientArguments, - val conf: Configuration, + val hadoopConf: Configuration, val sparkConf: SparkConf, val yarnConf: YarnConfiguration) extends ClientBase { - - override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = - throw new UnsupportedOperationException() - + override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = ??? + override def submitApplication(): ApplicationId = ??? + override def getApplicationReport(appId: ApplicationId): ApplicationReport = ??? + override def getClientToken(report: ApplicationReport): String = ??? } } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 82e45e3e7ad54..0b43e6ee20538 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -21,11 +21,9 @@ import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.DataOutputBuffer -import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.client.api.YarnClient +import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication} import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SparkConf} @@ -34,128 +32,98 @@ import org.apache.spark.deploy.SparkHadoopUtil /** * Version of [[org.apache.spark.deploy.yarn.ClientBase]] tailored to YARN's stable API. */ -class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: SparkConf) +private[spark] class Client( + val args: ClientArguments, + val hadoopConf: Configuration, + val sparkConf: SparkConf) extends ClientBase with Logging { - val yarnClient = YarnClient.createYarnClient - def this(clientArgs: ClientArguments, spConf: SparkConf) = this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf) def this(clientArgs: ClientArguments) = this(clientArgs, new SparkConf()) - val args = clientArgs - val conf = hadoopConf - val sparkConf = spConf - var rpc: YarnRPC = YarnRPC.create(conf) - val yarnConf: YarnConfiguration = new YarnConfiguration(conf) - - def runApp(): ApplicationId = { - validateArgs() - // Initialize and start the client service. + val yarnClient = YarnClient.createYarnClient + val yarnConf = new YarnConfiguration(hadoopConf) + + def stop(): Unit = yarnClient.stop() + + /* ------------------------------------------------------------------------------------- * + | The following methods have much in common in the stable and alpha versions of Client, | + | but cannot be implemented in the parent trait due to subtle API differences across | + | hadoop versions. | + * ------------------------------------------------------------------------------------- */ + + /** + * Submit an application running our ApplicationMaster to the ResourceManager. + * + * The stable Yarn API provides a convenience method (YarnClient#createApplication) for + * creating applications and setting up the application submission context. This was not + * available in the alpha API. + */ + override def submitApplication(): ApplicationId = { yarnClient.init(yarnConf) yarnClient.start() - // Log details about this YARN cluster (e.g, the number of slave machines/NodeManagers). - logClusterResourceDetails() - - // Prepare to submit a request to the ResourcManager (specifically its ApplicationsManager (ASM) - // interface). + logInfo("Requesting a new application from cluster with %d NodeManagers" + .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers)) - // Get a new client application. + // Get a new application from our RM val newApp = yarnClient.createApplication() val newAppResponse = newApp.getNewApplicationResponse() val appId = newAppResponse.getApplicationId() + // Verify whether the cluster has enough resources for our AM verifyClusterResources(newAppResponse) - // Set up resource and environment variables. - val appStagingDir = getAppStagingDir(appId) - val localResources = prepareLocalResources(appStagingDir) - val launchEnv = setupLaunchEnv(localResources, appStagingDir) - val amContainer = createContainerLaunchContext(newAppResponse, localResources, launchEnv) + // Set up the appropriate contexts to launch our AM + val containerContext = createContainerLaunchContext(newAppResponse) + val appContext = createApplicationSubmissionContext(newApp, containerContext) - // Set up an application submission context. - val appContext = newApp.getApplicationSubmissionContext() - appContext.setApplicationName(args.appName) - appContext.setQueue(args.amQueue) - appContext.setAMContainerSpec(amContainer) - appContext.setApplicationType("SPARK") - - // Memory for the ApplicationMaster. - val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] - memoryResource.setMemory(args.amMemory + memoryOverhead) - appContext.setResource(memoryResource) - - // Finally, submit and monitor the application. - submitApp(appContext) + // Finally, submit and monitor the application + logInfo(s"Submitting application ${appId.getId} to ResourceManager") + yarnClient.submitApplication(appContext) appId } - def run() { - val appId = runApp() - monitorApplication(appId) - } - - def logClusterResourceDetails() { - val clusterMetrics: YarnClusterMetrics = yarnClient.getYarnClusterMetrics - logInfo("Got cluster metric info from ResourceManager, number of NodeManagers: " + - clusterMetrics.getNumNodeManagers) + /** + * Set up the context for submitting our ApplicationMaster. + * This uses the YarnClientApplication not available in the Yarn alpha API. + */ + def createApplicationSubmissionContext( + newApp: YarnClientApplication, + containerContext: ContainerLaunchContext): ApplicationSubmissionContext = { + val appContext = newApp.getApplicationSubmissionContext + appContext.setApplicationName(args.appName) + appContext.setQueue(args.amQueue) + appContext.setAMContainerSpec(containerContext) + appContext.setApplicationType("SPARK") + val capability = Records.newRecord(classOf[Resource]) + capability.setMemory(args.amMemory + amMemoryOverhead) + appContext.setResource(capability) + appContext } - def setupSecurityToken(amContainer: ContainerLaunchContext) = { - // Setup security tokens. - val dob = new DataOutputBuffer() + /** Set up security tokens for launching our ApplicationMaster container. */ + override def setupSecurityToken(amContainer: ContainerLaunchContext): Unit = { + val dob = new DataOutputBuffer credentials.writeTokenStorageToStream(dob) - amContainer.setTokens(ByteBuffer.wrap(dob.getData())) + amContainer.setTokens(ByteBuffer.wrap(dob.getData)) } - def submitApp(appContext: ApplicationSubmissionContext) = { - // Submit the application to the applications manager. - logInfo("Submitting application to ResourceManager") - yarnClient.submitApplication(appContext) - } + /** Get the application report from the ResourceManager for an application we have submitted. */ + override def getApplicationReport(appId: ApplicationId): ApplicationReport = + yarnClient.getApplicationReport(appId) - def getApplicationReport(appId: ApplicationId) = - yarnClient.getApplicationReport(appId) - - def stop = yarnClient.stop - - def monitorApplication(appId: ApplicationId): Boolean = { - val interval = sparkConf.getLong("spark.yarn.report.interval", 1000) - - while (true) { - Thread.sleep(interval) - val report = yarnClient.getApplicationReport(appId) - - logInfo("Application report from ResourceManager: \n" + - "\t application identifier: " + appId.toString() + "\n" + - "\t appId: " + appId.getId() + "\n" + - "\t clientToAMToken: " + report.getClientToAMToken() + "\n" + - "\t appDiagnostics: " + report.getDiagnostics() + "\n" + - "\t appMasterHost: " + report.getHost() + "\n" + - "\t appQueue: " + report.getQueue() + "\n" + - "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + - "\t appStartTime: " + report.getStartTime() + "\n" + - "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + - "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" + - "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" + - "\t appUser: " + report.getUser() - ) - - val state = report.getYarnApplicationState() - if (state == YarnApplicationState.FINISHED || - state == YarnApplicationState.FAILED || - state == YarnApplicationState.KILLED) { - return true - } - } - true - } + /** + * Return the security token used by this client to communicate with the ApplicationMaster. + * If no security is enabled, the token returned by the report is null. + */ + override def getClientToken(report: ApplicationReport): String = + Option(report.getClientToAMToken).map(_.toString).getOrElse("") } object Client { - def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { println("WARNING: This client is deprecated and will be removed in a " + @@ -163,22 +131,19 @@ object Client { } // Set an env variable indicating we are running in YARN mode. - // Note: anything env variable with SPARK_ prefix gets propagated to all (remote) processes - - // see Client#setupLaunchEnv(). + // Note that any env variable with the SPARK_ prefix gets propagated to all (remote) processes System.setProperty("SPARK_YARN_MODE", "true") - val sparkConf = new SparkConf() + val sparkConf = new SparkConf try { val args = new ClientArguments(argStrings, sparkConf) new Client(args, sparkConf).run() } catch { - case e: Exception => { + case e: Exception => Console.err.println(e.getMessage) System.exit(1) - } } System.exit(0) } - } From 11c10df825419372df61a8d23c51e8c3cc78047f Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 23 Sep 2014 11:40:14 -0500 Subject: [PATCH 04/22] [SPARK-3304] [YARN] ApplicationMaster's Finish status is wrong when uncaught exception is thrown from ReporterThread Author: Kousuke Saruta Closes #2198 from sarutak/SPARK-3304 and squashes the following commits: 2696237 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 5b80363 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 4eb0a3e [Kousuke Saruta] Remoed the description about spark.yarn.scheduler.reporterThread.maxFailure 9741597 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 f7538d4 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 358ef8d [Kousuke Saruta] Merge branch 'SPARK-3304' of github.com:sarutak/spark into SPARK-3304 0d138c6 [Kousuke Saruta] Revert "tmp" f8da10a [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 b6e9879 [Kousuke Saruta] tmp 8d256ed [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 13b2652 [Kousuke Saruta] Merge branch 'SPARK-3304' of github.com:sarutak/spark into SPARK-3304 2711e15 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 c081f8e [Kousuke Saruta] Modified ApplicationMaster to handle exception in ReporterThread itself 0bbd3a6 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-3304 a6982ad [Kousuke Saruta] Added ability handling uncaught exception thrown from Reporter thread --- .../spark/deploy/yarn/ApplicationMaster.scala | 66 +++++++++++++++---- 1 file changed, 54 insertions(+), 12 deletions(-) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index cde5fff637a39..9050808157257 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -17,7 +17,10 @@ package org.apache.spark.deploy.yarn +import scala.util.control.NonFatal + import java.io.IOException +import java.lang.reflect.InvocationTargetException import java.net.Socket import java.util.concurrent.atomic.AtomicReference @@ -55,6 +58,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, @volatile private var finished = false @volatile private var finalStatus = FinalApplicationStatus.UNDEFINED + @volatile private var userClassThread: Thread = _ private var reporterThread: Thread = _ private var allocator: YarnAllocator = _ @@ -221,18 +225,48 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // must be <= expiryInterval / 2. val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) + // The number of failures in a row until Reporter thread give up + val reporterMaxFailures = sparkConf.getInt("spark.yarn.scheduler.reporterThread.maxFailures", 5) + val t = new Thread { override def run() { + var failureCount = 0 + while (!finished) { - checkNumExecutorsFailed() - if (!finished) { - logDebug("Sending progress") - allocator.allocateResources() - try { - Thread.sleep(interval) - } catch { - case e: InterruptedException => + try { + checkNumExecutorsFailed() + if (!finished) { + logDebug("Sending progress") + allocator.allocateResources() } + failureCount = 0 + } catch { + case e: Throwable => { + failureCount += 1 + if (!NonFatal(e) || failureCount >= reporterMaxFailures) { + logError("Exception was thrown from Reporter thread.", e) + finish(FinalApplicationStatus.FAILED, "Exception was thrown" + + s"${failureCount} time(s) from Reporter thread.") + + /** + * If exception is thrown from ReporterThread, + * interrupt user class to stop. + * Without this interrupting, if exception is + * thrown before allocating enough executors, + * YarnClusterScheduler waits until timeout even though + * we cannot allocate executors. + */ + logInfo("Interrupting user class to stop.") + userClassThread.interrupt + } else { + logWarning(s"Reporter thread fails ${failureCount} time(s) in a row.", e) + } + } + } + try { + Thread.sleep(interval) + } catch { + case e: InterruptedException => } } } @@ -355,7 +389,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]]) - val t = new Thread { + userClassThread = new Thread { override def run() { var status = FinalApplicationStatus.FAILED try { @@ -366,15 +400,23 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // Some apps have "System.exit(0)" at the end. The user thread will stop here unless // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED. status = FinalApplicationStatus.SUCCEEDED + } catch { + case e: InvocationTargetException => { + e.getCause match { + case _: InterruptedException => { + // Reporter thread can interrupt to stop user class + } + } + } } finally { logDebug("Finishing main") } finalStatus = status } } - t.setName("Driver") - t.start() - t + userClassThread.setName("Driver") + userClassThread.start() + userClassThread } // Actor used to monitor the driver when running in client deploy mode. From 66bc0f2d675d06cdd48638f124a1ff32be2bf456 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 23 Sep 2014 11:45:44 -0700 Subject: [PATCH 05/22] [SPARK-3598][SQL]cast to timestamp should be the same as hive this patch fixes timestamp smaller than 0 and cast int as timestamp select cast(1000 as timestamp) from src limit 1; should return 1970-01-01 00:00:01, but we now take it as 1000 seconds. also, current implementation has bug when the time is before 1970-01-01 00:00:00. rxin marmbrus chenghao-intel Author: Daoyuan Wang Closes #2458 from adrian-wang/timestamp and squashes the following commits: 4274b1d [Daoyuan Wang] set test not related to timezone 1234f66 [Daoyuan Wang] fix timestamp smaller than 0 and cast int as timestamp --- .../spark/sql/catalyst/expressions/Cast.scala | 17 +++++++------ .../ExpressionEvaluationSuite.scala | 16 ++++++++----- ...cast #1-0-69fc614ccea92bbe39f4decc299edcc6 | 1 + ...cast #2-0-732ed232ac592c5e7f7c913a88874fd2 | 1 + ... cast #3-0-76ee270337f664b36cacfc6528ac109 | 1 + ...cast #4-0-732ed232ac592c5e7f7c913a88874fd2 | 1 + ...cast #5-0-dbd7bcd167d322d6617b884c02c7f247 | 1 + ...cast #6-0-6d2da5cfada03605834e38bc4075bc79 | 1 + ...cast #7-0-1d70654217035f8ce5f64344f4c5a80f | 1 + ...cast #8-0-6d2da5cfada03605834e38bc4075bc79 | 1 + .../sql/hive/execution/HiveQuerySuite.scala | 24 +++++++++++++++++++ 11 files changed, 50 insertions(+), 15 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #2-0-732ed232ac592c5e7f7c913a88874fd2 create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #4-0-732ed232ac592c5e7f7c913a88874fd2 create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #6-0-6d2da5cfada03605834e38bc4075bc79 create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f create mode 100644 sql/hive/src/test/resources/golden/timestamp cast #8-0-6d2da5cfada03605834e38bc4075bc79 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0379275121bf2..f626d09f037bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -86,15 +86,15 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null } }) case BooleanType => - buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0) * 1000)) + buildCast[Boolean](_, b => new Timestamp((if (b) 1 else 0))) case LongType => - buildCast[Long](_, l => new Timestamp(l * 1000)) + buildCast[Long](_, l => new Timestamp(l)) case IntegerType => - buildCast[Int](_, i => new Timestamp(i * 1000)) + buildCast[Int](_, i => new Timestamp(i)) case ShortType => - buildCast[Short](_, s => new Timestamp(s * 1000)) + buildCast[Short](_, s => new Timestamp(s)) case ByteType => - buildCast[Byte](_, b => new Timestamp(b * 1000)) + buildCast[Byte](_, b => new Timestamp(b)) // TimestampWritable.decimalToTimestamp case DecimalType => buildCast[BigDecimal](_, d => decimalToTimestamp(d)) @@ -107,11 +107,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } private[this] def decimalToTimestamp(d: BigDecimal) = { - val seconds = d.longValue() + val seconds = Math.floor(d.toDouble).toLong val bd = (d - seconds) * 1000000000 val nanos = bd.intValue() - // Convert to millis val millis = seconds * 1000 val t = new Timestamp(millis) @@ -121,11 +120,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } // Timestamp to long, converting milliseconds to seconds - private[this] def timestampToLong(ts: Timestamp) = ts.getTime / 1000 + private[this] def timestampToLong(ts: Timestamp) = Math.floor(ts.getTime / 1000.0).toLong private[this] def timestampToDouble(ts: Timestamp) = { // First part is the seconds since the beginning of time, followed by nanosecs. - ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000 + Math.floor(ts.getTime / 1000.0).toLong + ts.getNanos.toDouble / 1000000000 } // Converts Timestamp to string according to Hive TimestampWritable convention diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index b961346dfc995..8b6721d5d8125 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -231,7 +231,9 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65)) checkEvaluation(Literal(1) cast LongType, 1) - checkEvaluation(Cast(Literal(1) cast TimestampType, LongType), 1) + checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) + checkEvaluation(Cast(Literal(-1200) cast TimestampType, LongType), -2.toLong) + checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts) @@ -242,11 +244,11 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Cast(Cast(Cast( Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 5) + Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 0) checkEvaluation(Cast(Cast(Cast(Cast( Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null) checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 5) + Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 0) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) @@ -293,16 +295,18 @@ class ExpressionEvaluationSuite extends FunSuite { test("timestamp casting") { val millis = 15 * 1000 + 2 + val seconds = millis * 1000 + 2 val ts = new Timestamp(millis) val ts1 = new Timestamp(15 * 1000) // a timestamp without the milliseconds part + val tss = new Timestamp(seconds) checkEvaluation(Cast(ts, ShortType), 15) checkEvaluation(Cast(ts, IntegerType), 15) checkEvaluation(Cast(ts, LongType), 15) checkEvaluation(Cast(ts, FloatType), 15.002f) checkEvaluation(Cast(ts, DoubleType), 15.002) - checkEvaluation(Cast(Cast(ts, ShortType), TimestampType), ts1) - checkEvaluation(Cast(Cast(ts, IntegerType), TimestampType), ts1) - checkEvaluation(Cast(Cast(ts, LongType), TimestampType), ts1) + checkEvaluation(Cast(Cast(tss, ShortType), TimestampType), ts) + checkEvaluation(Cast(Cast(tss, IntegerType), TimestampType), ts) + checkEvaluation(Cast(Cast(tss, LongType), TimestampType), ts) checkEvaluation(Cast(Cast(millis.toFloat / 1000, TimestampType), FloatType), millis.toFloat / 1000) checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType), diff --git a/sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 b/sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 new file mode 100644 index 0000000000000..8ebf695ba7d20 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 @@ -0,0 +1 @@ +0.001 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #2-0-732ed232ac592c5e7f7c913a88874fd2 b/sql/hive/src/test/resources/golden/timestamp cast #2-0-732ed232ac592c5e7f7c913a88874fd2 new file mode 100644 index 0000000000000..5625e59da8873 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #2-0-732ed232ac592c5e7f7c913a88874fd2 @@ -0,0 +1 @@ +1.2 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 b/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #3-0-76ee270337f664b36cacfc6528ac109 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #4-0-732ed232ac592c5e7f7c913a88874fd2 b/sql/hive/src/test/resources/golden/timestamp cast #4-0-732ed232ac592c5e7f7c913a88874fd2 new file mode 100644 index 0000000000000..5625e59da8873 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #4-0-732ed232ac592c5e7f7c913a88874fd2 @@ -0,0 +1 @@ +1.2 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 new file mode 100644 index 0000000000000..27de46fdf22ac --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 @@ -0,0 +1 @@ +-0.0010000000000000009 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #6-0-6d2da5cfada03605834e38bc4075bc79 b/sql/hive/src/test/resources/golden/timestamp cast #6-0-6d2da5cfada03605834e38bc4075bc79 new file mode 100644 index 0000000000000..1d94c8a014fb4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #6-0-6d2da5cfada03605834e38bc4075bc79 @@ -0,0 +1 @@ +-1.2 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f b/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f new file mode 100644 index 0000000000000..3fbedf693b51d --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #7-0-1d70654217035f8ce5f64344f4c5a80f @@ -0,0 +1 @@ +-2 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #8-0-6d2da5cfada03605834e38bc4075bc79 b/sql/hive/src/test/resources/golden/timestamp cast #8-0-6d2da5cfada03605834e38bc4075bc79 new file mode 100644 index 0000000000000..1d94c8a014fb4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp cast #8-0-6d2da5cfada03605834e38bc4075bc79 @@ -0,0 +1 @@ +-1.2 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 56bcd95eab4bc..6fc891ba4cca5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -303,6 +303,30 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("case statements WITHOUT key #4", "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") + createQueryTest("timestamp cast #1", + "SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + + createQueryTest("timestamp cast #2", + "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + + createQueryTest("timestamp cast #3", + "SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + + createQueryTest("timestamp cast #4", + "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + + createQueryTest("timestamp cast #5", + "SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + + createQueryTest("timestamp cast #6", + "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + + createQueryTest("timestamp cast #7", + "SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1") + + createQueryTest("timestamp cast #8", + "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + test("implement identity function using case statement") { val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") .map { case Row(i: Int) => i } From 116016b481cecbd8ad6e9717d92f977a164a6653 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Tue, 23 Sep 2014 11:47:53 -0700 Subject: [PATCH 06/22] [SPARK-3582][SQL] not limit argument type for hive simple udf Since we have moved to `ConventionHelper`, it is quite easy to avoid call `javaClassToDataType` in hive simple udf. This will solve SPARK-3582. Author: Daoyuan Wang Closes #2506 from adrian-wang/spark3582 and squashes the following commits: 450c28e [Daoyuan Wang] not limit argument type for hive simple udf --- .../spark/sql/hive/HiveInspectors.scala | 4 ++-- .../org/apache/spark/sql/hive/hiveUdfs.scala | 22 ++----------------- 2 files changed, 4 insertions(+), 22 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 943bbaa8ce25e..fa889ec104c6e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -137,7 +137,7 @@ private[hive] trait HiveInspectors { /** Converts native catalyst types to the types expected by Hive */ def wrap(a: Any): AnyRef = a match { - case s: String => new hadoopIo.Text(s) // TODO why should be Text? + case s: String => s: java.lang.String case i: Int => i: java.lang.Integer case b: Boolean => b: java.lang.Boolean case f: Float => f: java.lang.Float @@ -145,7 +145,7 @@ private[hive] trait HiveInspectors { case l: Long => l: java.lang.Long case l: Short => l: java.lang.Short case l: Byte => l: java.lang.Byte - case b: BigDecimal => b.bigDecimal + case b: BigDecimal => new HiveDecimal(b.underlying()) case b: Array[Byte] => b case t: java.sql.Timestamp => t case s: Seq[_] => seqAsJavaList(s.map(wrap)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 19ff3b66ad7ed..68944ed4ef21d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -51,19 +51,7 @@ private[hive] abstract class HiveFunctionRegistry val functionClassName = functionInfo.getFunctionClass.getName if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF] - val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) - - val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType) - - HiveSimpleUdf( - functionClassName, - children.zip(expectedDataTypes).map { - case (e, NullType) => e - case (e, t) if (e.dataType == t) => e - case (e, t) => Cast(e, t) - } - ) + HiveSimpleUdf(functionClassName, children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdf(functionClassName, children) } else if ( @@ -117,15 +105,9 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ @transient lazy val dataType = javaClassToDataType(method.getReturnType) - def catalystToHive(value: Any): Object = value match { - // TODO need more types here? or can we use wrap() - case bd: BigDecimal => new HiveDecimal(bd.underlying()) - case d => d.asInstanceOf[Object] - } - // TODO: Finish input output types. override def eval(input: Row): Any = { - val evaluatedChildren = children.map(c => catalystToHive(c.eval(input))) + val evaluatedChildren = children.map(c => wrap(c.eval(input))) unwrap(FunctionRegistry.invoke(method, function, conversionHelper .convertIfNecessary(evaluatedChildren: _*): _*)) From 3b8eefa9b843c7f1e0e8dda6023272bc9f011c5c Mon Sep 17 00:00:00 2001 From: ravipesala Date: Tue, 23 Sep 2014 11:52:13 -0700 Subject: [PATCH 07/22] [SPARK-3536][SQL] SELECT on empty parquet table throws exception It returns null metadata from parquet if querying on empty parquet file while calculating splits.So added null check and returns the empty splits. Author : ravipesala ravindra.pesalahuawei.com Author: ravipesala Closes #2456 from ravipesala/SPARK-3536 and squashes the following commits: 1e81a50 [ravipesala] Fixed the issue when querying on empty parquet file. --- .../spark/sql/parquet/ParquetTableOperations.scala | 7 +++++-- .../org/apache/spark/sql/parquet/ParquetQuerySuite.scala | 9 +++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index a5a5d139a65cb..d39e31a7fa195 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -427,11 +427,15 @@ private[parquet] class FilteringParquetRowInputFormat s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" + s" minSplitSize = $minSplitSize") } - + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] val getGlobalMetaData = classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) getGlobalMetaData.setAccessible(true) val globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] + // if parquet file is empty, return empty splits. + if (globalMetaData == null) { + return splits + } val readContext = getReadSupport(configuration).init( new InitContext(configuration, @@ -442,7 +446,6 @@ private[parquet] class FilteringParquetRowInputFormat classOf[ParquetInputFormat[_]].getDeclaredMethods.find(_.getName == "generateSplits").get generateSplits.setAccessible(true) - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] for (footer <- footers) { val fs = footer.getFile.getFileSystem(configuration) val file = footer.getFile diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 08f7358446b29..07adf731405af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -789,4 +789,13 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(result3(0)(1) === "the answer") Utils.deleteRecursively(tmpdir) } + + test("Querying on empty parquet throws exception (SPARK-3536)") { + val tmpdir = Utils.createTempDir() + Utils.deleteRecursively(tmpdir) + createParquetFile[TestRDDEntry](tmpdir.toString()).registerTempTable("tmpemptytable") + val result1 = sql("SELECT * FROM tmpemptytable").collect() + assert(result1.size === 0) + Utils.deleteRecursively(tmpdir) + } } From e73b48ace0a7e0f249221240140235d33eeac36b Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 23 Sep 2014 11:58:05 -0700 Subject: [PATCH 08/22] SPARK-2745 [STREAMING] Add Java friendly methods to Duration class tdas is this what you had in mind for this JIRA? I saw this one and thought it would be easy to take care of, and helpful as I use streaming from Java. I could do the same for `Time`? Happy to do so. Author: Sean Owen Closes #2403 from srowen/SPARK-2745 and squashes the following commits: 5a9e706 [Sean Owen] Change "Duration" to "Durations" to avoid changing Duration case class API bda301c [Sean Owen] Just delegate to Scala binary operator syntax to avoid scalastyle warning 7dde949 [Sean Owen] Disable scalastyle for false positives. Add Java static factory methods seconds(), minutes() to Duration. Add Java-friendly methods to Time too, and unit tests. Remove unnecessary math.floor from Time.floor() 4dee32e [Sean Owen] Add named methods to Duration in parallel to symbolic methods for Java-friendliness. Also add unit tests for Duration, in Scala and Java. --- .../org/apache/spark/streaming/Duration.scala | 39 ++++++ .../org/apache/spark/streaming/Time.scala | 20 +++- .../spark/streaming/JavaDurationSuite.java | 84 +++++++++++++ .../apache/spark/streaming/JavaTimeSuite.java | 63 ++++++++++ .../spark/streaming/DurationSuite.scala | 110 +++++++++++++++++ .../apache/spark/streaming/TimeSuite.scala | 111 ++++++++++++++++++ 6 files changed, 425 insertions(+), 2 deletions(-) create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaDurationSuite.java create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaTimeSuite.java create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/DurationSuite.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala index 6bf275f5afcb2..a0d8fb5ab93ec 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Duration.scala @@ -37,6 +37,25 @@ case class Duration (private val millis: Long) { def / (that: Duration): Double = millis.toDouble / that.millis.toDouble + // Java-friendlier versions of the above. + + def less(that: Duration): Boolean = this < that + + def lessEq(that: Duration): Boolean = this <= that + + def greater(that: Duration): Boolean = this > that + + def greaterEq(that: Duration): Boolean = this >= that + + def plus(that: Duration): Duration = this + that + + def minus(that: Duration): Duration = this - that + + def times(times: Int): Duration = this * times + + def div(that: Duration): Double = this / that + + def isMultipleOf(that: Duration): Boolean = (this.millis % that.millis == 0) @@ -80,4 +99,24 @@ object Minutes { def apply(minutes: Long) = new Duration(minutes * 60000) } +// Java-friendlier versions of the objects above. +// Named "Durations" instead of "Duration" to avoid changing the case class's implied API. + +object Durations { + + /** + * @return [[org.apache.spark.streaming.Duration]] representing given number of milliseconds. + */ + def milliseconds(milliseconds: Long) = Milliseconds(milliseconds) + /** + * @return [[org.apache.spark.streaming.Duration]] representing given number of seconds. + */ + def seconds(seconds: Long) = Seconds(seconds) + + /** + * @return [[org.apache.spark.streaming.Duration]] representing given number of minutes. + */ + def minutes(minutes: Long) = Minutes(minutes) + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala index 37b3b28fa01cb..42c49678d24f0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala @@ -41,10 +41,26 @@ case class Time(private val millis: Long) { def - (that: Duration): Time = new Time(millis - that.milliseconds) + // Java-friendlier versions of the above. + + def less(that: Time): Boolean = this < that + + def lessEq(that: Time): Boolean = this <= that + + def greater(that: Time): Boolean = this > that + + def greaterEq(that: Time): Boolean = this >= that + + def plus(that: Duration): Time = this + that + + def minus(that: Time): Duration = this - that + + def minus(that: Duration): Time = this - that + + def floor(that: Duration): Time = { val t = that.milliseconds - val m = math.floor(this.millis / t).toLong - new Time(m * t) + new Time((this.millis / t) * t) } def isMultipleOf(that: Duration): Boolean = diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaDurationSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaDurationSuite.java new file mode 100644 index 0000000000000..76425fe2aa2d3 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaDurationSuite.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.streaming; + +import org.junit.Assert; +import org.junit.Test; + +public class JavaDurationSuite { + + // Just testing the methods that are specially exposed for Java. + // This does not repeat all tests found in the Scala suite. + + @Test + public void testLess() { + Assert.assertTrue(new Duration(999).less(new Duration(1000))); + } + + @Test + public void testLessEq() { + Assert.assertTrue(new Duration(1000).lessEq(new Duration(1000))); + } + + @Test + public void testGreater() { + Assert.assertTrue(new Duration(1000).greater(new Duration(999))); + } + + @Test + public void testGreaterEq() { + Assert.assertTrue(new Duration(1000).greaterEq(new Duration(1000))); + } + + @Test + public void testPlus() { + Assert.assertEquals(new Duration(1100), new Duration(1000).plus(new Duration(100))); + } + + @Test + public void testMinus() { + Assert.assertEquals(new Duration(900), new Duration(1000).minus(new Duration(100))); + } + + @Test + public void testTimes() { + Assert.assertEquals(new Duration(200), new Duration(100).times(2)); + } + + @Test + public void testDiv() { + Assert.assertEquals(200.0, new Duration(1000).div(new Duration(5)), 1.0e-12); + } + + @Test + public void testMilliseconds() { + Assert.assertEquals(new Duration(100), Durations.milliseconds(100)); + } + + @Test + public void testSeconds() { + Assert.assertEquals(new Duration(30 * 1000), Durations.seconds(30)); + } + + @Test + public void testMinutes() { + Assert.assertEquals(new Duration(2 * 60 * 1000), Durations.minutes(2)); + } + +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTimeSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTimeSuite.java new file mode 100644 index 0000000000000..ad6b1853e3d12 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTimeSuite.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.junit.Assert; +import org.junit.Test; + +public class JavaTimeSuite { + + // Just testing the methods that are specially exposed for Java. + // This does not repeat all tests found in the Scala suite. + + @Test + public void testLess() { + Assert.assertTrue(new Time(999).less(new Time(1000))); + } + + @Test + public void testLessEq() { + Assert.assertTrue(new Time(1000).lessEq(new Time(1000))); + } + + @Test + public void testGreater() { + Assert.assertTrue(new Time(1000).greater(new Time(999))); + } + + @Test + public void testGreaterEq() { + Assert.assertTrue(new Time(1000).greaterEq(new Time(1000))); + } + + @Test + public void testPlus() { + Assert.assertEquals(new Time(1100), new Time(1000).plus(new Duration(100))); + } + + @Test + public void testMinusTime() { + Assert.assertEquals(new Duration(900), new Time(1000).minus(new Time(100))); + } + + @Test + public void testMinusDuration() { + Assert.assertEquals(new Time(900), new Time(1000).minus(new Duration(100))); + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DurationSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DurationSuite.scala new file mode 100644 index 0000000000000..6202250e897f2 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/DurationSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +class DurationSuite extends TestSuiteBase { + + test("less") { + assert(new Duration(999) < new Duration(1000)) + assert(new Duration(0) < new Duration(1)) + assert(!(new Duration(1000) < new Duration(999))) + assert(!(new Duration(1000) < new Duration(1000))) + } + + test("lessEq") { + assert(new Duration(999) <= new Duration(1000)) + assert(new Duration(0) <= new Duration(1)) + assert(!(new Duration(1000) <= new Duration(999))) + assert(new Duration(1000) <= new Duration(1000)) + } + + test("greater") { + assert(!(new Duration(999) > new Duration(1000))) + assert(!(new Duration(0) > new Duration(1))) + assert(new Duration(1000) > new Duration(999)) + assert(!(new Duration(1000) > new Duration(1000))) + } + + test("greaterEq") { + assert(!(new Duration(999) >= new Duration(1000))) + assert(!(new Duration(0) >= new Duration(1))) + assert(new Duration(1000) >= new Duration(999)) + assert(new Duration(1000) >= new Duration(1000)) + } + + test("plus") { + assert((new Duration(1000) + new Duration(100)) == new Duration(1100)) + assert((new Duration(1000) + new Duration(0)) == new Duration(1000)) + } + + test("minus") { + assert((new Duration(1000) - new Duration(100)) == new Duration(900)) + assert((new Duration(1000) - new Duration(0)) == new Duration(1000)) + assert((new Duration(1000) - new Duration(1000)) == new Duration(0)) + } + + test("times") { + assert((new Duration(100) * 2) == new Duration(200)) + assert((new Duration(100) * 1) == new Duration(100)) + assert((new Duration(100) * 0) == new Duration(0)) + } + + test("div") { + assert((new Duration(1000) / new Duration(5)) == 200.0) + assert((new Duration(1000) / new Duration(1)) == 1000.0) + assert((new Duration(1000) / new Duration(1000)) == 1.0) + assert((new Duration(1000) / new Duration(2000)) == 0.5) + } + + test("isMultipleOf") { + assert(new Duration(1000).isMultipleOf(new Duration(5))) + assert(new Duration(1000).isMultipleOf(new Duration(1000))) + assert(new Duration(1000).isMultipleOf(new Duration(1))) + assert(!new Duration(1000).isMultipleOf(new Duration(6))) + } + + test("min") { + assert(new Duration(999).min(new Duration(1000)) == new Duration(999)) + assert(new Duration(1000).min(new Duration(999)) == new Duration(999)) + assert(new Duration(1000).min(new Duration(1000)) == new Duration(1000)) + } + + test("max") { + assert(new Duration(999).max(new Duration(1000)) == new Duration(1000)) + assert(new Duration(1000).max(new Duration(999)) == new Duration(1000)) + assert(new Duration(1000).max(new Duration(1000)) == new Duration(1000)) + } + + test("isZero") { + assert(new Duration(0).isZero) + assert(!(new Duration(1).isZero)) + } + + test("Milliseconds") { + assert(new Duration(100) == Milliseconds(100)) + } + + test("Seconds") { + assert(new Duration(30 * 1000) == Seconds(30)) + } + + test("Minutes") { + assert(new Duration(2 * 60 * 1000) == Minutes(2)) + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala new file mode 100644 index 0000000000000..5579ac364346c --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/TimeSuite.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +class TimeSuite extends TestSuiteBase { + + test("less") { + assert(new Time(999) < new Time(1000)) + assert(new Time(0) < new Time(1)) + assert(!(new Time(1000) < new Time(999))) + assert(!(new Time(1000) < new Time(1000))) + } + + test("lessEq") { + assert(new Time(999) <= new Time(1000)) + assert(new Time(0) <= new Time(1)) + assert(!(new Time(1000) <= new Time(999))) + assert(new Time(1000) <= new Time(1000)) + } + + test("greater") { + assert(!(new Time(999) > new Time(1000))) + assert(!(new Time(0) > new Time(1))) + assert(new Time(1000) > new Time(999)) + assert(!(new Time(1000) > new Time(1000))) + } + + test("greaterEq") { + assert(!(new Time(999) >= new Time(1000))) + assert(!(new Time(0) >= new Time(1))) + assert(new Time(1000) >= new Time(999)) + assert(new Time(1000) >= new Time(1000)) + } + + test("plus") { + assert((new Time(1000) + new Duration(100)) == new Time(1100)) + assert((new Time(1000) + new Duration(0)) == new Time(1000)) + } + + test("minus Time") { + assert((new Time(1000) - new Time(100)) == new Duration(900)) + assert((new Time(1000) - new Time(0)) == new Duration(1000)) + assert((new Time(1000) - new Time(1000)) == new Duration(0)) + } + + test("minus Duration") { + assert((new Time(1000) - new Duration(100)) == new Time(900)) + assert((new Time(1000) - new Duration(0)) == new Time(1000)) + assert((new Time(1000) - new Duration(1000)) == new Time(0)) + } + + test("floor") { + assert(new Time(1350).floor(new Duration(200)) == new Time(1200)) + assert(new Time(1200).floor(new Duration(200)) == new Time(1200)) + assert(new Time(199).floor(new Duration(200)) == new Time(0)) + assert(new Time(1).floor(new Duration(1)) == new Time(1)) + } + + test("isMultipleOf") { + assert(new Time(1000).isMultipleOf(new Duration(5))) + assert(new Time(1000).isMultipleOf(new Duration(1000))) + assert(new Time(1000).isMultipleOf(new Duration(1))) + assert(!new Time(1000).isMultipleOf(new Duration(6))) + } + + test("min") { + assert(new Time(999).min(new Time(1000)) == new Time(999)) + assert(new Time(1000).min(new Time(999)) == new Time(999)) + assert(new Time(1000).min(new Time(1000)) == new Time(1000)) + } + + test("max") { + assert(new Time(999).max(new Time(1000)) == new Time(1000)) + assert(new Time(1000).max(new Time(999)) == new Time(1000)) + assert(new Time(1000).max(new Time(1000)) == new Time(1000)) + } + + test("until") { + assert(new Time(1000).until(new Time(1100), new Duration(100)) == + Seq(Time(1000))) + assert(new Time(1000).until(new Time(1000), new Duration(100)) == + Seq()) + assert(new Time(1000).until(new Time(1100), new Duration(30)) == + Seq(Time(1000), Time(1030), Time(1060), Time(1090))) + } + + test("to") { + assert(new Time(1000).to(new Time(1100), new Duration(100)) == + Seq(Time(1000), Time(1100))) + assert(new Time(1000).to(new Time(1000), new Duration(100)) == + Seq(Time(1000))) + assert(new Time(1000).to(new Time(1100), new Duration(30)) == + Seq(Time(1000), Time(1030), Time(1060), Time(1090))) + } + +} From ae60f8fb2d879ee1ebc0746bcbe05b89ab6ed3c9 Mon Sep 17 00:00:00 2001 From: wangfei Date: Tue, 23 Sep 2014 11:59:44 -0700 Subject: [PATCH 09/22] [SPARK-3481][SQL] removes the evil MINOR HACK a follow up of https://github.com/apache/spark/pull/2377 and https://github.com/apache/spark/pull/2352, see detail there. Author: wangfei Closes #2505 from scwf/patch-6 and squashes the following commits: 4874ec8 [wangfei] removes the evil MINOR HACK --- .../org/apache/spark/sql/hive/execution/PruningSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 8275e2d3bcce3..8474d850c9c6c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -28,8 +28,6 @@ import scala.collection.JavaConversions._ * A set of test cases that validate partition and column pruning. */ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { - // MINOR HACK: You must run a query before calling reset the first time. - TestHive.sql("SHOW TABLES") TestHive.cacheTables = false // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset From 1c62f97e94de96ca3dc6daf778f008176e92888a Mon Sep 17 00:00:00 2001 From: Venkata Ramana Gollamudi Date: Tue, 23 Sep 2014 12:17:47 -0700 Subject: [PATCH 10/22] [SPARK-3268][SQL] DoubleType, FloatType and DecimalType modulus support Supported modulus operation using % operator on fractional datatypes FloatType, DoubleType and DecimalType Example: SELECT 1388632775.0 % 60 from tablename LIMIT 1 Author : Venkata Ramana Gollamudi ramana.gollamudihuawei.com Author: Venkata Ramana Gollamudi Closes #2457 from gvramana/double_modulus_support and squashes the following commits: 79172a8 [Venkata Ramana Gollamudi] Add hive cache to testcase c09bd5b [Venkata Ramana Gollamudi] Added a HiveQuerySuite testcase 193fa81 [Venkata Ramana Gollamudi] corrected testcase 3624471 [Venkata Ramana Gollamudi] modified testcase e112c09 [Venkata Ramana Gollamudi] corrected the testcase 513d0e0 [Venkata Ramana Gollamudi] modified to add modulus support to fractional types float,double,decimal 296d253 [Venkata Ramana Gollamudi] modified to add modulus support to fractional types float,double,decimal --- .../sql/catalyst/expressions/Expression.scala | 3 ++ .../spark/sql/catalyst/types/dataTypes.scala | 5 +++ .../ExpressionEvaluationSuite.scala | 32 +++++++++++++++++++ ...modulus-0-6afd4a359a478cfa3ebd9ad00ae3868e | 1 + .../sql/hive/execution/HiveQuerySuite.scala | 3 ++ 5 files changed, 44 insertions(+) create mode 100644 sql/hive/src/test/resources/golden/modulus-0-6afd4a359a478cfa3ebd9ad00ae3868e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 70507e7ee2be8..1eb260efa6387 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -179,6 +179,9 @@ abstract class Expression extends TreeNode[Expression] { case i: IntegralType => f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType]( i.integral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) + case i: FractionalType => + f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType]( + i.asIntegral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) case other => sys.error(s"Type $other does not support numeric operations") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index e3050e5397937..c7d73d3990c3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.types import java.sql.Timestamp +import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag, runtimeMirror} import scala.util.parsing.combinator.RegexParsers @@ -250,6 +251,7 @@ object FractionalType { } abstract class FractionalType extends NumericType { private[sql] val fractional: Fractional[JvmType] + private[sql] val asIntegral: Integral[JvmType] } case object DecimalType extends FractionalType { @@ -258,6 +260,7 @@ case object DecimalType extends FractionalType { private[sql] val numeric = implicitly[Numeric[BigDecimal]] private[sql] val fractional = implicitly[Fractional[BigDecimal]] private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val asIntegral = BigDecimalAsIfIntegral def simpleString: String = "decimal" } @@ -267,6 +270,7 @@ case object DoubleType extends FractionalType { private[sql] val numeric = implicitly[Numeric[Double]] private[sql] val fractional = implicitly[Fractional[Double]] private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val asIntegral = DoubleAsIfIntegral def simpleString: String = "double" } @@ -276,6 +280,7 @@ case object FloatType extends FractionalType { private[sql] val numeric = implicitly[Numeric[Float]] private[sql] val fractional = implicitly[Fractional[Float]] private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val asIntegral = FloatAsIfIntegral def simpleString: String = "float" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 8b6721d5d8125..63931af4bac3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp import org.scalatest.FunSuite +import org.scalatest.Matchers._ +import org.scalautils.TripleEqualsSupport.Spread import org.apache.spark.sql.catalyst.types._ @@ -129,6 +131,13 @@ class ExpressionEvaluationSuite extends FunSuite { } } + def checkDoubleEvaluation(expression: Expression, expected: Spread[Double], inputRow: Row = EmptyRow): Unit = { + val actual = try evaluate(expression, inputRow) catch { + case e: Exception => fail(s"Exception evaluating $expression", e) + } + actual.asInstanceOf[Double] shouldBe expected + } + test("IN") { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) @@ -471,6 +480,29 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(c1 % c2, 1, row) } + test("fractional arithmetic") { + val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null)) + val c1 = 'a.double.at(0) + val c2 = 'a.double.at(1) + val c3 = 'a.double.at(2) + val c4 = 'a.double.at(3) + + checkEvaluation(UnaryMinus(c1), -1.1, row) + checkEvaluation(UnaryMinus(Literal(100.0, DoubleType)), -100.0) + checkEvaluation(Add(c1, c4), null, row) + checkEvaluation(Add(c1, c2), 3.1, row) + checkEvaluation(Add(c1, Literal(null, DoubleType)), null, row) + checkEvaluation(Add(Literal(null, DoubleType), c2), null, row) + checkEvaluation(Add(Literal(null, DoubleType), Literal(null, DoubleType)), null, row) + + checkEvaluation(-c1, -1.1, row) + checkEvaluation(c1 + c2, 3.1, row) + checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row) + checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row) + checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row) + checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row) + } + test("BinaryComparison") { val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null)) val c1 = 'a.int.at(0) diff --git a/sql/hive/src/test/resources/golden/modulus-0-6afd4a359a478cfa3ebd9ad00ae3868e b/sql/hive/src/test/resources/golden/modulus-0-6afd4a359a478cfa3ebd9ad00ae3868e new file mode 100644 index 0000000000000..52eab0653c505 --- /dev/null +++ b/sql/hive/src/test/resources/golden/modulus-0-6afd4a359a478cfa3ebd9ad00ae3868e @@ -0,0 +1 @@ +1 true 0.5 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6fc891ba4cca5..426f5fcee6157 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -138,6 +138,9 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("division", "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1") + createQueryTest("modulus", + "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), (101 / 2) % 10 FROM src LIMIT 1") + test("Query expressed in SQL") { setConf("spark.sql.dialect", "sql") assert(sql("SELECT 1").collect() === Array(Seq(1))) From a08153f8a3e7bad81bae330ec4152651da5e7804 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 23 Sep 2014 12:27:12 -0700 Subject: [PATCH 11/22] [SPARK-3646][SQL] Copy SQL configuration from SparkConf when a SQLContext is created. This will allow us to take advantage of things like the spark.defaults file. Author: Michael Armbrust Closes #2493 from marmbrus/copySparkConf and squashes the following commits: 0bd1377 [Michael Armbrust] Copy SQL configuration from SparkConf when a SQLContext is created. --- .../main/scala/org/apache/spark/sql/SQLContext.scala | 5 +++++ .../org/apache/spark/sql/test/TestSQLContext.scala | 6 +++++- .../scala/org/apache/spark/sql/SQLConfSuite.scala | 11 ++++++++++- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b245e1a863cc3..a42bedbe6c04e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -75,6 +75,11 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } + sparkContext.getConf.getAll.foreach { + case (key, value) if key.startsWith("spark.sql") => setConf(key, value) + case _ => + } + /** * :: DeveloperApi :: * Allows catalyst LogicalPlans to be executed as a SchemaRDD. Note that the LogicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index 265b67737c475..6bb81c76ed8bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -22,7 +22,11 @@ import org.apache.spark.sql.{SQLConf, SQLContext} /** A SQLContext that can be used for local testing. */ object TestSQLContext - extends SQLContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) { + extends SQLContext( + new SparkContext( + "local[2]", + "TestSQLContext", + new SparkConf().set("spark.sql.testkey", "true"))) { /** Fewer partitions to speed up testing. */ override private[spark] def numShufflePartitions: Int = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 584f71b3c13d5..60701f0e154f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,16 +17,25 @@ package org.apache.spark.sql +import org.scalatest.FunSuiteLike + import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ -class SQLConfSuite extends QueryTest { +class SQLConfSuite extends QueryTest with FunSuiteLike { val testKey = "test.key.0" val testVal = "test.val.0" + test("propagate from spark conf") { + // We create a new context here to avoid order dependence with other tests that might call + // clear(). + val newContext = new SQLContext(TestSQLContext.sparkContext) + assert(newContext.getConf("spark.sql.testkey", "false") == "true") + } + test("programmatic ways of basic setting and getting") { clear() assert(getAllConfs.size === 0) From 8dfe79ffb204807945e3c09b75c7255b09ad2a97 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 23 Sep 2014 13:42:00 -0700 Subject: [PATCH 12/22] [SPARK-3647] Add more exceptions to Guava relocation. Guava's Optional refers to some package private classes / methods, and when those are relocated the code stops working, throwing exceptions. So add the affected classes to the exception list too, and add a unit test. (Note that this unit test only really makes sense in maven, since we don't relocate in the sbt build. Also, JavaAPISuite doesn't seem to be run by "mvn test" - I had to manually add command line options to enable it.) Author: Marcelo Vanzin Closes #2496 from vanzin/SPARK-3647 and squashes the following commits: 84f58d7 [Marcelo Vanzin] [SPARK-3647] Add more exceptions to Guava relocation. --- assembly/pom.xml | 4 ++- core/pom.xml | 2 ++ .../java/org/apache/spark/JavaAPISuite.java | 26 +++++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 604b1ab3de6a8..5ec9da22ae83f 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -141,7 +141,9 @@ com.google.common.** - com.google.common.base.Optional** + com/google/common/base/Absent* + com/google/common/base/Optional* + com/google/common/base/Present* diff --git a/core/pom.xml b/core/pom.xml index 2a81f6df289c0..e012c5e673b74 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -343,7 +343,9 @@ com.google.guava:guava + com/google/common/base/Absent* com/google/common/base/Optional* + com/google/common/base/Present* diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index b8574dfb42e6b..b8c23d524e00b 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1307,4 +1307,30 @@ public void collectUnderlyingScalaRDD() { SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect(); Assert.assertEquals(data.size(), collected.length); } + + /** + * Test for SPARK-3647. This test needs to use the maven-built assembly to trigger the issue, + * since that's the only artifact where Guava classes have been relocated. + */ + @Test + public void testGuavaOptional() { + // Stop the context created in setUp() and start a local-cluster one, to force usage of the + // assembly. + sc.stop(); + JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,512]", "JavaAPISuite"); + try { + JavaRDD rdd1 = localCluster.parallelize(Arrays.asList(1, 2, null), 3); + JavaRDD> rdd2 = rdd1.map( + new Function>() { + @Override + public Optional call(Integer i) { + return Optional.fromNullable(i); + } + }); + rdd2.collect(); + } finally { + localCluster.stop(); + } + } + } From d79238d03a2ffe0cf5fc6166543d67768693ddbe Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Tue, 23 Sep 2014 13:44:18 -0700 Subject: [PATCH 13/22] SPARK-3612. Executor shouldn't quit if heartbeat message fails to reach ... ...the driver Author: Sandy Ryza Closes #2487 from sryza/sandy-spark-3612 and squashes the following commits: 2b7353d [Sandy Ryza] SPARK-3612. Executor shouldn't quit if heartbeat message fails to reach the driver --- .../org/apache/spark/executor/Executor.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index acae448a9c66f..d7211ae465902 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -24,6 +24,7 @@ import java.util.concurrent._ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -375,12 +376,17 @@ private[spark] class Executor( } val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) - if (response.reregisterBlockManager) { - logWarning("Told to re-register on heartbeat") - env.blockManager.reregister() + try { + val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, + retryAttempts, retryIntervalMs, timeout) + if (response.reregisterBlockManager) { + logWarning("Told to re-register on heartbeat") + env.blockManager.reregister() + } + } catch { + case NonFatal(t) => logWarning("Issue communicating with driver in heartbeater", t) } + Thread.sleep(interval) } } From b3fef50e22fb3fe499f627179d17836a92dcb33a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 23 Sep 2014 14:00:33 -0700 Subject: [PATCH 14/22] [SPARK-3653] Respect SPARK_*_MEMORY for cluster mode `SPARK_DRIVER_MEMORY` was only used to start the `SparkSubmit` JVM, which becomes the driver only in client mode but not cluster mode. In cluster mode, this property is simply not propagated to the worker nodes. `SPARK_EXECUTOR_MEMORY` is picked up from `SparkContext`, but in cluster mode the driver runs on one of the worker machines, where this environment variable may not be set. Author: Andrew Or Closes #2500 from andrewor14/memory-env-vars and squashes the following commits: 6217b38 [Andrew Or] Respect SPARK_*_MEMORY for cluster mode --- .../scala/org/apache/spark/deploy/SparkSubmitArguments.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 92e0917743ed1..2b72c61cc8177 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -75,6 +75,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { defaultProperties } + // Respect SPARK_*_MEMORY for cluster mode + driverMemory = sys.env.get("SPARK_DRIVER_MEMORY").orNull + executorMemory = sys.env.get("SPARK_EXECUTOR_MEMORY").orNull + parseOpts(args.toList) mergeSparkProperties() checkRequiredArguments() From 729952a5efce755387c76cdf29280ee6f49fdb72 Mon Sep 17 00:00:00 2001 From: Mubarak Seyed Date: Tue, 23 Sep 2014 15:09:12 -0700 Subject: [PATCH 15/22] [SPARK-1853] Show Streaming application code context (file, line number) in Spark Stages UI This is a refactored version of the original PR https://github.com/apache/spark/pull/1723 my mubarak Please take a look andrewor14, mubarak Author: Mubarak Seyed Author: Tathagata Das Closes #2464 from tdas/streaming-callsite and squashes the following commits: dc54c71 [Tathagata Das] Made changes based on PR comments. 390b45d [Tathagata Das] Fixed minor bugs. 904cd92 [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into streaming-callsite 7baa427 [Tathagata Das] Refactored getCallSite and setCallSite to make it simpler. Also added unit test for DStream creation site. b9ed945 [Mubarak Seyed] Adding streaming utils c461cf4 [Mubarak Seyed] Merge remote-tracking branch 'upstream/master' ceb43da [Mubarak Seyed] Changing default regex function name 8c5d443 [Mubarak Seyed] Merge remote-tracking branch 'upstream/master' 196121b [Mubarak Seyed] Merge remote-tracking branch 'upstream/master' 491a1eb [Mubarak Seyed] Removing streaming visibility from getRDDCreationCallSite in DStream 33a7295 [Mubarak Seyed] Fixing review comments: Merging both setCallSite methods c26d933 [Mubarak Seyed] Merge remote-tracking branch 'upstream/master' f51fd9f [Mubarak Seyed] Fixing scalastyle, Regex for Utils.getCallSite, and changing method names in DStream 5051c58 [Mubarak Seyed] Getting return value of compute() into variable and call setCallSite(prevCallSite) only once. Adding return for other code paths (for None) a207eb7 [Mubarak Seyed] Fixing code review comments ccde038 [Mubarak Seyed] Removing Utils import from MappedDStream 2a09ad6 [Mubarak Seyed] Changes in Utils.scala for SPARK-1853 1d90cc3 [Mubarak Seyed] Changes for SPARK-1853 5f3105a [Mubarak Seyed] Merge remote-tracking branch 'upstream/master' 70f494f [Mubarak Seyed] Changes for SPARK-1853 1500deb [Mubarak Seyed] Changes in Spark Streaming UI 9d38d3c [Mubarak Seyed] [SPARK-1853] Show Streaming application code context (file, line number) in Spark Stages UI d466d75 [Mubarak Seyed] Changes for spark streaming UI --- .../scala/org/apache/spark/SparkContext.scala | 32 +++++-- .../main/scala/org/apache/spark/rdd/RDD.scala | 7 +- .../scala/org/apache/spark/util/Utils.scala | 27 ++++-- .../spark/streaming/StreamingContext.scala | 4 +- .../spark/streaming/dstream/DStream.scala | 96 ++++++++++++------- .../streaming/StreamingContextSuite.scala | 45 ++++++++- 6 files changed, 153 insertions(+), 58 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 428f019b02a23..979d178c35969 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1030,28 +1030,40 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Support function for API backtraces. + * Set the thread-local property for overriding the call sites + * of actions and RDDs. */ - def setCallSite(site: String) { - setLocalProperty("externalCallSite", site) + def setCallSite(shortCallSite: String) { + setLocalProperty(CallSite.SHORT_FORM, shortCallSite) } /** - * Support function for API backtraces. + * Set the thread-local property for overriding the call sites + * of actions and RDDs. + */ + private[spark] def setCallSite(callSite: CallSite) { + setLocalProperty(CallSite.SHORT_FORM, callSite.shortForm) + setLocalProperty(CallSite.LONG_FORM, callSite.longForm) + } + + /** + * Clear the thread-local property for overriding the call sites + * of actions and RDDs. */ def clearCallSite() { - setLocalProperty("externalCallSite", null) + setLocalProperty(CallSite.SHORT_FORM, null) + setLocalProperty(CallSite.LONG_FORM, null) } /** * Capture the current user callsite and return a formatted version for printing. If the user - * has overridden the call site, this will return the user's version. + * has overridden the call site using `setCallSite()`, this will return the user's version. */ private[spark] def getCallSite(): CallSite = { - Option(getLocalProperty("externalCallSite")) match { - case Some(callSite) => CallSite(callSite, longForm = "") - case None => Utils.getCallSite - } + Option(getLocalProperty(CallSite.SHORT_FORM)).map { case shortCallSite => + val longCallSite = Option(getLocalProperty(CallSite.LONG_FORM)).getOrElse("") + CallSite(shortCallSite, longCallSite) + }.getOrElse(Utils.getCallSite()) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a9b905b0d1a63..0e90caa5c9ca7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.util.Random +import java.util.{Properties, Random} import scala.collection.{mutable, Map} import scala.collection.mutable.ArrayBuffer @@ -41,7 +41,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1224,7 +1224,8 @@ abstract class RDD[T: ClassTag]( private var storageLevel: StorageLevel = StorageLevel.NONE /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ - @transient private[spark] val creationSite = Utils.getCallSite + @transient private[spark] val creationSite = sc.getCallSite() + private[spark] def getCreationSite: String = Option(creationSite).map(_.shortForm).getOrElse("") private[spark] def elementClassTag: ClassTag[T] = classTag[T] diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ed063844323af..2755887feeeff 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -49,6 +49,11 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream, /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) +private[spark] object CallSite { + val SHORT_FORM = "callSite.short" + val LONG_FORM = "callSite.long" +} + /** * Various utility methods used by Spark. */ @@ -859,18 +864,26 @@ private[spark] object Utils extends Logging { } } - /** - * A regular expression to match classes of the "core" Spark API that we want to skip when - * finding the call site of a method. - */ - private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r + /** Default filtering function for finding call sites using `getCallSite`. */ + private def coreExclusionFunction(className: String): Boolean = { + // A regular expression to match classes of the "core" Spark API that we want to skip when + // finding the call site of a method. + val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r + val SCALA_CLASS_REGEX = """^scala""".r + val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined + val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined + // If the class is a Spark internal class or a Scala class, then exclude. + isSparkCoreClass || isScalaClass + } /** * When called inside a class in the spark package, returns the name of the user code class * (outside the spark package) that called into Spark, as well as which Spark method they called. * This is used, for example, to tell users where in their code each RDD got created. + * + * @param skipClass Function that is used to exclude non-user-code classes. */ - def getCallSite: CallSite = { + def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = { val trace = Thread.currentThread.getStackTrace() .filterNot { ste:StackTraceElement => // When running under some profilers, the current stack trace might contain some bogus @@ -891,7 +904,7 @@ private[spark] object Utils extends Logging { for (el <- trace) { if (insideSpark) { - if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName).isDefined) { + if (skipClass(el.getClassName)) { lastSparkMethod = if (el.getMethodName == "") { // Spark method is a constructor; get its class name el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index f63560dcb5b89..5a8eef1372e23 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -35,10 +35,9 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receiver.{ActorSupervisorStrategy, ActorReceiver, Receiver} +import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.MetadataCleaner /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -448,6 +447,7 @@ class StreamingContext private[streaming] ( throw new SparkException("StreamingContext has already been stopped") } validate() + sparkContext.setCallSite(DStream.getCreationSite()) scheduler.start() state = Started } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index e05db236addca..65f7ccd318684 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -23,6 +23,7 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import scala.deprecated import scala.collection.mutable.HashMap import scala.reflect.ClassTag +import scala.util.matching.Regex import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.{BlockRDD, RDD} @@ -30,7 +31,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.scheduler.Job -import org.apache.spark.util.MetadataCleaner +import org.apache.spark.util.{CallSite, MetadataCleaner} /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous @@ -106,6 +107,9 @@ abstract class DStream[T: ClassTag] ( /** Return the StreamingContext associated with this DStream */ def context = ssc + /* Set the creation call site */ + private[streaming] val creationSite = DStream.getCreationSite() + /** Persist the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { @@ -272,43 +276,41 @@ abstract class DStream[T: ClassTag] ( } /** - * Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal - * method that should not be called directly. + * Get the RDD corresponding to the given time; either retrieve it from cache + * or compute-and-cache it. */ private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { - // If this DStream was not initialized (i.e., zeroTime not set), then do it - // If RDD was already generated, then retrieve it from HashMap - generatedRDDs.get(time) match { - - // If an RDD was already generated and is being reused, then - // probably all RDDs in this DStream will be reused and hence should be cached - case Some(oldRDD) => Some(oldRDD) - - // if RDD was not generated, and if the time is valid - // (based on sliding time of this DStream), then generate the RDD - case None => { - if (isTimeValid(time)) { - compute(time) match { - case Some(newRDD) => - if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting RDD " + newRDD.id + " for time " + - time + " to " + storageLevel + " at time " + time) - } - if (checkpointDuration != null && - (time - zeroTime).isMultipleOf(checkpointDuration)) { - newRDD.checkpoint() - logInfo("Marking RDD " + newRDD.id + " for time " + time + - " for checkpointing at time " + time) - } - generatedRDDs.put(time, newRDD) - Some(newRDD) - case None => - None + // If RDD was already generated, then retrieve it from HashMap, + // or else compute the RDD + generatedRDDs.get(time).orElse { + // Compute the RDD if time is valid (e.g. correct time in a sliding window) + // of RDD generation, else generate nothing. + if (isTimeValid(time)) { + // Set the thread-local property for call sites to this DStream's creation site + // such that RDDs generated by compute gets that as their creation site. + // Note that this `getOrCompute` may get called from another DStream which may have + // set its own call site. So we store its call site in a temporary variable, + // set this DStream's creation site, generate RDDs and then restore the previous call site. + val prevCallSite = ssc.sparkContext.getCallSite() + ssc.sparkContext.setCallSite(creationSite) + val rddOption = compute(time) + ssc.sparkContext.setCallSite(prevCallSite) + + rddOption.foreach { case newRDD => + // Register the generated RDD for caching and checkpointing + if (storageLevel != StorageLevel.NONE) { + newRDD.persist(storageLevel) + logDebug(s"Persisting RDD ${newRDD.id} for time $time to $storageLevel") } - } else { - None + if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) { + newRDD.checkpoint() + logInfo(s"Marking RDD ${newRDD.id} for time $time for checkpointing") + } + generatedRDDs.put(time, newRDD) } + rddOption + } else { + None } } } @@ -799,3 +801,29 @@ abstract class DStream[T: ClassTag] ( this } } + +private[streaming] object DStream { + + /** Get the creation site of a DStream from the stack trace of when the DStream is created. */ + def getCreationSite(): CallSite = { + val SPARK_CLASS_REGEX = """^org\.apache\.spark""".r + val SPARK_STREAMING_TESTCLASS_REGEX = """^org\.apache\.spark\.streaming\.test""".r + val SPARK_EXAMPLES_CLASS_REGEX = """^org\.apache\.spark\.examples""".r + val SCALA_CLASS_REGEX = """^scala""".r + + /** Filtering function that excludes non-user classes for a streaming application */ + def streamingExclustionFunction(className: String): Boolean = { + def doesMatch(r: Regex) = r.findFirstIn(className).isDefined + val isSparkClass = doesMatch(SPARK_CLASS_REGEX) + val isSparkExampleClass = doesMatch(SPARK_EXAMPLES_CLASS_REGEX) + val isSparkStreamingTestClass = doesMatch(SPARK_STREAMING_TESTCLASS_REGEX) + val isScalaClass = doesMatch(SCALA_CLASS_REGEX) + + // If the class is a spark example class or a streaming test class then it is considered + // as a streaming application class and don't exclude. Otherwise, exclude any + // non-Spark and non-Scala class, as the rest would streaming application classes. + (isSparkClass || isScalaClass) && !isSparkExampleClass && !isSparkStreamingTestClass + } + org.apache.spark.util.Utils.getCallSite(streamingExclustionFunction) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index a3cabd6be02fe..ebf83748ffa28 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -19,13 +19,16 @@ package org.apache.spark.streaming import java.util.concurrent.atomic.AtomicInteger +import scala.language.postfixOps + import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.util.{MetadataCleaner, Utils} -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.util.Utils +import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.Eventually._ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ @@ -257,6 +260,10 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w assert(exception.getMessage.contains("transform"), "Expected exception not thrown") } + test("DStream and generated RDD creation sites") { + testPackage.test() + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => (1 to i)) val inputStream = new TestInputStream(s, input, 1) @@ -293,3 +300,37 @@ class TestReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging object TestReceiver { val counter = new AtomicInteger(1) } + +/** Streaming application for testing DStream and RDD creation sites */ +package object testPackage extends Assertions { + def test() { + val conf = new SparkConf().setMaster("local").setAppName("CreationSite test") + val ssc = new StreamingContext(conf , Milliseconds(100)) + try { + val inputStream = ssc.receiverStream(new TestReceiver) + + // Verify creation site of DStream + val creationSite = inputStream.creationSite + assert(creationSite.shortForm.contains("receiverStream") && + creationSite.shortForm.contains("StreamingContextSuite") + ) + assert(creationSite.longForm.contains("testPackage")) + + // Verify creation site of generated RDDs + var rddGenerated = false + var rddCreationSiteCorrect = true + + inputStream.foreachRDD { rdd => + rddCreationSiteCorrect = rdd.creationSite == creationSite + rddGenerated = true + } + ssc.start() + + eventually(timeout(10000 millis), interval(10 millis)) { + assert(rddGenerated && rddCreationSiteCorrect, "RDD creation site was not correct") + } + } finally { + ssc.stop() + } + } +} From c429126066f766396b706894b6942f1ca7fcb528 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Wed, 24 Sep 2014 11:33:58 -0700 Subject: [PATCH 16/22] [Build] Diff from branch point Sometimes Jenkins posts [spurious reports of new classes being added](https://github.com/apache/spark/pull/2339#issuecomment-56570170). I believe this stems from diffing the patch against `master`, as opposed to against `master...`, which starts from the commit the PR was branched from. This patch fixes that behavior. Author: Nicholas Chammas Closes #2512 from nchammas/diff-only-commits-ahead and squashes the following commits: c065599 [Nicholas Chammas] comment typo fix a453c67 [Nicholas Chammas] diff from branch point --- dev/run-tests-jenkins | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 06c3781eb3ccf..a6ecf3196d7d4 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -92,13 +92,13 @@ function post_message () { merge_note=" * This patch merges cleanly." source_files=$( - git diff master --name-only \ + git diff master... --name-only `# diff patch against master from branch point` \ | grep -v -e "\/test" `# ignore files in test directories` \ | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ | tr "\n" " " ) new_public_classes=$( - git diff master ${source_files} `# diff this patch against master and...` \ + git diff master... ${source_files} `# diff patch against master from branch point` \ | grep "^\+" `# filter in only added lines` \ | sed -r -e "s/^\+//g" `# remove the leading +` \ | grep -e "trait " -e "class " `# filter in lines with these key words` \ From 50f863365348d52a9285fc779efbedbf1567ea11 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 24 Sep 2014 11:34:39 -0700 Subject: [PATCH 17/22] [SPARK-3659] Set EC2 version to 1.1.0 and update version map This brings the master branch in sync with branch-1.1 Author: Shivaram Venkataraman Closes #2510 from shivaram/spark-ec2-version and squashes the following commits: bb0dd16 [Shivaram Venkataraman] Set EC2 version to 1.1.0 and update version map --- ec2/spark_ec2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index fbeccd89b43b3..7f2cd7d94de39 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -38,7 +38,7 @@ from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType from boto import ec2 -DEFAULT_SPARK_VERSION = "1.0.0" +DEFAULT_SPARK_VERSION = "1.1.0" # A URL prefix from which to fetch AMI information AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list" @@ -218,7 +218,7 @@ def is_active(instance): def get_spark_shark_version(opts): spark_shark_map = { "0.7.3": "0.7.1", "0.8.0": "0.8.0", "0.8.1": "0.8.1", "0.9.0": "0.9.0", "0.9.1": "0.9.1", - "1.0.0": "1.0.0" + "1.0.0": "1.0.0", "1.0.1": "1.0.1", "1.0.2": "1.0.2", "1.1.0": "1.1.0" } version = opts.spark_version.replace("v", "") if version not in spark_shark_map: From c854b9fcb5595b1d70b6ce257fc7574602ac5e49 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 24 Sep 2014 12:10:09 -0700 Subject: [PATCH 18/22] [SPARK-3634] [PySpark] User's module should take precedence over system modules Python modules added through addPyFile should take precedence over system modules. This patch put the path for user added module in the front of sys.path (just after ''). Author: Davies Liu Closes #2492 from davies/path and squashes the following commits: 4a2af78 [Davies Liu] fix tests f7ff4da [Davies Liu] ad license header 6b0002f [Davies Liu] add tests c16c392 [Davies Liu] put addPyFile in front of sys.path --- python/pyspark/context.py | 11 +++++------ python/pyspark/tests.py | 12 ++++++++++++ python/pyspark/worker.py | 11 +++++++++-- python/test_support/SimpleHTTPServer.py | 22 ++++++++++++++++++++++ 4 files changed, 48 insertions(+), 8 deletions(-) create mode 100644 python/test_support/SimpleHTTPServer.py diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 064a24bff539c..8e7b00469e246 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -171,7 +171,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, SparkFiles._sc = self root_dir = SparkFiles.getRootDirectory() - sys.path.append(root_dir) + sys.path.insert(1, root_dir) # Deploy any code dependencies specified in the constructor self._python_includes = list() @@ -183,10 +183,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - self._python_includes.append(filename) - sys.path.append(path) - if dirname not in sys.path: - sys.path.append(dirname) + if filename.lower().endswith("zip") or filename.lower().endswith("egg"): + self._python_includes.append(filename) + sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) @@ -667,7 +666,7 @@ def addPyFile(self, path): if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): self._python_includes.append(filename) # for tests in local mode - sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) + sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) def setCheckpointDir(self, dirName): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 1b8afb763b26a..4483bf80dbe06 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -323,6 +323,18 @@ def func(): from userlib import UserClass self.assertEqual("Hello World from inside a package!", UserClass().hello()) + def test_overwrite_system_module(self): + self.sc.addPyFile(os.path.join(SPARK_HOME, "python/test_support/SimpleHTTPServer.py")) + + import SimpleHTTPServer + self.assertEqual("My Server", SimpleHTTPServer.__name__) + + def func(x): + import SimpleHTTPServer + return SimpleHTTPServer.__name__ + + self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) + class TestRDDFunctions(PySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d6c06e2dbef62..c1f6e3e4a1f40 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -43,6 +43,13 @@ def report_times(outfile, boot, init, finish): write_long(1000 * finish, outfile) +def add_path(path): + # worker can be used, so donot add path multiple times + if path not in sys.path: + # overwrite system packages + sys.path.insert(1, path) + + def main(infile, outfile): try: boot_time = time.time() @@ -61,11 +68,11 @@ def main(infile, outfile): SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH - sys.path.append(spark_files_dir) # *.py files that were added will be copied here + add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) - sys.path.append(os.path.join(spark_files_dir, filename)) + add_path(os.path.join(spark_files_dir, filename)) # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) diff --git a/python/test_support/SimpleHTTPServer.py b/python/test_support/SimpleHTTPServer.py new file mode 100644 index 0000000000000..eddbd588e02dc --- /dev/null +++ b/python/test_support/SimpleHTTPServer.py @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Used to test override standard SimpleHTTPServer module. +""" + +__name__ = "My Server" From bb96012b7360b099a19fecc80f0209b30f118ada Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 24 Sep 2014 13:00:05 -0700 Subject: [PATCH 19/22] [SPARK-3679] [PySpark] pickle the exact globals of functions function.func_code.co_names has all the names used in the function, including name of attributes. It will pickle some unnecessary globals if there is a global having the same name with attribute (in co_names). There is a regression introduced by #2144, revert part of changes in that PR. cc JoshRosen Author: Davies Liu Closes #2522 from davies/globals and squashes the following commits: dfbccf5 [Davies Liu] fix bug while pickle globals of function --- python/pyspark/cloudpickle.py | 42 ++++++++++++++++++++++++++++++----- python/pyspark/tests.py | 18 +++++++++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 32dda3888c62d..bb0783555aa77 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -52,6 +52,7 @@ import itertools from copy_reg import _extension_registry, _inverted_registry, _extension_cache import new +import dis import traceback import platform @@ -61,6 +62,14 @@ import logging cloudLog = logging.getLogger("Cloud.Transport") +#relevant opcodes +STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) +DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) +LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] + +HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) +EXTENDED_ARG = chr(dis.EXTENDED_ARG) if PyImp == "PyPy": # register builtin type in `new` @@ -304,16 +313,37 @@ def save_function_tuple(self, func, forced_imports): write(pickle.REDUCE) # applies _fill_function on the tuple @staticmethod - def extract_code_globals(code): + def extract_code_globals(co): """ Find all globals names read or written to by codeblock co """ - names = set(code.co_names) - if code.co_consts: # see if nested function have any global refs - for const in code.co_consts: + code = co.co_code + names = co.co_names + out_names = set() + + n = len(code) + i = 0 + extended_arg = 0 + while i < n: + op = code[i] + + i = i+1 + if op >= HAVE_ARGUMENT: + oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + extended_arg = 0 + i = i+2 + if op == EXTENDED_ARG: + extended_arg = oparg*65536L + if op in GLOBAL_OPS: + out_names.add(names[oparg]) + #print 'extracted', out_names, ' from ', names + + if co.co_consts: # see if nested function have any global refs + for const in co.co_consts: if type(const) is types.CodeType: - names |= CloudPickler.extract_code_globals(const) - return names + out_names |= CloudPickler.extract_code_globals(const) + + return out_names def extract_func_data(self, func): """ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 4483bf80dbe06..d1bb2033b7a16 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -213,6 +213,24 @@ def test_pickling_file_handles(self): out2 = ser.loads(ser.dumps(out1)) self.assertEquals(out1, out2) + def test_func_globals(self): + + class Unpicklable(object): + def __reduce__(self): + raise Exception("not picklable") + + global exit + exit = Unpicklable() + + ser = CloudPickleSerializer() + self.assertRaises(Exception, lambda: ser.dumps(exit)) + + def foo(): + sys.exit(0) + + self.assertTrue("exit" in foo.func_code.co_names) + ser.dumps(foo) + class PySparkTestCase(unittest.TestCase): From 74fb2ecf7afc2d314f6477f8f2e6134614387453 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 24 Sep 2014 17:18:55 -0700 Subject: [PATCH 20/22] [SPARK-3615][Streaming]Fix Kafka unit test hard coded Zookeeper port issue Details can be seen in [SPARK-3615](https://issues.apache.org/jira/browse/SPARK-3615). Author: jerryshao Closes #2483 from jerryshao/SPARK_3615 and squashes the following commits: 8555563 [jerryshao] Fix Kafka unit test hard coded Zookeeper port issue --- .../streaming/kafka/JavaKafkaStreamSuite.java | 2 +- .../streaming/kafka/KafkaStreamSuite.scala | 46 +++++++++++++------ 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 0571454c01dae..efb0099c7c850 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -81,7 +81,7 @@ public void testKafkaStream() throws InterruptedException { Predef.>conforms())); HashMap kafkaParams = new HashMap(); - kafkaParams.put("zookeeper.connect", testSuite.zkConnect()); + kafkaParams.put("zookeeper.connect", testSuite.zkHost() + ":" + testSuite.zkPort()); kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index c0b55e9340253..6943326eb750e 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -24,7 +24,7 @@ import java.util.{Properties, Random} import scala.collection.mutable import kafka.admin.CreateTopicCommand -import kafka.common.TopicAndPartition +import kafka.common.{KafkaException, TopicAndPartition} import kafka.producer.{KeyedMessage, ProducerConfig, Producer} import kafka.utils.ZKStringSerializer import kafka.serializer.{StringDecoder, StringEncoder} @@ -42,14 +42,13 @@ import org.apache.spark.util.Utils class KafkaStreamSuite extends TestSuiteBase { import KafkaTestUtils._ - val zkConnect = "localhost:2181" + val zkHost = "localhost" + var zkPort: Int = 0 val zkConnectionTimeout = 6000 val zkSessionTimeout = 6000 - val brokerPort = 9092 - val brokerProps = getBrokerConfig(brokerPort, zkConnect) - val brokerConf = new KafkaConfig(brokerProps) - + protected var brokerPort = 9092 + protected var brokerConf: KafkaConfig = _ protected var zookeeper: EmbeddedZookeeper = _ protected var zkClient: ZkClient = _ protected var server: KafkaServer = _ @@ -59,16 +58,35 @@ class KafkaStreamSuite extends TestSuiteBase { override def beforeFunction() { // Zookeeper server startup - zookeeper = new EmbeddedZookeeper(zkConnect) + zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") + // Get the actual zookeeper binding port + zkPort = zookeeper.actualPort logInfo("==================== 0 ====================") - zkClient = new ZkClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) + + zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, + ZKStringSerializer) logInfo("==================== 1 ====================") // Kafka broker startup - server = new KafkaServer(brokerConf) - logInfo("==================== 2 ====================") - server.startup() - logInfo("==================== 3 ====================") + var bindSuccess: Boolean = false + while(!bindSuccess) { + try { + val brokerProps = getBrokerConfig(brokerPort, s"$zkHost:$zkPort") + brokerConf = new KafkaConfig(brokerProps) + server = new KafkaServer(brokerConf) + logInfo("==================== 2 ====================") + server.startup() + logInfo("==================== 3 ====================") + bindSuccess = true + } catch { + case e: KafkaException => + if (e.getMessage != null && e.getMessage.contains("Socket server failed to bind to")) { + brokerPort += 1 + } + case e: Exception => throw new Exception("Kafka server create failed", e) + } + } + Thread.sleep(2000) logInfo("==================== 4 ====================") super.beforeFunction() @@ -92,7 +110,7 @@ class KafkaStreamSuite extends TestSuiteBase { createTopic(topic) produceAndSendMessage(topic, sent) - val kafkaParams = Map("zookeeper.connect" -> zkConnect, + val kafkaParams = Map("zookeeper.connect" -> s"$zkHost:$zkPort", "group.id" -> s"test-consumer-${random.nextInt(10000)}", "auto.offset.reset" -> "smallest") @@ -200,6 +218,8 @@ object KafkaTestUtils { factory.configure(new InetSocketAddress(ip, port), 16) factory.startup(zookeeper) + val actualPort = factory.getLocalPort + def shutdown() { factory.shutdown() Utils.deleteRecursively(snapshotDir) From 8ca4ecb6a56b96bae21b33e27f6abdb53676683a Mon Sep 17 00:00:00 2001 From: Aaron Staple Date: Wed, 24 Sep 2014 20:39:09 -0700 Subject: [PATCH 21/22] [SPARK-546] Add full outer join to RDD and DStream. leftOuterJoin and rightOuterJoin are already implemented. This patch adds fullOuterJoin. Author: Aaron Staple Closes #1395 from staple/SPARK-546 and squashes the following commits: 1f5595c [Aaron Staple] Fix python style 7ac0aa9 [Aaron Staple] [SPARK-546] Add full outer join to RDD and DStream. 3b5d137 [Aaron Staple] In JavaPairDStream, make class tag specification in rightOuterJoin consistent with other functions. 31f2956 [Aaron Staple] Fix left outer join documentation comments. --- .../apache/spark/api/java/JavaPairRDD.scala | 48 +++++++++++++++++ .../apache/spark/rdd/PairRDDFunctions.scala | 42 +++++++++++++++ .../org/apache/spark/PartitioningSuite.scala | 3 ++ .../spark/rdd/PairRDDFunctionsSuite.scala | 15 ++++++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 1 + docs/programming-guide.md | 2 +- python/pyspark/join.py | 16 ++++++ python/pyspark/rdd.py | 25 ++++++++- .../streaming/api/java/JavaPairDStream.scala | 54 +++++++++++++++++-- .../dstream/PairDStreamFunctions.scala | 36 +++++++++++++ .../streaming/BasicOperationsSuite.scala | 15 ++++++ 11 files changed, 250 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 880f61c49726e..0846225e4f992 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -469,6 +469,22 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) } + /** + * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or + * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each + * element (k, w) in `other`, the resulting RDD will either contain all pairs + * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements + * in `this` have key k. Uses the given Partitioner to partition the output RDD. + */ + def fullOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner) + : JavaPairRDD[K, (Optional[V], Optional[W])] = { + val joinResult = rdd.fullOuterJoin(other, partitioner) + fromRDD(joinResult.mapValues{ case (v, w) => + (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w)) + }) + } + /** * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing * partitioner/parallelism level. @@ -563,6 +579,38 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) } + /** + * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or + * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each + * element (k, w) in `other`, the resulting RDD will either contain all pairs + * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements + * in `this` have key k. Hash-partitions the resulting RDD using the existing partitioner/ + * parallelism level. + */ + def fullOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Optional[V], Optional[W])] = { + val joinResult = rdd.fullOuterJoin(other) + fromRDD(joinResult.mapValues{ case (v, w) => + (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w)) + }) + } + + /** + * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or + * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each + * element (k, w) in `other`, the resulting RDD will either contain all pairs + * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements + * in `this` have key k. Hash-partitions the resulting RDD into the given number of partitions. + */ + def fullOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int) + : JavaPairRDD[K, (Optional[V], Optional[W])] = { + val joinResult = rdd.fullOuterJoin(other, numPartitions) + fromRDD(joinResult.mapValues{ case (v, w) => + (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w)) + }) + } + /** * Return the key-value pairs in this RDD to the master as a Map. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 51ba8c2d17834..7f578bc5dac39 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -506,6 +506,23 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } } + /** + * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or + * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each + * element (k, w) in `other`, the resulting RDD will either contain all pairs + * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements + * in `this` have key k. Uses the given Partitioner to partition the output RDD. + */ + def fullOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) + : RDD[(K, (Option[V], Option[W]))] = { + this.cogroup(other, partitioner).flatMapValues { + case (vs, Seq()) => vs.map(v => (Some(v), None)) + case (Seq(), ws) => ws.map(w => (None, Some(w))) + case (vs, ws) => for (v <- vs; w <- ws) yield (Some(v), Some(w)) + } + } + /** * Simplified version of combineByKey that hash-partitions the resulting RDD using the * existing partitioner/parallelism level. @@ -585,6 +602,31 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) rightOuterJoin(other, new HashPartitioner(numPartitions)) } + /** + * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or + * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each + * element (k, w) in `other`, the resulting RDD will either contain all pairs + * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements + * in `this` have key k. Hash-partitions the resulting RDD using the existing partitioner/ + * parallelism level. + */ + def fullOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], Option[W]))] = { + fullOuterJoin(other, defaultPartitioner(self, other)) + } + + /** + * Perform a full outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (Some(v), Some(w))) for w in `other`, or + * the pair (k, (Some(v), None)) if no elements in `other` have key k. Similarly, for each + * element (k, w) in `other`, the resulting RDD will either contain all pairs + * (k, (Some(v), Some(w))) for v in `this`, or the pair (k, (None, Some(w))) if no elements + * in `this` have key k. Hash-partitions the resulting RDD into the given number of partitions. + */ + def fullOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], Option[W]))] = { + fullOuterJoin(other, new HashPartitioner(numPartitions)) + } + /** * Return the key-value pairs in this RDD to the master as a Map. * diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index fc0cee3e8749d..646ede30ae6ff 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -193,11 +193,13 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet assert(grouped2.join(grouped4).partitioner === grouped4.partitioner) assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner) assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.fullOuterJoin(grouped4).partitioner === grouped4.partitioner) assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner) assert(grouped2.join(reduced2).partitioner === grouped2.partitioner) assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner) + assert(grouped2.fullOuterJoin(reduced2).partitioner === grouped2.partitioner) assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner) assert(grouped2.map(_ => 1).partitioner === None) @@ -218,6 +220,7 @@ class PartitioningSuite extends FunSuite with SharedSparkContext with PrivateMet assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array")) + assert(intercept[SparkException]{ arrPairs.fullOuterJoin(arrPairs) }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array")) assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array")) diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index e84cc69592339..75b01191901b8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -298,6 +298,21 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { )) } + test("fullOuterJoin") { + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.fullOuterJoin(rdd2).collect() + assert(joined.size === 6) + assert(joined.toSet === Set( + (1, (Some(1), Some('x'))), + (1, (Some(2), Some('x'))), + (2, (Some(1), Some('y'))), + (2, (Some(1), Some('z'))), + (3, (Some(1), None)), + (4, (None, Some('w'))) + )) + } + test("join with no matches") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index c1b501a75c8b8..465c1a8a43a79 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -193,6 +193,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(rdd.join(emptyKv).collect().size === 0) assert(rdd.rightOuterJoin(emptyKv).collect().size === 0) assert(rdd.leftOuterJoin(emptyKv).collect().size === 2) + assert(rdd.fullOuterJoin(emptyKv).collect().size === 2) assert(rdd.cogroup(emptyKv).collect().size === 2) assert(rdd.union(emptyKv).collect().size === 2) } diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 01d378af574b5..510b47a2aaad1 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -906,7 +906,7 @@ for details. join(otherDataset, [numTasks]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (V, W)) pairs with all pairs of elements for each key. - Outer joins are also supported through leftOuterJoin and rightOuterJoin. + Outer joins are supported through leftOuterJoin, rightOuterJoin, and fullOuterJoin. diff --git a/python/pyspark/join.py b/python/pyspark/join.py index b0f1cc1927066..b4a844713745a 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -80,6 +80,22 @@ def dispatch(seq): return _do_python_join(rdd, other, numPartitions, dispatch) +def python_full_outer_join(rdd, other, numPartitions): + def dispatch(seq): + vbuf, wbuf = [], [] + for (n, v) in seq: + if n == 1: + vbuf.append(v) + elif n == 2: + wbuf.append(v) + if not vbuf: + vbuf.append(None) + if not wbuf: + wbuf.append(None) + return [(v, w) for v in vbuf for w in wbuf] + return _do_python_join(rdd, other, numPartitions, dispatch) + + def python_cogroup(rdds, numPartitions): def make_mapper(i): return lambda (k, v): (k, (i, v)) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8ef233bc80c5c..680140d72d03c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -36,7 +36,7 @@ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer from pyspark.join import python_join, python_left_outer_join, \ - python_right_outer_join, python_cogroup + python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel @@ -1375,7 +1375,7 @@ def leftOuterJoin(self, other, numPartitions=None): For each element (k, v) in C{self}, the resulting RDD will either contain all pairs (k, (v, w)) for w in C{other}, or the pair - (k, (v, None)) if no elements in other have key k. + (k, (v, None)) if no elements in C{other} have key k. Hash-partitions the resulting RDD into the given number of partitions. @@ -1403,6 +1403,27 @@ def rightOuterJoin(self, other, numPartitions=None): """ return python_right_outer_join(self, other, numPartitions) + def fullOuterJoin(self, other, numPartitions=None): + """ + Perform a right outer join of C{self} and C{other}. + + For each element (k, v) in C{self}, the resulting RDD will either + contain all pairs (k, (v, w)) for w in C{other}, or the pair + (k, (v, None)) if no elements in C{other} have key k. + + Similarly, for each element (k, w) in C{other}, the resulting RDD will + either contain all pairs (k, (v, w)) for v in C{self}, or the pair + (k, (None, w)) if no elements in C{self} have key k. + + Hash-partitions the resulting RDD into the given number of partitions. + + >>> x = sc.parallelize([("a", 1), ("b", 4)]) + >>> y = sc.parallelize([("a", 2), ("c", 8)]) + >>> sorted(x.fullOuterJoin(y).collect()) + [('a', (1, 2)), ('b', (4, None)), ('c', (None, 8))] + """ + return python_full_outer_join(self, other, numPartitions) + # TODO: add option to control map-side combining # portable_hash is used as default, because builtin hash of None is different # cross machines. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index c00e11d11910f..59d4423086ef0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -606,8 +606,9 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream. - * The supplied org.apache.spark.Partitioner is used to control the partitioning of each RDD. + * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and + * `other` DStream. The supplied org.apache.spark.Partitioner is used to control + * the partitioning of each RDD. */ def leftOuterJoin[W]( other: JavaPairDStream[K, W], @@ -624,8 +625,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * number of partitions. */ def rightOuterJoin[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (Optional[V], W)] = { - implicit val cm: ClassTag[W] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]] + implicit val cm: ClassTag[W] = fakeClassTag val joinResult = dstream.rightOuterJoin(other.dstream) joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)} } @@ -658,6 +658,52 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)} } + /** + * Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default + * number of partitions. + */ + def fullOuterJoin[W](other: JavaPairDStream[K, W]) + : JavaPairDStream[K, (Optional[V], Optional[W])] = { + implicit val cm: ClassTag[W] = fakeClassTag + val joinResult = dstream.fullOuterJoin(other.dstream) + joinResult.mapValues{ case (v, w) => + (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w)) + } + } + + /** + * Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + * `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` + * partitions. + */ + def fullOuterJoin[W]( + other: JavaPairDStream[K, W], + numPartitions: Int + ): JavaPairDStream[K, (Optional[V], Optional[W])] = { + implicit val cm: ClassTag[W] = fakeClassTag + val joinResult = dstream.fullOuterJoin(other.dstream, numPartitions) + joinResult.mapValues{ case (v, w) => + (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w)) + } + } + + /** + * Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + * `other` DStream. The supplied org.apache.spark.Partitioner is used to control + * the partitioning of each RDD. + */ + def fullOuterJoin[W]( + other: JavaPairDStream[K, W], + partitioner: Partitioner + ): JavaPairDStream[K, (Optional[V], Optional[W])] = { + implicit val cm: ClassTag[W] = fakeClassTag + val joinResult = dstream.fullOuterJoin(other.dstream, partitioner) + joinResult.mapValues{ case (v, w) => + (JavaUtils.optionToOptional(v), JavaUtils.optionToOptional(w)) + } + } + /** * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 826bf39e860e1..9467595d307a2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -568,6 +568,42 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) ) } + /** + * Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default + * number of partitions. + */ + def fullOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Option[V], Option[W]))] = { + fullOuterJoin[W](other, defaultPartitioner()) + } + + /** + * Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + * `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions` + * partitions. + */ + def fullOuterJoin[W: ClassTag]( + other: DStream[(K, W)], + numPartitions: Int + ): DStream[(K, (Option[V], Option[W]))] = { + fullOuterJoin[W](other, defaultPartitioner(numPartitions)) + } + + /** + * Return a new DStream by applying 'full outer join' between RDDs of `this` DStream and + * `other` DStream. The supplied org.apache.spark.Partitioner is used to control + * the partitioning of each RDD. + */ + def fullOuterJoin[W: ClassTag]( + other: DStream[(K, W)], + partitioner: Partitioner + ): DStream[(K, (Option[V], Option[W]))] = { + self.transformWith( + other, + (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.fullOuterJoin(rdd2, partitioner) + ) + } + /** * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval * is generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix" diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 059ac6c2dbee2..6c8bb50145367 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -303,6 +303,21 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(inputData1, inputData2, operation, outputData, true) } + test("fullOuterJoin") { + val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() ) + val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") ) + val outputData = Seq( + Seq( ("a", (Some(1), Some("x"))), ("b", (Some(1), Some("x"))) ), + Seq( ("", (Some(1), Some("x"))), ("a", (Some(1), None)), ("b", (None, Some("x"))) ), + Seq( ("", (Some(1), None)) ), + Seq( ("", (None, Some("x"))) ) + ) + val operation = (s1: DStream[String], s2: DStream[String]) => { + s1.map(x => (x, 1)).fullOuterJoin(s2.map(x => (x, "x"))) + } + testOperation(inputData1, inputData2, operation, outputData, true) + } + test("updateStateByKey") { val inputData = Seq( From b8487713d3bf288a4f6fc149e6ee4cc8196d6e7d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 24 Sep 2014 23:10:26 -0700 Subject: [PATCH 22/22] [SPARK-2778] [yarn] Add yarn integration tests. This patch adds a couple of, currently, very simple integration tests to make sure both client and cluster modes are working. The tests don't do much yet other than run a simple job, but the plan is to enhance them after we get the framework in. The cluster tests are noisy, so redirect all log output to a file like other tests do. Copying the conf around sucks but it's less work than messing with maven/sbt and having to clean up other projects. Note the test is only added for yarn-stable. The code compiles against yarn-alpha but there are two issues I ran into that I could not overcome: - an old netty dependency kept creeping into the classpath and causing akka to not work, when using sbt; the old netty was correctly suppressed under maven. - MiniYARNCluster kept failing to execute containers because it did not create the NM's local dir itself; this is apparently a known behavior, but I'm not sure how to work around it. None of those issues are present with the stable Yarn. Also, these tests are a little slow to run. Apparently Spark doesn't yet tag tests (so that these could be isolated in a "slow" batch), so this is something to keep in mind. Author: Marcelo Vanzin Closes #2257 from vanzin/yarn-tests and squashes the following commits: 6d5b84e [Marcelo Vanzin] Fix wrong system property being set. 8b0933d [Marcelo Vanzin] Merge branch 'master' into yarn-tests 5c2b56f [Marcelo Vanzin] Use custom log4j conf for Yarn containers. ec73f17 [Marcelo Vanzin] More review feedback. 67f5b02 [Marcelo Vanzin] Review feedback. f01517c [Marcelo Vanzin] Review feedback. 68fbbbf [Marcelo Vanzin] Use older constructor available in older Hadoop releases. d07ef9a [Marcelo Vanzin] Merge branch 'master' into yarn-tests add8416 [Marcelo Vanzin] [SPARK-2778] [yarn] Add yarn integration tests. --- pom.xml | 31 +++- .../spark/deploy/yarn/ApplicationMaster.scala | 10 +- .../apache/spark/deploy/yarn/ClientBase.scala | 2 +- .../deploy/yarn/ExecutorRunnableUtil.scala | 2 +- yarn/pom.xml | 3 +- yarn/stable/pom.xml | 9 + .../src/test/resources/log4j.properties | 28 ++++ .../spark/deploy/yarn/YarnClusterSuite.scala | 154 ++++++++++++++++++ 8 files changed, 229 insertions(+), 10 deletions(-) create mode 100644 yarn/stable/src/test/resources/log4j.properties create mode 100644 yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala diff --git a/pom.xml b/pom.xml index 520aed3806937..f3de097b9cb32 100644 --- a/pom.xml +++ b/pom.xml @@ -712,6 +712,35 @@ + + org.apache.hadoop + hadoop-yarn-server-tests + ${yarn.version} + tests + test + + + asm + asm + + + org.ow2.asm + asm + + + org.jboss.netty + netty + + + javax.servlet + servlet-api + + + commons-logging + commons-logging + + + org.apache.hadoop hadoop-yarn-server-web-proxy @@ -1187,7 +1216,7 @@ org.apache.zookeeper zookeeper - 3.4.5-mapr-1406 + 3.4.5-mapr-1406 diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 9050808157257..b51daeb437516 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -401,17 +401,17 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED. status = FinalApplicationStatus.SUCCEEDED } catch { - case e: InvocationTargetException => { + case e: InvocationTargetException => e.getCause match { - case _: InterruptedException => { + case _: InterruptedException => // Reporter thread can interrupt to stop user class - } + + case e => throw e } - } } finally { logDebug("Finishing main") + finalStatus = status } - finalStatus = status } } userClassThread.setName("Driver") diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 4870b0cb3ddaf..1cf19c198509c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -348,7 +348,7 @@ private[spark] trait ClientBase extends Logging { } // For log4j configuration to reference - javaOpts += "-D=spark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) val userClass = if (args.userClass != null) { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index bbbf615510762..d7a7175d5e578 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -98,7 +98,7 @@ trait ExecutorRunnableUtil extends Logging { */ // For log4j configuration to reference - javaOpts += "-D=spark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server", diff --git a/yarn/pom.xml b/yarn/pom.xml index 815a736c2e8fd..8a7035c85e9f1 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -140,7 +140,6 @@ ${basedir}/../.. - ${spark.classpath} @@ -148,7 +147,7 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes - + ../common/src/main/resources diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index fd934b7726181..97eb0548e77c3 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -32,4 +32,13 @@ jar Spark Project YARN Stable API + + + org.apache.hadoop + hadoop-yarn-server-tests + tests + test + + + diff --git a/yarn/stable/src/test/resources/log4j.properties b/yarn/stable/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..26b73a1b39744 --- /dev/null +++ b/yarn/stable/src/test/resources/log4j.properties @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file core/target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN +org.eclipse.jetty.LEVEL=WARN diff --git a/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala new file mode 100644 index 0000000000000..857a4447dd738 --- /dev/null +++ b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.File + +import scala.collection.JavaConversions._ + +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} + +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.server.MiniYARNCluster + +import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils + +class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers { + + // log4j configuration for the Yarn containers, so that their output is collected + // by Yarn instead of trying to overwrite unit-tests.log. + private val LOG4J_CONF = """ + |log4j.rootCategory=DEBUG, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + """.stripMargin + + private var yarnCluster: MiniYARNCluster = _ + private var tempDir: File = _ + private var fakeSparkJar: File = _ + private var oldConf: Map[String, String] = _ + + override def beforeAll() { + tempDir = Utils.createTempDir() + + val logConfDir = new File(tempDir, "log4j") + logConfDir.mkdir() + + val logConfFile = new File(logConfDir, "log4j.properties") + Files.write(LOG4J_CONF, logConfFile, Charsets.UTF_8) + + val childClasspath = logConfDir.getAbsolutePath() + File.pathSeparator + + sys.props("java.class.path") + + oldConf = sys.props.filter { case (k, v) => k.startsWith("spark.") }.toMap + + yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1) + yarnCluster.init(new YarnConfiguration()) + yarnCluster.start() + yarnCluster.getConfig().foreach { e => + sys.props += ("spark.hadoop." + e.getKey() -> e.getValue()) + } + + fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) + sys.props += ("spark.yarn.jar" -> ("local:" + fakeSparkJar.getAbsolutePath())) + sys.props += ("spark.executor.instances" -> "1") + sys.props += ("spark.driver.extraClassPath" -> childClasspath) + sys.props += ("spark.executor.extraClassPath" -> childClasspath) + + super.beforeAll() + } + + override def afterAll() { + yarnCluster.stop() + sys.props.retain { case (k, v) => !k.startsWith("spark.") } + sys.props ++= oldConf + super.afterAll() + } + + test("run Spark in yarn-client mode") { + var result = File.createTempFile("result", null, tempDir) + YarnClusterDriver.main(Array("yarn-client", result.getAbsolutePath())) + checkResult(result) + } + + test("run Spark in yarn-cluster mode") { + val main = YarnClusterDriver.getClass.getName().stripSuffix("$") + var result = File.createTempFile("result", null, tempDir) + + // The Client object will call System.exit() after the job is done, and we don't want + // that because it messes up the scalatest monitoring. So replicate some of what main() + // does here. + val args = Array("--class", main, + "--jar", "file:" + fakeSparkJar.getAbsolutePath(), + "--arg", "yarn-cluster", + "--arg", result.getAbsolutePath(), + "--num-executors", "1") + val sparkConf = new SparkConf() + val yarnConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val clientArgs = new ClientArguments(args, sparkConf) + new Client(clientArgs, yarnConf, sparkConf).run() + checkResult(result) + } + + /** + * This is a workaround for an issue with yarn-cluster mode: the Client class will not provide + * any sort of error when the job process finishes successfully, but the job itself fails. So + * the tests enforce that something is written to a file after everything is ok to indicate + * that the job succeeded. + */ + private def checkResult(result: File) = { + var resultString = Files.toString(result, Charsets.UTF_8) + resultString should be ("success") + } + +} + +private object YarnClusterDriver extends Logging with Matchers { + + def main(args: Array[String]) = { + if (args.length != 2) { + System.err.println( + s""" + |Invalid command line: ${args.mkString(" ")} + | + |Usage: YarnClusterDriver [master] [result file] + """.stripMargin) + System.exit(1) + } + + val sc = new SparkContext(new SparkConf().setMaster(args(0)) + .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns")) + val status = new File(args(1)) + var result = "failure" + try { + val data = sc.parallelize(1 to 4, 4).collect().toSet + data should be (Set(1, 2, 3, 4)) + result = "success" + } finally { + sc.stop() + Files.write(result, status, Charsets.UTF_8) + } + } + +}