From c6ba7cca3338e3f4f719d86dbcff4406d949edc7 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 10 Jun 2015 09:45:45 -0700 Subject: [PATCH 001/151] [SPARK-8215] [SPARK-8212] [SQL] add leaf math expression for e and pi Author: Daoyuan Wang Closes #6716 from adrian-wang/epi and squashes the following commits: e2e8dbd [Daoyuan Wang] move tests 11b351c [Daoyuan Wang] add tests and remove pu db331c9 [Daoyuan Wang] py style 599ddd8 [Daoyuan Wang] add py e6783ef [Daoyuan Wang] register function 82d426e [Daoyuan Wang] add function entry dbf3ab5 [Daoyuan Wang] add PI and E --- .../catalyst/analysis/FunctionRegistry.scala | 2 ++ .../spark/sql/catalyst/expressions/math.scala | 35 +++++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 22 ++++++++++++ .../org/apache/spark/sql/functions.scala | 18 ++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 19 ++++++++++ 5 files changed, 96 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 936ffc7d5ff55..ba89a5c8d1372 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -106,6 +106,7 @@ object FunctionRegistry { expression[Cbrt]("cbrt"), expression[Ceil]("ceil"), expression[Cos]("cos"), + expression[EulerNumber]("e"), expression[Exp]("exp"), expression[Expm1]("expm1"), expression[Floor]("floor"), @@ -113,6 +114,7 @@ object FunctionRegistry { expression[Log]("log"), expression[Log10]("log10"), expression[Log1p]("log1p"), + expression[Pi]("pi"), expression[Pow]("pow"), expression[Rint]("rint"), expression[Signum]("signum"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 7dacb6a9b47b6..e1d8c9a0cdb5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -20,9 +20,34 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.{DataType, DoubleType} +/** + * A leaf expression specifically for math constants. Math constants expect no input. + * @param c The math constant. + * @param name The short name of the function + */ +abstract class LeafMathExpression(c: Double, name: String) + extends LeafExpression with Serializable { + self: Product => + + override def dataType: DataType = DoubleType + override def foldable: Boolean = true + override def nullable: Boolean = false + override def toString: String = s"$name()" + + override def eval(input: Row): Any = c + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + s""" + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = java.lang.Math.$name; + """ + } +} + /** * A unary expression specifically for math functions. Math Functions expect a specific type of * input format, therefore these functions extend `ExpectsInputTypes`. + * @param f The math function. * @param name The short name of the function */ abstract class UnaryMathExpression(f: Double => Double, name: String) @@ -98,6 +123,16 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Leaf math functions +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +case class EulerNumber() extends LeafMathExpression(math.E, "E") + +case class Pi() extends LeafMathExpression(math.Pi, "PI") + //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// // Unary math functions diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 25ebc70d095d8..1fe69059d39da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -22,6 +22,20 @@ import org.apache.spark.sql.types.DoubleType class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + /** + * Used for testing leaf math expressions. + * + * @param e expression + * @param c The constants in scala.math + * @tparam T Generic type for primitives + */ + private def testLeaf[T]( + e: () => Expression, + c: T): Unit = { + checkEvaluation(e(), c, EmptyRow) + checkEvaluation(e(), c, create_row(null)) + } + /** * Used for testing unary math expressions. * @@ -74,6 +88,14 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) } + test("e") { + testLeaf(EulerNumber, math.E) + } + + test("pi") { + testLeaf(Pi, math.Pi) + } + test("sin") { testUnary(Sin, math.sin) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 454af47913bf1..b3fc1e6cd987e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -944,6 +944,15 @@ object functions { */ def cosh(columnName: String): Column = cosh(Column(columnName)) + /** + * Returns the double value that is closer than any other to e, the base of the natural + * logarithms. + * + * @group math_funcs + * @since 1.5.0 + */ + def e(): Column = EulerNumber() + /** * Computes the exponential of the given value. * @@ -1105,6 +1114,15 @@ object functions { */ def log1p(columnName: String): Column = log1p(Column(columnName)) + /** + * Returns the double value that is closer than any other to pi, the ratio of the circumference + * of a circle to its diameter. + * + * @group math_funcs + * @since 1.5.0 + */ + def pi(): Column = Pi() + /** * Returns the value of the first argument raised to the power of the second argument. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 53c2befb73702..b93ad39f5da45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -85,6 +85,25 @@ class DataFrameFunctionsSuite extends QueryTest { } } + test("constant functions") { + checkAnswer( + testData2.select(e()).limit(1), + Row(scala.math.E) + ) + checkAnswer( + testData2.select(pi()).limit(1), + Row(scala.math.Pi) + ) + checkAnswer( + ctx.sql("SELECT E()"), + Row(scala.math.E) + ) + checkAnswer( + ctx.sql("SELECT PI()"), + Row(scala.math.Pi) + ) + } + test("bitwiseNOT") { checkAnswer( testData2.select(bitwiseNOT($"a")), From 2b550a521e45e1dbca2cca40ddd94e20c013831c Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Wed, 10 Jun 2015 11:21:12 -0700 Subject: [PATCH 002/151] [SPARK-7996] Deprecate the developer api SparkEnv.actorSystem Changed ```SparkEnv.actorSystem``` to be a function such that we can use the deprecated flag with it and added a deprecated message. Author: Ilya Ganelin Closes #6731 from ilganeli/SPARK-7996 and squashes the following commits: be43817 [Ilya Ganelin] Restored to val 9ed89e7 [Ilya Ganelin] Added a version info for deprecation 9610b08 [Ilya Ganelin] Converted actorSystem to function and added deprecated flag --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index a185954089528..b0665570e2681 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,6 +20,8 @@ package org.apache.spark import java.io.File import java.net.Socket +import akka.actor.ActorSystem + import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties @@ -75,7 +77,8 @@ class SparkEnv ( val conf: SparkConf) extends Logging { // TODO Remove actorSystem - val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + @deprecated("Actor system is no longer supported as of 1.4") + val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() From 8f7308f9c49805b9486aaae5f60e4481e8ba24e8 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 10 Jun 2015 11:48:14 -0700 Subject: [PATCH 003/151] [SQL] [MINOR] Fixes a minor Java example error in SQL programming guide Author: Cheng Lian Closes #6749 from liancheng/java-sample-fix and squashes the following commits: 5b44585 [Cheng Lian] Fixes a minor Java example error in SQL programming guide --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 40e33f757d693..c5ab074e4439f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1479,7 +1479,7 @@ expressed in HiveQL. {% highlight java %} // sc is an existing JavaSparkContext. -HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc); +HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc.sc); sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); From 38112905bc3b33f2ae75274afba1c30e116f6e46 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 10 Jun 2015 13:17:29 -0700 Subject: [PATCH 004/151] [SPARK-5479] [YARN] Handle --py-files correctly in YARN. The bug description is a little misleading: the actual issue is that .py files are not handled correctly when distributed by YARN. They're added to "spark.submit.pyFiles", which, when processed by context.py, explicitly whitelists certain extensions (see PACKAGE_EXTENSIONS), and that does not include .py files. On top of that, archives were not handled at all! They made it to the driver's python path, but never made it to executors, since the mechanism used to propagate their location (spark.submit.pyFiles) only works on the driver side. So, instead, ignore "spark.submit.pyFiles" and just build PYTHONPATH correctly for both driver and executors. Individual .py files are placed in a subdirectory of the container's local dir in the cluster, which is then added to the python path. Archives are added directly. The change, as a side effect, ends up solving the symptom described in the bug. The issue was not that the files were not being distributed, but that they were never made visible to the python application running under Spark. Also included is a proper unit test for running python on YARN, which broke in several different ways with the previous code. A short walk around of the changes: - SparkSubmit does not try to be smart about how YARN handles python files anymore. It just passes down the configs to the YARN client code. - The YARN client distributes python files and archives differently, placing the files in a subdirectory. - The YARN client now sets PYTHONPATH for the processes it launches; to properly handle different locations, it uses YARN's support for embedding env variables, so to avoid YARN expanding those at the wrong time, SparkConf is now propagated to the AM using a conf file instead of command line options. - Because the Client initialization code is a maze of implicit dependencies, some code needed to be moved around to make sure all needed state was available when the code ran. - The pyspark tests in YarnClusterSuite now actually distribute and try to use both a python file and an archive containing a different python module. Also added a yarn-client tests for completeness. - I cleaned up some of the code around distributing files to YARN, to avoid adding more copied & pasted code to handle the new files being distributed. Author: Marcelo Vanzin Closes #6360 from vanzin/SPARK-5479 and squashes the following commits: bcaf7e6 [Marcelo Vanzin] Feedback. c47501f [Marcelo Vanzin] Fix yarn-client mode. 46b1d0c [Marcelo Vanzin] Merge branch 'master' into SPARK-5479 c743778 [Marcelo Vanzin] Only pyspark cares about python archives. c8e5a82 [Marcelo Vanzin] Actually run pyspark in client mode. 705571d [Marcelo Vanzin] Move some code to the YARN module. 1dd4d0c [Marcelo Vanzin] Review feedback. 71ee736 [Marcelo Vanzin] Merge branch 'master' into SPARK-5479 220358b [Marcelo Vanzin] Scalastyle. cdbb990 [Marcelo Vanzin] Merge branch 'master' into SPARK-5479 7fe3cd4 [Marcelo Vanzin] No need to distribute primary file to executors. 09045f1 [Marcelo Vanzin] Style. 943cbf4 [Marcelo Vanzin] [SPARK-5479] [yarn] Handle --py-files correctly in YARN. --- .../org/apache/spark/deploy/SparkSubmit.scala | 77 +---- .../spark/deploy/yarn/ApplicationMaster.scala | 20 +- .../yarn/ApplicationMasterArguments.scala | 12 +- .../org/apache/spark/deploy/yarn/Client.scala | 295 +++++++++++------- .../spark/deploy/yarn/ClientArguments.scala | 4 +- .../cluster/YarnClientSchedulerBackend.scala | 5 +- .../spark/deploy/yarn/ClientSuite.scala | 4 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 61 ++-- 8 files changed, 270 insertions(+), 208 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index a0eae774268ed..b8978e25a02d2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -324,55 +324,20 @@ object SparkSubmit { // Usage: PythonAppRunner
[app arguments] args.mainClass = "org.apache.spark.deploy.PythonRunner" args.childArgs = ArrayBuffer(args.primaryResource, args.pyFiles) ++ args.childArgs - args.files = mergeFileLists(args.files, args.primaryResource) + if (clusterManager != YARN) { + // The YARN backend distributes the primary file differently, so don't merge it. + args.files = mergeFileLists(args.files, args.primaryResource) + } + } + if (clusterManager != YARN) { + // The YARN backend handles python files differently, so don't merge the lists. + args.files = mergeFileLists(args.files, args.pyFiles) } - args.files = mergeFileLists(args.files, args.pyFiles) if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } } - // In yarn mode for a python app, add pyspark archives to files - // that can be distributed with the job - if (args.isPython && clusterManager == YARN) { - var pyArchives: String = null - val pyArchivesEnvOpt = sys.env.get("PYSPARK_ARCHIVES_PATH") - if (pyArchivesEnvOpt.isDefined) { - pyArchives = pyArchivesEnvOpt.get - } else { - if (!sys.env.contains("SPARK_HOME")) { - printErrorAndExit("SPARK_HOME does not exist for python application in yarn mode.") - } - val pythonPath = new ArrayBuffer[String] - for (sparkHome <- sys.env.get("SPARK_HOME")) { - val pyLibPath = Seq(sparkHome, "python", "lib").mkString(File.separator) - val pyArchivesFile = new File(pyLibPath, "pyspark.zip") - if (!pyArchivesFile.exists()) { - printErrorAndExit("pyspark.zip does not exist for python application in yarn mode.") - } - val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") - if (!py4jFile.exists()) { - printErrorAndExit("py4j-0.8.2.1-src.zip does not exist for python application " + - "in yarn mode.") - } - pythonPath += pyArchivesFile.getAbsolutePath() - pythonPath += py4jFile.getAbsolutePath() - } - pyArchives = pythonPath.mkString(",") - } - - pyArchives = pyArchives.split(",").map { localPath => - val localURI = Utils.resolveURI(localPath) - if (localURI.getScheme != "local") { - args.files = mergeFileLists(args.files, localURI.toString) - new Path(localPath).getName - } else { - localURI.getPath - } - }.mkString(File.pathSeparator) - sysProps("spark.submit.pyArchives") = pyArchives - } - // If we're running a R app, set the main class to our specific R runner if (args.isR && deployMode == CLIENT) { if (args.primaryResource == SPARKR_SHELL) { @@ -386,19 +351,10 @@ object SparkSubmit { } } - if (isYarnCluster) { - // In yarn-cluster mode for a python app, add primary resource and pyFiles to files - // that can be distributed with the job - if (args.isPython) { - args.files = mergeFileLists(args.files, args.primaryResource) - args.files = mergeFileLists(args.files, args.pyFiles) - } - + if (isYarnCluster && args.isR) { // In yarn-cluster mode for a R app, add primary resource to files // that can be distributed with the job - if (args.isR) { - args.files = mergeFileLists(args.files, args.primaryResource) - } + args.files = mergeFileLists(args.files, args.primaryResource) } // Special flag to avoid deprecation warnings at the client @@ -515,17 +471,18 @@ object SparkSubmit { } } + // Let YARN know it's a pyspark app, so it distributes needed libraries. + if (clusterManager == YARN && args.isPython) { + sysProps.put("spark.yarn.isPython", "true") + } + // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (isYarnCluster) { childMainClass = "org.apache.spark.deploy.yarn.Client" if (args.isPython) { - val mainPyFile = new Path(args.primaryResource).getName - childArgs += ("--primary-py-file", mainPyFile) + childArgs += ("--primary-py-file", args.primaryResource) if (args.pyFiles != null) { - // These files will be distributed to each machine's working directory, so strip the - // path prefix - val pyFilesNames = args.pyFiles.split(",").map(p => (new Path(p)).getName).mkString(",") - childArgs += ("--py-files", pyFilesNames) + childArgs += ("--py-files", args.pyFiles) } childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") } else if (args.isR) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 002d7b6eaf498..83dafa4a125d2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.SparkException -import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -46,6 +46,14 @@ private[spark] class ApplicationMaster( client: YarnRMClient) extends Logging { + // Load the properties file with the Spark configuration and set entries as system properties, + // so that user code run inside the AM also has access to them. + if (args.propertiesFile != null) { + Utils.getPropertiesFromFile(args.propertiesFile).foreach { case (k, v) => + sys.props(k) = v + } + } + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. @@ -490,9 +498,11 @@ private[spark] class ApplicationMaster( new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader) } + var userArgs = args.userArgs if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { - System.setProperty("spark.submit.pyFiles", - PythonRunner.formatPaths(args.pyFiles).mkString(",")) + // When running pyspark, the app is run using PythonRunner. The second argument is the list + // of files to add to PYTHONPATH, which Client.scala already handles, so it's empty. + userArgs = Seq(args.primaryPyFile, "") ++ userArgs } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { // TODO(davies): add R dependencies here @@ -503,9 +513,7 @@ private[spark] class ApplicationMaster( val userThread = new Thread { override def run() { try { - val mainArgs = new Array[String](args.userArgs.size) - args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) - mainMethod.invoke(null, mainArgs) + mainMethod.invoke(null, userArgs.toArray) finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) logDebug("Done running users class") } catch { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index ae6dc1094d724..68e9f6b4db7f4 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -26,11 +26,11 @@ class ApplicationMasterArguments(val args: Array[String]) { var userClass: String = null var primaryPyFile: String = null var primaryRFile: String = null - var pyFiles: String = null - var userArgs: Seq[String] = Seq[String]() + var userArgs: Seq[String] = Nil var executorMemory = 1024 var executorCores = 1 var numExecutors = DEFAULT_NUMBER_EXECUTORS + var propertiesFile: String = null parseArgs(args.toList) @@ -59,10 +59,6 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryRFile = value args = tail - case ("--py-files") :: value :: tail => - pyFiles = value - args = tail - case ("--args" | "--arg") :: value :: tail => userArgsBuffer += value args = tail @@ -79,6 +75,10 @@ class ApplicationMasterArguments(val args: Array[String]) { executorCores = value args = tail + case ("--properties-file") :: value :: tail => + propertiesFile = value + args = tail + case _ => printUsageAndExit(1, args) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index f4d43214b08ca..ec9402afff329 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,11 +17,12 @@ package org.apache.spark.deploy.yarn -import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException} +import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException, + OutputStreamWriter} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer import java.security.PrivilegedExceptionAction -import java.util.UUID +import java.util.{Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConversions._ @@ -29,6 +30,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} +import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects import com.google.common.io.Files @@ -247,7 +249,9 @@ private[spark] class Client( * This is used for setting up a container launch context for our ApplicationMaster. * Exposed for testing. */ - def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { + def prepareLocalResources( + appStagingDir: String, + pySparkArchives: Seq[String]): HashMap[String, LocalResource] = { logInfo("Preparing resources for our AM container") // Upload Spark and the application JAR to the remote file system if necessary, // and add them as local resources to the application master. @@ -277,20 +281,6 @@ private[spark] class Client( "for alternatives.") } - // If we passed in a keytab, make sure we copy the keytab to the staging directory on - // HDFS, and setup the relevant environment vars, so the AM can login again. - if (loginFromKeytab) { - logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + - " via the YARN Secure Distributed Cache.") - val localUri = new URI(args.keytab) - val localPath = getQualifiedLocalPath(localUri, hadoopConf) - val destinationPath = copyFileToRemote(dst, localPath, replication) - val destFs = FileSystem.get(destinationPath.toUri(), hadoopConf) - distCacheMgr.addResource( - destFs, hadoopConf, destinationPath, localResources, LocalResourceType.FILE, - sparkConf.get("spark.yarn.keytab"), statCache, appMasterOnly = true) - } - def addDistributedUri(uri: URI): Boolean = { val uriStr = uri.toString() if (distributedUris.contains(uriStr)) { @@ -302,6 +292,57 @@ private[spark] class Client( } } + /** + * Distribute a file to the cluster. + * + * If the file's path is a "local:" URI, it's actually not distributed. Other files are copied + * to HDFS (if not already there) and added to the application's distributed cache. + * + * @param path URI of the file to distribute. + * @param resType Type of resource being distributed. + * @param destName Name of the file in the distributed cache. + * @param targetDir Subdirectory where to place the file. + * @param appMasterOnly Whether to distribute only to the AM. + * @return A 2-tuple. First item is whether the file is a "local:" URI. Second item is the + * localized path for non-local paths, or the input `path` for local paths. + * The localized path will be null if the URI has already been added to the cache. + */ + def distribute( + path: String, + resType: LocalResourceType = LocalResourceType.FILE, + destName: Option[String] = None, + targetDir: Option[String] = None, + appMasterOnly: Boolean = false): (Boolean, String) = { + val localURI = new URI(path.trim()) + if (localURI.getScheme != LOCAL_SCHEME) { + if (addDistributedUri(localURI)) { + val localPath = getQualifiedLocalPath(localURI, hadoopConf) + val linkname = targetDir.map(_ + "/").getOrElse("") + + destName.orElse(Option(localURI.getFragment())).getOrElse(localPath.getName()) + val destPath = copyFileToRemote(dst, localPath, replication) + distCacheMgr.addResource( + fs, hadoopConf, destPath, localResources, resType, linkname, statCache, + appMasterOnly = appMasterOnly) + (false, linkname) + } else { + (false, null) + } + } else { + (true, path.trim()) + } + } + + // If we passed in a keytab, make sure we copy the keytab to the staging directory on + // HDFS, and setup the relevant environment vars, so the AM can login again. + if (loginFromKeytab) { + logInfo("To enable the AM to login from keytab, credentials are being copied over to the AM" + + " via the YARN Secure Distributed Cache.") + val (_, localizedPath) = distribute(args.keytab, + destName = Some(sparkConf.get("spark.yarn.keytab")), + appMasterOnly = true) + require(localizedPath != null, "Keytab file already distributed.") + } + /** * Copy the given main resource to the distributed cache if the scheme is not "local". * Otherwise, set the corresponding key in our SparkConf to handle it downstream. @@ -314,33 +355,18 @@ private[spark] class Client( (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR), (APP_JAR, args.userJar, CONF_SPARK_USER_JAR), ("log4j.properties", oldLog4jConf.orNull, null) - ).foreach { case (destName, _localPath, confKey) => - val localPath: String = if (_localPath != null) _localPath.trim() else "" - if (!localPath.isEmpty()) { - val localURI = new URI(localPath) - if (localURI.getScheme != LOCAL_SCHEME) { - if (addDistributedUri(localURI)) { - val src = getQualifiedLocalPath(localURI, hadoopConf) - val destPath = copyFileToRemote(dst, src, replication) - val destFs = FileSystem.get(destPath.toUri(), hadoopConf) - distCacheMgr.addResource(destFs, hadoopConf, destPath, - localResources, LocalResourceType.FILE, destName, statCache) - } - } else if (confKey != null) { + ).foreach { case (destName, path, confKey) => + if (path != null && !path.trim().isEmpty()) { + val (isLocal, localizedPath) = distribute(path, destName = Some(destName)) + if (isLocal && confKey != null) { + require(localizedPath != null, s"Path $path already distributed.") // If the resource is intended for local use only, handle this downstream // by setting the appropriate property - sparkConf.set(confKey, localPath) + sparkConf.set(confKey, localizedPath) } } } - createConfArchive().foreach { file => - require(addDistributedUri(file.toURI())) - val destPath = copyFileToRemote(dst, new Path(file.toURI()), replication) - distCacheMgr.addResource(fs, hadoopConf, destPath, localResources, LocalResourceType.ARCHIVE, - LOCALIZED_HADOOP_CONF_DIR, statCache, appMasterOnly = true) - } - /** * Do the same for any additional resources passed in through ClientArguments. * Each resource category is represented by a 3-tuple of: @@ -356,21 +382,10 @@ private[spark] class Client( ).foreach { case (flist, resType, addToClasspath) => if (flist != null && !flist.isEmpty()) { flist.split(',').foreach { file => - val localURI = new URI(file.trim()) - if (localURI.getScheme != LOCAL_SCHEME) { - if (addDistributedUri(localURI)) { - val localPath = new Path(localURI) - val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) - val destPath = copyFileToRemote(dst, localPath, replication) - distCacheMgr.addResource( - fs, hadoopConf, destPath, localResources, resType, linkname, statCache) - if (addToClasspath) { - cachedSecondaryJarLinks += linkname - } - } - } else if (addToClasspath) { - // Resource is intended for local use only and should be added to the class path - cachedSecondaryJarLinks += file.trim() + val (_, localizedPath) = distribute(file, resType = resType) + require(localizedPath != null) + if (addToClasspath) { + cachedSecondaryJarLinks += localizedPath } } } @@ -379,11 +394,31 @@ private[spark] class Client( sparkConf.set(CONF_SPARK_YARN_SECONDARY_JARS, cachedSecondaryJarLinks.mkString(",")) } + if (isClusterMode && args.primaryPyFile != null) { + distribute(args.primaryPyFile, appMasterOnly = true) + } + + pySparkArchives.foreach { f => distribute(f) } + + // The python files list needs to be treated especially. All files that are not an + // archive need to be placed in a subdirectory that will be added to PYTHONPATH. + args.pyFiles.foreach { f => + val targetDir = if (f.endsWith(".py")) Some(LOCALIZED_PYTHON_DIR) else None + distribute(f, targetDir = targetDir) + } + + // Distribute an archive with Hadoop and Spark configuration for the AM. + val (_, confLocalizedPath) = distribute(createConfArchive().getAbsolutePath(), + resType = LocalResourceType.ARCHIVE, + destName = Some(LOCALIZED_CONF_DIR), + appMasterOnly = true) + require(confLocalizedPath != null) + localResources } /** - * Create an archive with the Hadoop config files for distribution. + * Create an archive with the config files for distribution. * * These are only used by the AM, since executors will use the configuration object broadcast by * the driver. The files are zipped and added to the job as an archive, so that YARN will explode @@ -395,8 +430,11 @@ private[spark] class Client( * * Currently this makes a shallow copy of the conf directory. If there are cases where a * Hadoop config directory contains subdirectories, this code will have to be fixed. + * + * The archive also contains some Spark configuration. Namely, it saves the contents of + * SparkConf in a file to be loaded by the AM process. */ - private def createConfArchive(): Option[File] = { + private def createConfArchive(): File = { val hadoopConfFiles = new HashMap[String, File]() Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => sys.env.get(envKey).foreach { path => @@ -411,28 +449,32 @@ private[spark] class Client( } } - if (!hadoopConfFiles.isEmpty) { - val hadoopConfArchive = File.createTempFile(LOCALIZED_HADOOP_CONF_DIR, ".zip", - new File(Utils.getLocalDir(sparkConf))) + val confArchive = File.createTempFile(LOCALIZED_CONF_DIR, ".zip", + new File(Utils.getLocalDir(sparkConf))) + val confStream = new ZipOutputStream(new FileOutputStream(confArchive)) - val hadoopConfStream = new ZipOutputStream(new FileOutputStream(hadoopConfArchive)) - try { - hadoopConfStream.setLevel(0) - hadoopConfFiles.foreach { case (name, file) => - if (file.canRead()) { - hadoopConfStream.putNextEntry(new ZipEntry(name)) - Files.copy(file, hadoopConfStream) - hadoopConfStream.closeEntry() - } + try { + confStream.setLevel(0) + hadoopConfFiles.foreach { case (name, file) => + if (file.canRead()) { + confStream.putNextEntry(new ZipEntry(name)) + Files.copy(file, confStream) + confStream.closeEntry() } - } finally { - hadoopConfStream.close() } - Some(hadoopConfArchive) - } else { - None + // Save Spark configuration to a file in the archive. + val props = new Properties() + sparkConf.getAll.foreach { case (k, v) => props.setProperty(k, v) } + confStream.putNextEntry(new ZipEntry(SPARK_CONF_FILE)) + val writer = new OutputStreamWriter(confStream, UTF_8) + props.store(writer, "Spark configuration.") + writer.flush() + confStream.closeEntry() + } finally { + confStream.close() } + confArchive } /** @@ -460,7 +502,9 @@ private[spark] class Client( /** * Set up the environment for launching our ApplicationMaster container. */ - private def setupLaunchEnv(stagingDir: String): HashMap[String, String] = { + private def setupLaunchEnv( + stagingDir: String, + pySparkArchives: Seq[String]): HashMap[String, String] = { logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.driver.extraClassPath") @@ -478,9 +522,6 @@ private[spark] class Client( val renewalInterval = getTokenRenewalInterval(stagingDirPath) sparkConf.set("spark.yarn.token.renewal.interval", renewalInterval.toString) } - // Set the environment variables to be passed on to the executors. - distCacheMgr.setDistFilesEnv(env) - distCacheMgr.setDistArchivesEnv(env) // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* val amEnvPrefix = "spark.yarn.appMasterEnv." @@ -497,15 +538,32 @@ private[spark] class Client( env("SPARK_YARN_USER_ENV") = userEnvs } - // if spark.submit.pyArchives is in sparkConf, append pyArchives to PYTHONPATH - // that can be passed on to the ApplicationMaster and the executors. - if (sparkConf.contains("spark.submit.pyArchives")) { - var pythonPath = sparkConf.get("spark.submit.pyArchives") - if (env.contains("PYTHONPATH")) { - pythonPath = Seq(env.get("PYTHONPATH"), pythonPath).mkString(File.pathSeparator) + // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH + // of the container processes too. Add all non-.py files directly to PYTHONPATH. + // + // NOTE: the code currently does not handle .py files defined with a "local:" scheme. + val pythonPath = new ListBuffer[String]() + val (pyFiles, pyArchives) = args.pyFiles.partition(_.endsWith(".py")) + if (pyFiles.nonEmpty) { + pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_PYTHON_DIR) + } + (pySparkArchives ++ pyArchives).foreach { path => + val uri = new URI(path) + if (uri.getScheme != LOCAL_SCHEME) { + pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + new Path(path).getName()) + } else { + pythonPath += uri.getPath() } - env("PYTHONPATH") = pythonPath - sparkConf.setExecutorEnv("PYTHONPATH", pythonPath) + } + + // Finally, update the Spark config to propagate PYTHONPATH to the AM and executors. + if (pythonPath.nonEmpty) { + val pythonPathStr = (sys.env.get("PYTHONPATH") ++ pythonPath) + .mkString(YarnSparkHadoopUtil.getClassPathSeparator) + env("PYTHONPATH") = pythonPathStr + sparkConf.setExecutorEnv("PYTHONPATH", pythonPathStr) } // In cluster mode, if the deprecated SPARK_JAVA_OPTS is set, we need to propagate it to @@ -555,8 +613,19 @@ private[spark] class Client( logInfo("Setting up container launch context for our AM") val appId = newAppResponse.getApplicationId val appStagingDir = getAppStagingDir(appId) - val localResources = prepareLocalResources(appStagingDir) - val launchEnv = setupLaunchEnv(appStagingDir) + val pySparkArchives = + if (sys.props.getOrElse("spark.yarn.isPython", "false").toBoolean) { + findPySparkArchives() + } else { + Nil + } + val launchEnv = setupLaunchEnv(appStagingDir, pySparkArchives) + val localResources = prepareLocalResources(appStagingDir, pySparkArchives) + + // Set the environment variables to be passed on to the executors. + distCacheMgr.setDistFilesEnv(launchEnv) + distCacheMgr.setDistArchivesEnv(launchEnv) + val amContainer = Records.newRecord(classOf[ContainerLaunchContext]) amContainer.setLocalResources(localResources) amContainer.setEnvironment(launchEnv) @@ -596,13 +665,6 @@ private[spark] class Client( javaOpts += "-XX:CMSIncrementalDutyCycle=10" } - // Forward the Spark configuration to the application master / executors. - // TODO: it might be nicer to pass these as an internal environment variable rather than - // as Java options, due to complications with string parsing of nested quotes. - for ((k, v) <- sparkConf.getAll) { - javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") - } - // Include driver-specific java options if we are launching a driver if (isClusterMode) { val driverOpts = sparkConf.getOption("spark.driver.extraJavaOptions") @@ -655,14 +717,8 @@ private[spark] class Client( Nil } val primaryPyFile = - if (args.primaryPyFile != null) { - Seq("--primary-py-file", args.primaryPyFile) - } else { - Nil - } - val pyFiles = - if (args.pyFiles != null) { - Seq("--py-files", args.pyFiles) + if (isClusterMode && args.primaryPyFile != null) { + Seq("--primary-py-file", new Path(args.primaryPyFile).getName()) } else { Nil } @@ -678,9 +734,6 @@ private[spark] class Client( } else { Class.forName("org.apache.spark.deploy.yarn.ExecutorLauncher").getName } - if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) { - args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs - } if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) { args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs } @@ -688,11 +741,13 @@ private[spark] class Client( Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg)) } val amArgs = - Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++ + Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++ userArgs ++ Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString) + "--num-executors ", args.numExecutors.toString, + "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) // Command for the ApplicationMaster val commands = prefixEnv ++ Seq( @@ -857,6 +912,22 @@ private[spark] class Client( } } } + + private def findPySparkArchives(): Seq[String] = { + sys.env.get("PYSPARK_ARCHIVES_PATH") + .map(_.split(",").toSeq) + .getOrElse { + val pyLibPath = Seq(sys.env("SPARK_HOME"), "python", "lib").mkString(File.separator) + val pyArchivesFile = new File(pyLibPath, "pyspark.zip") + require(pyArchivesFile.exists(), + "pyspark.zip not found; cannot run pyspark application in YARN mode.") + val py4jFile = new File(pyLibPath, "py4j-0.8.2.1-src.zip") + require(py4jFile.exists(), + "py4j-0.8.2.1-src.zip not found; cannot run pyspark application in YARN mode.") + Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) + } + } + } object Client extends Logging { @@ -907,8 +978,14 @@ object Client extends Logging { // Distribution-defined classpath to add to processes val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH" - // Subdirectory where the user's hadoop config files will be placed. - val LOCALIZED_HADOOP_CONF_DIR = "__hadoop_conf__" + // Subdirectory where the user's Spark and Hadoop config files will be placed. + val LOCALIZED_CONF_DIR = "__spark_conf__" + + // Name of the file in the conf archive containing Spark configuration. + val SPARK_CONF_FILE = "__spark_conf__.properties" + + // Subdirectory where the user's python files (not archives) will be placed. + val LOCALIZED_PYTHON_DIR = "__pyfiles__" /** * Find the user-defined Spark jar if configured, or return the jar containing this @@ -1033,7 +1110,7 @@ object Client extends Logging { if (isAM) { addClasspathEntry( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + - LOCALIZED_HADOOP_CONF_DIR, env) + LOCALIZED_CONF_DIR, env) } if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 9c7b1b3988082..35e990602a6cf 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -30,7 +30,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) var archives: String = null var userJar: String = null var userClass: String = null - var pyFiles: String = null + var pyFiles: Seq[String] = Nil var primaryPyFile: String = null var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() @@ -228,7 +228,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) args = tail case ("--py-files") :: value :: tail => - pyFiles = value + pyFiles = value.split(",") args = tail case ("--files") :: value :: tail => diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 99c05329b4d73..1c8d7ec57635f 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -76,7 +76,8 @@ private[spark] class YarnClientSchedulerBackend( ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), - ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue") + ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), + ("--py-files", null, "spark.submit.pyFiles") ) // Warn against the following deprecated environment variables: env var -> suggestion val deprecatedEnvVars = Map( @@ -86,7 +87,7 @@ private[spark] class YarnClientSchedulerBackend( optionTuples.foreach { case (optionName, envVar, sparkProp) => if (sc.getConf.contains(sparkProp)) { extraArgs += (optionName, sc.getConf.get(sparkProp)) - } else if (System.getenv(envVar) != null) { + } else if (envVar != null && System.getenv(envVar) != null) { extraArgs += (optionName, System.getenv(envVar)) if (deprecatedEnvVars.contains(envVar)) { logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.") diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 01d33c9ce9297..4ec976aa31387 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -113,7 +113,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { Environment.PWD.$() } cp should contain(pwdVar) - cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_HADOOP_CONF_DIR}") + cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_CONF_DIR}") cp should not contain (Client.SPARK_JAR) cp should not contain (Client.APP_JAR) } @@ -129,7 +129,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { val tempDir = Utils.createTempDir() try { - client.prepareLocalResources(tempDir.getAbsolutePath()) + client.prepareLocalResources(tempDir.getAbsolutePath(), Nil) sparkConf.getOption(Client.CONF_SPARK_USER_JAR) should be (Some(USER)) // The non-local path should be propagated by name only, since it will end up in the app's diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 93d587d0cb36a..a0f25ba450068 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -56,6 +56,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher """.stripMargin private val TEST_PYFILE = """ + |import mod1, mod2 |import sys |from operator import add | @@ -67,7 +68,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher | sc = SparkContext(conf=SparkConf()) | status = open(sys.argv[1],'w') | result = "failure" - | rdd = sc.parallelize(range(10)) + | rdd = sc.parallelize(range(10)).map(lambda x: x * mod1.func() * mod2.func()) | cnt = rdd.count() | if cnt == 10: | result = "success" @@ -76,6 +77,11 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher | sc.stop() """.stripMargin + private val TEST_PYMODULE = """ + |def func(): + | return 42 + """.stripMargin + private var yarnCluster: MiniYARNCluster = _ private var tempDir: File = _ private var fakeSparkJar: File = _ @@ -124,7 +130,7 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) - hadoopConfDir = new File(tempDir, Client.LOCALIZED_HADOOP_CONF_DIR) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_CONF_DIR) assert(hadoopConfDir.mkdir()) File.createTempFile("token", ".txt", hadoopConfDir) } @@ -151,26 +157,12 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher } } - // Enable this once fix SPARK-6700 - test("run Python application in yarn-cluster mode") { - val primaryPyFile = new File(tempDir, "test.py") - Files.write(TEST_PYFILE, primaryPyFile, UTF_8) - val pyFile = new File(tempDir, "test2.py") - Files.write(TEST_PYFILE, pyFile, UTF_8) - var result = File.createTempFile("result", null, tempDir) + test("run Python application in yarn-client mode") { + testPySpark(true) + } - // The sbt assembly does not include pyspark / py4j python dependencies, so we need to - // propagate SPARK_HOME so that those are added to PYTHONPATH. See PythonUtils.scala. - val sparkHome = sys.props("spark.test.home") - val extraConf = Map( - "spark.executorEnv.SPARK_HOME" -> sparkHome, - "spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome) - - runSpark(false, primaryPyFile.getAbsolutePath(), - sparkArgs = Seq("--py-files", pyFile.getAbsolutePath()), - appArgs = Seq(result.getAbsolutePath()), - extraConf = extraConf) - checkResult(result) + test("run Python application in yarn-cluster mode") { + testPySpark(false) } test("user class path first in client mode") { @@ -188,6 +180,33 @@ class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matcher checkResult(result) } + private def testPySpark(clientMode: Boolean): Unit = { + val primaryPyFile = new File(tempDir, "test.py") + Files.write(TEST_PYFILE, primaryPyFile, UTF_8) + + val moduleDir = + if (clientMode) { + // In client-mode, .py files added with --py-files are not visible in the driver. + // This is something that the launcher library would have to handle. + tempDir + } else { + val subdir = new File(tempDir, "pyModules") + subdir.mkdir() + subdir + } + val pyModule = new File(moduleDir, "mod1.py") + Files.write(TEST_PYMODULE, pyModule, UTF_8) + + val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) + val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") + val result = File.createTempFile("result", null, tempDir) + + runSpark(clientMode, primaryPyFile.getAbsolutePath(), + sparkArgs = Seq("--py-files", pyFiles), + appArgs = Seq(result.getAbsolutePath())) + checkResult(result) + } + private def testUseClassPathFirst(clientMode: Boolean): Unit = { // Create a jar file that contains a different version of "test.resource". val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir) From 30ebf1a233295539c2455bd838bae7315711e1e2 Mon Sep 17 00:00:00 2001 From: Hossein Date: Wed, 10 Jun 2015 13:18:48 -0700 Subject: [PATCH 005/151] [SPARK-8282] [SPARKR] Make number of threads used in RBackend configurable Read number of threads for RBackend from configuration. [SPARK-8282] #comment Linking with JIRA Author: Hossein Closes #6730 from falaki/SPARK-8282 and squashes the following commits: 33b3d98 [Hossein] Documented new config parameter 70f2a9c [Hossein] Fixing import ec44225 [Hossein] Read number of threads for RBackend from configuration --- .../main/scala/org/apache/spark/api/r/RBackend.scala | 5 +++-- docs/configuration.md | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index d24c650d37bb0..1a5f2bca26c2b 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -29,7 +29,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.LengthFieldBasedFrameDecoder import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkConf} /** * Netty-based backend server that is used to communicate between R and Java. @@ -41,7 +41,8 @@ private[spark] class RBackend { private[this] var bossGroup: EventLoopGroup = null def init(): Int = { - bossGroup = new NioEventLoopGroup(2) + val conf = new SparkConf() + bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2)) val workerGroup = bossGroup val handler = new RBackendHandler(this) diff --git a/docs/configuration.md b/docs/configuration.md index 3960e7e78bde1..95a322f79b40b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1495,6 +1495,18 @@ Apart from these, the following properties are also available, and may be useful +#### SparkR + + + + + + + +
Property NameDefaultMeaning
spark.r.numRBackendThreads2 + Number of threads used by RBackend to handle RPC calls from SparkR package. +
+ #### Cluster Managers Each cluster manager in Spark has additional configuration options. Configurations can be found on the pages for each mode: From 19e30b48f3c6d0b72871d3e15b9564c1b2822700 Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Wed, 10 Jun 2015 13:21:01 -0700 Subject: [PATCH 006/151] [SPARK-7756] CORE RDDOperationScope fix for IBM Java IBM Java has an extra method when we do getStackTrace(): this is "getStackTraceImpl", a native method. This causes two tests to fail within "DStreamScopeSuite" when running with IBM Java. Instead of "map" or "filter" being the method names found, "getStackTrace" is returned. This commit addresses such an issue by using dropWhile. Given that our current method is withScope, we look for the next method that isn't ours: we don't care about methods that come before us in the stack trace: e.g. getStackTrace (regardless of how many levels this might go). IBM: java.lang.Thread.getStackTraceImpl(Native Method) java.lang.Thread.getStackTrace(Thread.java:1117) org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:104) Oracle: PRINTING STACKTRACE!!! java.lang.Thread.getStackTrace(Thread.java:1552) org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:106) I've tested this with Oracle and IBM Java, no side effects for other tests introduced. Author: Adam Roberts Author: a-roberts Closes #6740 from a-roberts/RDDScopeStackCrawlFix and squashes the following commits: 13ce390 [Adam Roberts] Ensure consistency with String equality checking a4fc0e0 [a-roberts] Update RDDOperationScope.scala --- .../scala/org/apache/spark/rdd/RDDOperationScope.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 6b09dfafc889c..44667281c1063 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -95,10 +95,9 @@ private[spark] object RDDOperationScope extends Logging { private[spark] def withScope[T]( sc: SparkContext, allowNesting: Boolean = false)(body: => T): T = { - val stackTrace = Thread.currentThread.getStackTrace().tail // ignore "Thread#getStackTrace" - val ourMethodName = stackTrace(1).getMethodName // i.e. withScope - // Climb upwards to find the first method that's called something different - val callerMethodName = stackTrace + val ourMethodName = "withScope" + val callerMethodName = Thread.currentThread.getStackTrace() + .dropWhile(_.getMethodName != ourMethodName) .find(_.getMethodName != ourMethodName) .map(_.getMethodName) .getOrElse { From e90c9d92d9a86e9960c10a5c043f3c02f6c636f9 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 10 Jun 2015 13:22:52 -0700 Subject: [PATCH 007/151] [SPARK-7527] [CORE] Fix createNullValue to return the correct null values and REPL mode detection The root cause of SPARK-7527 is `createNullValue` returns an incompatible value `Byte(0)` for `char` and `boolean`. This PR fixes it and corrects the class name of the main class, and also adds an unit test to demonstrate it. Author: zsxwing Closes #6735 from zsxwing/SPARK-7527 and squashes the following commits: bbdb271 [zsxwing] Use pattern match in createNullValue b0a0e7e [zsxwing] Remove the noisy in the test output 903e269 [zsxwing] Remove the code for Utils.isInInterpreter == false 5f92dc1 [zsxwing] Fix createNullValue to return the correct null values and REPL mode detection --- .../apache/spark/util/ClosureCleaner.scala | 40 ++++++++--------- .../scala/org/apache/spark/util/Utils.scala | 9 +--- .../spark/util/ClosureCleanerSuite.scala | 44 +++++++++++++++++++ 3 files changed, 64 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 6f2966bd4fd31..305de4c75539d 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -109,7 +109,14 @@ private[spark] object ClosureCleaner extends Logging { private def createNullValue(cls: Class[_]): AnyRef = { if (cls.isPrimitive) { - new java.lang.Byte(0: Byte) // Should be convertible to any primitive type + cls match { + case java.lang.Boolean.TYPE => new java.lang.Boolean(false) + case java.lang.Character.TYPE => new java.lang.Character('\0') + case java.lang.Void.TYPE => + // This should not happen because `Foo(void x) {}` does not compile. + throw new IllegalStateException("Unexpected void parameter in constructor") + case _ => new java.lang.Byte(0: Byte) + } } else { null } @@ -319,28 +326,17 @@ private[spark] object ClosureCleaner extends Logging { private def instantiateClass( cls: Class[_], enclosingObject: AnyRef): AnyRef = { - if (!Utils.isInInterpreter) { - // This is a bona fide closure class, whose constructor has no effects - // other than to set its fields, so use its constructor - val cons = cls.getConstructors()(0) - val params = cons.getParameterTypes.map(createNullValue).toArray - if (enclosingObject != null) { - params(0) = enclosingObject // First param is always enclosing object - } - return cons.newInstance(params: _*).asInstanceOf[AnyRef] - } else { - // Use reflection to instantiate object without calling constructor - val rf = sun.reflect.ReflectionFactory.getReflectionFactory() - val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() - val newCtor = rf.newConstructorForSerialization(cls, parentCtor) - val obj = newCtor.newInstance().asInstanceOf[AnyRef] - if (enclosingObject != null) { - val field = cls.getDeclaredField("$outer") - field.setAccessible(true) - field.set(obj, enclosingObject) - } - obj + // Use reflection to instantiate object without calling constructor + val rf = sun.reflect.ReflectionFactory.getReflectionFactory() + val parentCtor = classOf[java.lang.Object].getDeclaredConstructor() + val newCtor = rf.newConstructorForSerialization(cls, parentCtor) + val obj = newCtor.newInstance().asInstanceOf[AnyRef] + if (enclosingObject != null) { + val field = cls.getDeclaredField("$outer") + field.setAccessible(true) + field.set(obj, enclosingObject) } + obj } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 153ece6224a6d..19157af5b6f4d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1804,15 +1804,10 @@ private[spark] object Utils extends Logging { lazy val isInInterpreter: Boolean = { try { - val interpClass = classForName("spark.repl.Main") + val interpClass = classForName("org.apache.spark.repl.Main") interpClass.getMethod("interp").invoke(null) != null } catch { - // Returning true seems to be a mistake. - // Currently changing it to false causes tests failures in Streaming. - // For a more detailed discussion, please, refer to - // https://github.com/apache/spark/pull/5835#issuecomment-101042271 and subsequent comments. - // Addressing this changed is tracked as https://issues.apache.org/jira/browse/SPARK-7527 - case _: ClassNotFoundException => true + case _: ClassNotFoundException => false } } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 70cd27b04347d..1053c6caf7718 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -121,6 +121,10 @@ class ClosureCleanerSuite extends SparkFunSuite { expectCorrectException { TestUserClosuresActuallyCleaned.testSubmitJob(sc) } } } + + test("createNullValue") { + new TestCreateNullValue().run() + } } // A non-serializable class we create in closures to make sure that we aren't @@ -350,3 +354,43 @@ private object TestUserClosuresActuallyCleaned { ) } } + +class TestCreateNullValue { + + var x = 5 + + def getX: Int = x + + def run(): Unit = { + val bo: Boolean = true + val c: Char = '1' + val b: Byte = 1 + val s: Short = 1 + val i: Int = 1 + val l: Long = 1 + val f: Float = 1 + val d: Double = 1 + + // Bring in all primitive types into the closure such that they become + // parameters of the closure constructor. This allows us to test whether + // null values are created correctly for each type. + val nestedClosure = () => { + if (s.toString == "123") { // Don't really output them to avoid noisy + println(bo) + println(c) + println(b) + println(s) + println(i) + println(l) + println(f) + println(d) + } + + val closure = () => { + println(getX) + } + ClosureCleaner.clean(closure) + } + nestedClosure() + } +} From 80043e9e761c44ce2c3a432dcd1989be573f8bb4 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 10 Jun 2015 13:25:59 -0700 Subject: [PATCH 008/151] [SPARK-7261] [CORE] Change default log level to WARN in the REPL 1. Add `log4j-defaults-repl.properties` that has log level WARN. 2. When logging is initialized, check whether inside the REPL. If so, use `log4j-defaults-repl.properties`. 3. Print the following information if using `log4j-defaults-repl.properties`: ``` Using Spark's repl log4j profile: org/apache/spark/log4j-defaults-repl.properties To adjust logging level use sc.setLogLevel("INFO") ``` Author: zsxwing Closes #6734 from zsxwing/log4j-repl and squashes the following commits: 3835eff [zsxwing] Change default log level to WARN in the REPL --- .rat-excludes | 1 + .../spark/log4j-defaults-repl.properties | 12 +++++++++ .../main/scala/org/apache/spark/Logging.scala | 26 ++++++++++++++----- 3 files changed, 32 insertions(+), 7 deletions(-) create mode 100644 core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties diff --git a/.rat-excludes b/.rat-excludes index 994c7e86f8a91..aa008e6e920f5 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -28,6 +28,7 @@ spark-env.sh spark-env.cmd spark-env.sh.template log4j-defaults.properties +log4j-defaults-repl.properties bootstrap-tooltip.js jquery-1.11.1.min.js d3.min.js diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties new file mode 100644 index 0000000000000..b146f8a784127 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -0,0 +1,12 @@ +# Set everything to be logged to the console +log4j.rootCategory=WARN, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 419d093d55643..7fcb7830e7b0b 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -121,13 +121,25 @@ trait Logging { if (usingLog4j12) { val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4j12Initialized) { - val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { - case Some(url) => - PropertyConfigurator.configure(url) - System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") - case None => - System.err.println(s"Spark was unable to load $defaultLogProps") + if (Utils.isInInterpreter) { + val replDefaultLogProps = "org/apache/spark/log4j-defaults-repl.properties" + Option(Utils.getSparkClassLoader.getResource(replDefaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's repl log4j profile: $replDefaultLogProps") + System.err.println("To adjust logging level use sc.setLogLevel(\"INFO\")") + case None => + System.err.println(s"Spark was unable to load $replDefaultLogProps") + } + } else { + val defaultLogProps = "org/apache/spark/log4j-defaults.properties" + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => + PropertyConfigurator.configure(url) + System.err.println(s"Using Spark's default log4j profile: $defaultLogProps") + case None => + System.err.println(s"Spark was unable to load $defaultLogProps") + } } } } From cb871c44c38a4c1575ed076389f14641afafad7d Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Wed, 10 Jun 2015 13:30:16 -0700 Subject: [PATCH 009/151] [SPARK-8290] spark class command builder need read SPARK_JAVA_OPTS and SPARK_DRIVER_MEMORY properly SPARK_JAVA_OPTS was missed in reconstructing the launcher part, we should add it back so process launched by spark-class could read it properly. And so does `SPARK_DRIVER_MEMORY`. The missing part is [here](https://github.com/apache/spark/blob/1c30afdf94b27e1ad65df0735575306e65d148a1/bin/spark-class#L97). Author: WangTaoTheTonic Author: Tao Wang Closes #6741 from WangTaoTheTonic/SPARK-8290 and squashes the following commits: bd89f0f [Tao Wang] make sure the memory setting is right too e313520 [WangTaoTheTonic] spark class command builder need read SPARK_JAVA_OPTS --- .../org/apache/spark/launcher/SparkClassCommandBuilder.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index d80abf2a8676e..de85720febf23 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -93,6 +93,9 @@ public List buildCommand(Map env) throws IOException { toolsDir.getAbsolutePath(), className); javaOptsKeys.add("SPARK_JAVA_OPTS"); + } else { + javaOptsKeys.add("SPARK_JAVA_OPTS"); + memKey = "SPARK_DRIVER_MEMORY"; } List cmd = buildJavaCommand(extraClassPath); From 5014d0ed7e2f69810654003f8dd38078b945cf05 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Wed, 10 Jun 2015 13:34:19 -0700 Subject: [PATCH 010/151] [SPARK-8273] Driver hangs up when yarn shutdown in client mode In client mode, if yarn was shut down with spark application running, the application will hang up after several retries(default: 30) because the exception throwed by YarnClientImpl could not be caught by upper level, we should exit in case that user can not be aware that. The exception we wanna catch is [here](https://github.com/apache/hadoop/blob/branch-2.7.0/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/io/retry/RetryInvocationHandler.java#L122), and I try to fix it refer to [MR](https://github.com/apache/hadoop/blob/branch-2.7.0/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-jobclient/src/main/java/org/apache/hadoop/mapred/ClientServiceDelegate.java#L320). Author: WangTaoTheTonic Closes #6717 from WangTaoTheTonic/SPARK-8273 and squashes the following commits: 28752d6 [WangTaoTheTonic] catch the throwed exception --- yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index ec9402afff329..da1ec2a0fe2e9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -29,6 +29,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} +import scala.util.control.NonFatal import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects @@ -826,6 +827,9 @@ private[spark] class Client( case e: ApplicationNotFoundException => logError(s"Application $appId not found.") return (YarnApplicationState.KILLED, FinalApplicationStatus.KILLED) + case NonFatal(e) => + logError(s"Failed to contact YARN for application $appId.", e) + return (YarnApplicationState.FAILED, FinalApplicationStatus.FAILED) } val state = report.getYarnApplicationState From 96a7c888d806adfdb2c722025a1079ed7eaa2052 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Wed, 10 Jun 2015 15:03:40 -0700 Subject: [PATCH 011/151] [SPARK-2774] Set preferred locations for reduce tasks Set preferred locations for reduce tasks. The basic design is that we maintain a map from reducerId to a list of (sizes, locations) for each shuffle. We then set the preferred locations to be any machines that have 20% of more of the output that needs to be read by the reduce task. This will result in at most 5 preferred locations for each reduce task. Selecting the preferred locations involves O(# map tasks * # reduce tasks) computation, so we restrict this feature to cases where we have fewer than 1000 map tasks and 1000 reduce tasks. Author: Shivaram Venkataraman Closes #6652 from shivaram/reduce-locations and squashes the following commits: 492e25e [Shivaram Venkataraman] Remove unused import 2ef2d39 [Shivaram Venkataraman] Address code review comments 897a914 [Shivaram Venkataraman] Remove unused hash map f5be578 [Shivaram Venkataraman] Use fraction of map outputs to determine locations Also removes caching of preferred locations to make the API cleaner 68bc29e [Shivaram Venkataraman] Fix line length 1090b58 [Shivaram Venkataraman] Change flag name 77ce7d8 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into reduce-locations e5d56bd [Shivaram Venkataraman] Add flag to turn off locality for shuffle deps 6cfae98 [Shivaram Venkataraman] Filter out zero blocks, rename variables 9d5831a [Shivaram Venkataraman] Address some more comments 8e31266 [Shivaram Venkataraman] Fix style 0df3180 [Shivaram Venkataraman] Address code review comments e7d5449 [Shivaram Venkataraman] Fix merge issues ad7cb53 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into reduce-locations df14cee [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into reduce-locations 5093aea [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into reduce-locations 0171d3c [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into reduce-locations bc4dfd6 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into reduce-locations 774751b [Shivaram Venkataraman] Fix bug introduced by line length adjustment 34d0283 [Shivaram Venkataraman] Fix style issues 3b464b7 [Shivaram Venkataraman] Set preferred locations for reduce tasks This is another attempt at #1697 addressing some of the earlier concerns. This adds a couple of thresholds based on number map and reduce tasks beyond which we don't use preferred locations for reduce tasks. --- .../org/apache/spark/MapOutputTracker.scala | 49 +++++++++++- .../apache/spark/scheduler/DAGScheduler.scala | 37 ++++++++- .../apache/spark/MapOutputTrackerSuite.scala | 35 +++++++++ .../spark/scheduler/DAGSchedulerSuite.scala | 76 +++++++++++++++---- 4 files changed, 177 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 018422827e1c8..862ffe868f58f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,7 +21,7 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, Map} +import scala.collection.mutable.{HashMap, HashSet, Map} import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -284,6 +284,53 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) } + /** + * Return a list of locations that each have fraction of map output greater than the specified + * threshold. + * + * @param shuffleId id of the shuffle + * @param reducerId id of the reduce task + * @param numReducers total number of reducers in the shuffle + * @param fractionThreshold fraction of total map output size that a location must have + * for it to be considered large. + * + * This method is not thread-safe. + */ + def getLocationsWithLargestOutputs( + shuffleId: Int, + reducerId: Int, + numReducers: Int, + fractionThreshold: Double) + : Option[Array[BlockManagerId]] = { + + if (mapStatuses.contains(shuffleId)) { + val statuses = mapStatuses(shuffleId) + if (statuses.nonEmpty) { + // HashMap to add up sizes of all blocks at the same location + val locs = new HashMap[BlockManagerId, Long] + var totalOutputSize = 0L + var mapIdx = 0 + while (mapIdx < statuses.length) { + val status = statuses(mapIdx) + val blockSize = status.getSizeForBlock(reducerId) + if (blockSize > 0) { + locs(status.location) = locs.getOrElse(status.location, 0L) + blockSize + totalOutputSize += blockSize + } + mapIdx = mapIdx + 1 + } + val topLocs = locs.filter { case (loc, size) => + size.toDouble / totalOutputSize >= fractionThreshold + } + // Return if we have any locations which satisfy the required threshold + if (topLocs.nonEmpty) { + return Some(topLocs.map(_._1).toArray) + } + } + } + None + } + def incrementEpoch() { epochLock.synchronized { epoch += 1 diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 75a567fb31520..aea6674ed20be 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -137,6 +137,22 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + // Flag to control if reduce tasks are assigned preferred locations + private val shuffleLocalityEnabled = + sc.getConf.getBoolean("spark.shuffle.reduceLocality.enabled", true) + // Number of map, reduce tasks above which we do not assign preferred locations + // based on map output sizes. We limit the size of jobs for which assign preferred locations + // as computing the top locations by size becomes expensive. + private[this] val SHUFFLE_PREF_MAP_THRESHOLD = 1000 + // NOTE: This should be less than 2000 as we use HighlyCompressedMapStatus beyond that + private[this] val SHUFFLE_PREF_REDUCE_THRESHOLD = 1000 + + // Fraction of total map output that must be at a location for it to considered as a preferred + // location for a reduce task. + // Making this larger will focus on fewer locations where most data can be read locally, but + // may lead to more delay in scheduling if those locations are busy. + private[scheduler] val REDUCER_PREF_LOCS_FRACTION = 0.2 + // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) @@ -1384,17 +1400,32 @@ class DAGScheduler( if (rddPrefs.nonEmpty) { return rddPrefs.map(TaskLocation(_)) } - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep - // that has any placement preferences. Ideally we would choose based on transfer sizes, - // but this will do for now. + rdd.dependencies.foreach { case n: NarrowDependency[_] => + // If the RDD has narrow dependencies, pick the first partition of the first narrow dep + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. for (inPart <- n.getParents(partition)) { val locs = getPreferredLocsInternal(n.rdd, inPart, visited) if (locs != Nil) { return locs } } + case s: ShuffleDependency[_, _, _] => + // For shuffle dependencies, pick locations which have at least REDUCER_PREF_LOCS_FRACTION + // of data as preferred locations + if (shuffleLocalityEnabled && + rdd.partitions.size < SHUFFLE_PREF_REDUCE_THRESHOLD && + s.rdd.partitions.size < SHUFFLE_PREF_MAP_THRESHOLD) { + // Get the preferred map output locations for this reducer + val topLocsForReducer = mapOutputTracker.getLocationsWithLargestOutputs(s.shuffleId, + partition, rdd.partitions.size, REDUCER_PREF_LOCS_FRACTION) + if (topLocsForReducer.nonEmpty) { + return topLocsForReducer.get.map(loc => TaskLocation(loc.host, loc.executorId)) + } + } + case _ => } Nil diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 1fab69678d040..7a1961137cce5 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -205,4 +205,39 @@ class MapOutputTrackerSuite extends SparkFunSuite { // masterTracker.stop() // this throws an exception rpcEnv.shutdown() } + + test("getLocationsWithLargestOutputs with multiple outputs in same machine") { + val rpcEnv = createRpcEnv("test") + val tracker = new MapOutputTrackerMaster(conf) + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + // Setup 3 map tasks + // on hostA with output size 2 + // on hostA with output size 2 + // on hostB with output size 3 + tracker.registerShuffle(10, 3) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(2L))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(2L))) + tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(3L))) + + // When the threshold is 50%, only host A should be returned as a preferred location + // as it has 4 out of 7 bytes of output. + val topLocs50 = tracker.getLocationsWithLargestOutputs(10, 0, 1, 0.5) + assert(topLocs50.nonEmpty) + assert(topLocs50.get.size === 1) + assert(topLocs50.get.head === BlockManagerId("a", "hostA", 1000)) + + // When the threshold is 20%, both hosts should be returned as preferred locations. + val topLocs20 = tracker.getLocationsWithLargestOutputs(10, 0, 1, 0.2) + assert(topLocs20.nonEmpty) + assert(topLocs20.get.size === 2) + assert(topLocs20.get.toSet === + Seq(BlockManagerId("a", "hostA", 1000), BlockManagerId("b", "hostB", 1000)).toSet) + + tracker.stop() + rpcEnv.shutdown() + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 47b2868753c0e..833b600746e90 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -490,8 +490,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), @@ -501,7 +501,7 @@ class DAGSchedulerSuite // ask the scheduler to try it again scheduler.resubmitFailedStages() // have the 2nd attempt pass - complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1)))) + complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) // we can see both result blocks now assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) @@ -517,8 +517,8 @@ class DAGSchedulerSuite val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) submit(reduceRdd, Array(0, 1)) complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 1)), - (Success, makeMapStatus("hostB", 1)))) + (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) // The MapOutputTracker should know about both map output locations. assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) @@ -560,18 +560,18 @@ class DAGSchedulerSuite assert(newEpoch > oldEpoch) val taskSet = taskSets(0) // should be ignored for being too old - runEvent(CompletionEvent( - taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", + reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) // should work because it's a non-failed host - runEvent(CompletionEvent( - taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", + reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) // should be ignored for being too old - runEvent(CompletionEvent( - taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", + reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) // should work because it's a new epoch taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent( - taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", + reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -800,6 +800,50 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + test("reduce tasks should be placed locally with map output") { + // Create an shuffleMapRdd with 1 partition + val shuffleMapRdd = new MyRDD(sc, 1, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 1)))) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostA"))) + + // Reducer should run on the same host that map task ran + val reduceTaskSet = taskSets(1) + assertLocations(reduceTaskSet, Seq(Seq("hostA"))) + complete(reduceTaskSet, Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty + } + + test("reduce task locality preferences should only include machines with largest map outputs") { + val numMapTasks = 4 + // Create an shuffleMapRdd with more partitions + val shuffleMapRdd = new MyRDD(sc, numMapTasks, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 1, List(shuffleDep)) + submit(reduceRdd, Array(0)) + + val statuses = (1 to numMapTasks).map { i => + (Success, makeMapStatus("host" + i, 1, (10*i).toByte)) + } + complete(taskSets(0), statuses) + + // Reducer should prefer the last 3 hosts as they have 20%, 30% and 40% of data + val hosts = (1 to numMapTasks).map(i => "host" + i).reverse.take(numMapTasks - 1) + + val reduceTaskSet = taskSets(1) + assertLocations(reduceTaskSet, Seq(hosts)) + complete(reduceTaskSet, Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. @@ -807,12 +851,12 @@ class DAGSchedulerSuite private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]) { assert(hosts.size === taskSet.tasks.size) for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) { - assert(taskLocs.map(_.host) === expectedLocs) + assert(taskLocs.map(_.host).toSet === expectedLocs.toSet) } } - private def makeMapStatus(host: String, reduces: Int): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(2)) + private def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) From b928f543845ddd39e914a0e8f0b0205fd86100c5 Mon Sep 17 00:00:00 2001 From: Paavo Date: Wed, 10 Jun 2015 23:17:42 +0100 Subject: [PATCH 012/151] [SPARK-8200] [MLLIB] Check for empty RDDs in StreamingLinearAlgorithm Test cases for both StreamingLinearRegression and StreamingLogisticRegression, and code fix. Edit: This contribution is my original work and I license the work to the project under the project's open source license. Author: Paavo Closes #6713 from pparkkin/streamingmodel-empty-rdd and squashes the following commits: ff5cd78 [Paavo] Update strings to use interpolation. db234cf [Paavo] Use !rdd.isEmpty. 54ad89e [Paavo] Test case for empty stream. 393e36f [Paavo] Ignore empty RDDs. 0bfc365 [Paavo] Test case for empty stream. --- .../regression/StreamingLinearAlgorithm.scala | 14 ++++++++------ .../StreamingLogisticRegressionSuite.scala | 17 +++++++++++++++++ .../StreamingLinearRegressionSuite.scala | 18 ++++++++++++++++++ 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index aee51bf22d8d0..141052ba813ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -83,13 +83,15 @@ abstract class StreamingLinearAlgorithm[ throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - model = Some(algorithm.run(rdd, model.get.weights)) - logInfo("Model updated at time %s".format(time.toString)) - val display = model.get.weights.size match { - case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") - case _ => model.get.weights.toArray.mkString("[", ",", "]") + if (!rdd.isEmpty) { + model = Some(algorithm.run(rdd, model.get.weights)) + logInfo(s"Model updated at time ${time.toString}") + val display = model.get.weights.size match { + case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.get.weights.toArray.mkString("[", ",", "]") + } + logInfo(s"Current model: weights, ${display}") } - logInfo("Current model: weights, %s".format (display)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index e98b61e13e21f..fd653296c9d97 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -158,4 +158,21 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList assert(error.head > 0.8 & error.last < 0.2) } + + // Test empty RDDs in a stream + test("handling empty RDDs in a stream") { + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + val numBatches = 10 + val emptyInput = Seq.empty[Seq[LabeledPoint]] + val ssc = setupStreams(emptyInput, + (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + } + ) + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 9a379406d5061..f5e2d31056cbd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -166,4 +166,22 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList assert((error.head - error.last) > 2) } + + // Test empty RDDs in a stream + test("handling empty RDDs in a stream") { + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + val numBatches = 10 + val nPoints = 100 + val emptyInput = Seq.empty[Seq[LabeledPoint]] + val ssc = setupStreams(emptyInput, + (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + } + ) + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + } } From 37719e0cd0b00cc5ffee0ebe1652d465a574db0f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 10 Jun 2015 16:55:39 -0700 Subject: [PATCH 013/151] [SPARK-8189] [SQL] use Long for TimestampType in SQL This PR change to use Long as internal type for TimestampType for efficiency, which means it will the precision below 100ns. Author: Davies Liu Closes #6733 from davies/timestamp and squashes the following commits: d9565fa [Davies Liu] remove print 65cf2f1 [Davies Liu] fix Timestamp in SparkR 86fecfb [Davies Liu] disable two timestamp tests 8f77ee0 [Davies Liu] fix scala style 246ee74 [Davies Liu] address comments 309d2e1 [Davies Liu] use Long for TimestampType in SQL --- .../scala/org/apache/spark/api/r/SerDe.scala | 17 +++-- python/pyspark/sql/types.py | 11 ++++ .../scala/org/apache/spark/sql/BaseRow.java | 6 ++ .../main/scala/org/apache/spark/sql/Row.scala | 8 ++- .../sql/catalyst/CatalystTypeConverters.scala | 13 +++- .../spark/sql/catalyst/expressions/Cast.scala | 62 +++++++++---------- .../expressions/SpecificMutableRow.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 4 +- .../codegen/GenerateProjection.scala | 10 ++- .../sql/catalyst/expressions/literals.scala | 15 +++-- .../sql/catalyst/expressions/predicates.scala | 6 +- .../spark/sql/catalyst/util/DateUtils.scala | 44 ++++++++++--- .../spark/sql/types/TimestampType.scala | 10 +-- .../sql/catalyst/expressions/CastSuite.scala | 11 ++-- .../sql/catalyst/util/DateUtilsSuite.scala | 40 ++++++++++++ .../spark/sql/types/DataTypeSuite.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 21 +------ .../spark/sql/columnar/ColumnType.scala | 19 +++--- .../sql/execution/SparkSqlSerializer2.scala | 17 ++--- .../spark/sql/execution/debug/package.scala | 2 + .../spark/sql/execution/pythonUdfs.scala | 7 ++- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 10 ++- .../apache/spark/sql/json/JacksonParser.scala | 5 +- .../org/apache/spark/sql/json/JsonRDD.scala | 10 ++- .../spark/sql/parquet/ParquetConverter.scala | 9 +-- .../sql/parquet/ParquetTableSupport.scala | 10 +-- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 2 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 11 ++-- .../sql/columnar/ColumnarTestUtils.scala | 9 +-- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 +- .../org/apache/spark/sql/json/JsonSuite.scala | 14 +++-- .../execution/HiveCompatibilitySuite.scala | 8 ++- .../spark/sql/hive/HiveInspectors.scala | 20 +++--- .../apache/spark/sql/hive/TableReader.scala | 4 +- ...cast #5-0-dbd7bcd167d322d6617b884c02c7f247 | 2 +- 36 files changed, 272 insertions(+), 172 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index f8e3f1a79082e..56adc857d4ce0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io.{DataInputStream, DataOutputStream} -import java.sql.{Date, Time} +import java.sql.{Timestamp, Date, Time} import scala.collection.JavaConversions._ @@ -107,9 +107,12 @@ private[spark] object SerDe { Date.valueOf(readString(in)) } - def readTime(in: DataInputStream): Time = { - val t = in.readDouble() - new Time((t * 1000L).toLong) + def readTime(in: DataInputStream): Timestamp = { + val seconds = in.readDouble() + val sec = Math.floor(seconds).toLong + val t = new Timestamp(sec * 1000L) + t.setNanos(((seconds - sec) * 1e9).toInt) + t } def readBytesArr(in: DataInputStream): Array[Array[Byte]] = { @@ -227,6 +230,9 @@ private[spark] object SerDe { case "java.sql.Time" => writeType(dos, "time") writeTime(dos, value.asInstanceOf[Time]) + case "java.sql.Timestamp" => + writeType(dos, "time") + writeTime(dos, value.asInstanceOf[Timestamp]) case "[B" => writeType(dos, "raw") writeBytes(dos, value.asInstanceOf[Array[Byte]]) @@ -289,6 +295,9 @@ private[spark] object SerDe { out.writeDouble(value.getTime.toDouble / 1000.0) } + def writeTime(out: DataOutputStream, value: Timestamp): Unit = { + out.writeDouble((value.getTime / 1000).toDouble + value.getNanos.toDouble / 1e9) + } // NOTE: Only works for ASCII right now def writeString(out: DataOutputStream, value: String): Unit = { diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index b6ec6137c9180..8f286b631f4f0 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -19,6 +19,7 @@ import decimal import time import datetime +import calendar import keyword import warnings import json @@ -654,6 +655,8 @@ def _need_python_to_sql_conversion(dataType): _need_python_to_sql_conversion(dataType.valueType) elif isinstance(dataType, UserDefinedType): return True + elif isinstance(dataType, TimestampType): + return True else: return False @@ -707,6 +710,14 @@ def converter(obj): return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) elif isinstance(dataType, UserDefinedType): return lambda obj: dataType.serialize(obj) + elif isinstance(dataType, TimestampType): + + def to_posix_timstamp(dt): + if dt.tzinfo is None: + return int(time.mktime(dt.timetuple()) * 1e7 + dt.microsecond * 10) + else: + return int(calendar.timegm(dt.utctimetuple()) * 1e7 + dt.microsecond * 10) + return to_posix_timstamp else: raise ValueError("Unexpected type %r" % dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java index d138b43a3482b..6584882a62fd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/BaseRow.java @@ -19,6 +19,7 @@ import java.math.BigDecimal; import java.sql.Date; +import java.sql.Timestamp; import java.util.List; import scala.collection.Seq; @@ -103,6 +104,11 @@ public Date getDate(int i) { throw new UnsupportedOperationException(); } + @Override + public Timestamp getTimestamp(int i) { + throw new UnsupportedOperationException(); + } + @Override public Seq getSeq(int i) { throw new UnsupportedOperationException(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 0d460b634d9b0..8aaf5d7d89154 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -260,9 +260,15 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - // TODO(davies): This is not the right default implementation, we use Int as Date internally def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] + /** + * Returns the value at position i of date type as java.sql.Timestamp. + * + * @throws ClassCastException when data type does not match. + */ + def getTimestamp(i: Int): java.sql.Timestamp = apply(i).asInstanceOf[java.sql.Timestamp] + /** * Returns the value at position i of array type as a Scala Seq. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 2e7b4c236d8f8..beb82dbc08642 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} -import java.sql.Date +import java.sql.{Timestamp, Date} import java.util.{Map => JavaMap} import javax.annotation.Nullable @@ -58,6 +58,7 @@ object CatalystTypeConverters { case structType: StructType => StructConverter(structType) case StringType => StringConverter case DateType => DateConverter + case TimestampType => TimestampConverter case dt: DecimalType => BigDecimalConverter case BooleanType => BooleanConverter case ByteType => ByteConverter @@ -274,6 +275,15 @@ object CatalystTypeConverters { override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column)) } + private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] { + override def toCatalystImpl(scalaValue: Timestamp): Long = + DateUtils.fromJavaTimestamp(scalaValue) + override def toScala(catalystValue: Any): Timestamp = + if (catalystValue == null) null + else DateUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long]) + override def toScalaImpl(row: Row, column: Int): Timestamp = toScala(row.getLong(column)) + } + private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { case d: BigDecimal => Decimal(d) @@ -367,6 +377,7 @@ object CatalystTypeConverters { def convertToCatalyst(a: Any): Any = a match { case s: String => StringConverter.toCatalyst(s) case d: Date => DateConverter.toCatalyst(d) + case t: Timestamp => TimestampConverter.toCatalyst(t) case d: BigDecimal => BigDecimalConverter.toCatalyst(d) case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) case seq: Seq[Any] => seq.map(convertToCatalyst) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 18102d1acb5b3..8d93957fea2fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -113,7 +113,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToString(from: DataType): Any => Any = from match { case BinaryType => buildCast[Array[Byte]](_, UTF8String(_)) case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d))) - case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t))) + case TimestampType => buildCast[Long](_, + t => UTF8String(timestampToString(DateUtils.toJavaTimestamp(t)))) case _ => buildCast[Any](_, o => UTF8String(o.toString)) } @@ -127,7 +128,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case StringType => buildCast[UTF8String](_, _.length() != 0) case TimestampType => - buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) + buildCast[Long](_, t => t != 0) case DateType => // Hive would return null when cast from date to boolean buildCast[Int](_, d => null) @@ -158,20 +159,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if (periodIdx != -1 && n.length() - periodIdx > 9) { n = n.substring(0, periodIdx + 10) } - try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null } + try DateUtils.fromJavaTimestamp(Timestamp.valueOf(n)) + catch { case _: java.lang.IllegalArgumentException => null } }) case BooleanType => - buildCast[Boolean](_, b => new Timestamp(if (b) 1 else 0)) + buildCast[Boolean](_, b => (if (b) 1L else 0)) case LongType => - buildCast[Long](_, l => new Timestamp(l)) + buildCast[Long](_, l => longToTimestamp(l)) case IntegerType => - buildCast[Int](_, i => new Timestamp(i)) + buildCast[Int](_, i => longToTimestamp(i.toLong)) case ShortType => - buildCast[Short](_, s => new Timestamp(s)) + buildCast[Short](_, s => longToTimestamp(s.toLong)) case ByteType => - buildCast[Byte](_, b => new Timestamp(b)) + buildCast[Byte](_, b => longToTimestamp(b.toLong)) case DateType => - buildCast[Int](_, d => new Timestamp(DateUtils.toJavaDate(d).getTime)) + buildCast[Int](_, d => DateUtils.toMillisSinceEpoch(d) * 10000) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -191,25 +193,17 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w }) } - private[this] def decimalToTimestamp(d: Decimal) = { - val seconds = Math.floor(d.toDouble).toLong - val bd = (d.toBigDecimal - seconds) * 1000000000 - val nanos = bd.intValue() - - val millis = seconds * 1000 - val t = new Timestamp(millis) - - // remaining fractional portion as nanos - t.setNanos(nanos) - t + private[this] def decimalToTimestamp(d: Decimal): Long = { + (d.toBigDecimal * 10000000L).longValue() } - // Timestamp to long, converting milliseconds to seconds - private[this] def timestampToLong(ts: Timestamp) = Math.floor(ts.getTime / 1000.0).toLong - - private[this] def timestampToDouble(ts: Timestamp) = { - // First part is the seconds since the beginning of time, followed by nanosecs. - Math.floor(ts.getTime / 1000.0).toLong + ts.getNanos.toDouble / 1000000000 + // converting milliseconds to 100ns + private[this] def longToTimestamp(t: Long): Long = t * 10000L + // converting 100ns to seconds + private[this] def timestampToLong(ts: Long): Long = math.floor(ts.toDouble / 10000000L).toLong + // converting 100ns to seconds in double + private[this] def timestampToDouble(ts: Long): Double = { + ts / 10000000.0 } // Converts Timestamp to string according to Hive TimestampWritable convention @@ -234,7 +228,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. - buildCast[Timestamp](_, t => DateUtils.millisToDays(t.getTime)) + buildCast[Long](_, t => DateUtils.millisToDays(t / 10000L)) // Hive throws this exception as a Semantic Exception // It is never possible to compare result when hive return with exception, // so we can return null @@ -253,7 +247,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t)) + buildCast[Long](_, t => timestampToLong(t)) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } @@ -269,7 +263,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t).toInt) + buildCast[Long](_, t => timestampToLong(t).toInt) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } @@ -285,7 +279,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t).toShort) + buildCast[Long](_, t => timestampToLong(t).toShort) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } @@ -301,7 +295,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToLong(t).toByte) + buildCast[Long](_, t => timestampToLong(t).toByte) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } @@ -334,7 +328,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. - buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) + buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) case DecimalType() => b => changePrecision(b.asInstanceOf[Decimal].clone(), target) case LongType => @@ -358,7 +352,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToDouble(t)) + buildCast[Long](_, t => timestampToDouble(t)) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } @@ -374,7 +368,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Int](_, d => null) case TimestampType => - buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) + buildCast[Long](_, t => timestampToDouble(t).toFloat) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index aa4099e4d7bf9..2c884517d62a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -203,6 +203,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR case BooleanType => new MutableBoolean case LongType => new MutableLong case DateType => new MutableInt // We use INT for DATE internally + case TimestampType => new MutableLong // We use Long for Timestamp internally case _ => new MutableAny }.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e95682f952a7b..80aa8fa056146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -122,7 +122,7 @@ class CodeGenContext { case BinaryType => "byte[]" case StringType => stringType case DateType => "int" - case TimestampType => "java.sql.Timestamp" + case TimestampType => "long" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" @@ -140,6 +140,7 @@ class CodeGenContext { case FloatType => "Float" case BooleanType => "Boolean" case DateType => "Integer" + case TimestampType => "Long" case _ => javaType(dt) } @@ -155,6 +156,7 @@ class CodeGenContext { case DoubleType => "-1.0" case IntegerType => "-1" case DateType => "-1" + case TimestampType => "-1L" case _ => "null" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 7caf4aaab88bb..274429cd1c55f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -73,7 +73,9 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificAccessorFunctions = ctx.nativeTypes.map { dataType => val cases = expressions.zipWithIndex.map { - case (e, i) if e.dataType == dataType => + case (e, i) if e.dataType == dataType + || dataType == IntegerType && e.dataType == DateType + || dataType == LongType && e.dataType == TimestampType => s"case $i: return c$i;" case _ => "" }.mkString("\n ") @@ -96,7 +98,9 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificMutatorFunctions = ctx.nativeTypes.map { dataType => val cases = expressions.zipWithIndex.map { - case (e, i) if e.dataType == dataType => + case (e, i) if e.dataType == dataType + || dataType == IntegerType && e.dataType == DateType + || dataType == LongType && e.dataType == TimestampType => s"case $i: { c$i = value; return; }" case _ => "" }.mkString("\n") @@ -119,7 +123,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val nonNull = e.dataType match { case BooleanType => s"$col ? 0 : 1" case ByteType | ShortType | IntegerType | DateType => s"$col" - case LongType => s"$col ^ ($col >>> 32)" + case LongType | TimestampType => s"$col ^ ($col >>> 32)" case FloatType => s"Float.floatToIntBits($col)" case DoubleType => s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 297b35b4da94c..833c08a293dcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -37,7 +37,7 @@ object Literal { case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: Decimal => Literal(d, DecimalType.Unlimited) - case t: Timestamp => Literal(t, TimestampType) + case t: Timestamp => Literal(DateUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) @@ -100,7 +100,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres ev.isNull = "false" ev.primitive = value.toString "" - case FloatType => // This must go before NumericType + case FloatType => val v = value.asInstanceOf[Float] if (v.isNaN || v.isInfinite) { super.genCode(ctx, ev) @@ -109,7 +109,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres ev.primitive = s"${value}f" "" } - case DoubleType => // This must go before NumericType + case DoubleType => val v = value.asInstanceOf[Double] if (v.isNaN || v.isInfinite) { super.genCode(ctx, ev) @@ -118,15 +118,18 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres ev.primitive = s"${value}" "" } - - case ByteType | ShortType => // This must go before NumericType + case ByteType | ShortType => ev.isNull = "false" ev.primitive = s"(${ctx.javaType(dataType)})$value" "" - case dt: NumericType if !dt.isInstanceOf[DecimalType] => + case IntegerType | DateType => ev.isNull = "false" ev.primitive = value.toString "" + case TimestampType | LongType => + ev.isNull = "false" + ev.primitive = s"${value}L" + "" // eval() version may be faster for non-primitive types case other => super.genCode(ctx, ev) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3cbdfdfb13847..2c49352874fc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -254,9 +254,9 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { (c1, c3) => s"$c1 $symbol $c3" }) - case TimestampType => - // java.sql.Timestamp does not have compare() - super.genCode(ctx, ev) + case DateType | TimestampType => defineCodeGen (ctx, ev, { + (c1, c3) => s"$c1 $symbol $c3" + }) case other => defineCodeGen (ctx, ev, { (c1, c2) => s"$c1.compare($c2) $symbol 0" }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala index ad649acf536f9..5cadc141af1df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import java.sql.Date +import java.sql.{Timestamp, Date} import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast */ object DateUtils { private val MILLIS_PER_DAY = 86400000 + private val HUNDRED_NANOS_PER_SECOND = 10000000L // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] { @@ -45,17 +46,17 @@ object DateUtils { ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt } - private def toMillisSinceEpoch(days: Int): Long = { + def toMillisSinceEpoch(days: Int): Long = { val millisUtc = days.toLong * MILLIS_PER_DAY millisUtc - LOCAL_TIMEZONE.get().getOffset(millisUtc) } - def fromJavaDate(date: java.sql.Date): Int = { + def fromJavaDate(date: Date): Int = { javaDateToDays(date) } - def toJavaDate(daysSinceEpoch: Int): java.sql.Date = { - new java.sql.Date(toMillisSinceEpoch(daysSinceEpoch)) + def toJavaDate(daysSinceEpoch: Int): Date = { + new Date(toMillisSinceEpoch(daysSinceEpoch)) } def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) @@ -64,9 +65,9 @@ object DateUtils { if (!s.contains('T')) { // JDBC escape string if (s.contains(' ')) { - java.sql.Timestamp.valueOf(s) + Timestamp.valueOf(s) } else { - java.sql.Date.valueOf(s) + Date.valueOf(s) } } else if (s.endsWith("Z")) { // this is zero timezone of ISO8601 @@ -87,4 +88,33 @@ object DateUtils { ISO8601GMT.parse(s) } } + + /** + * Return a java.sql.Timestamp from number of 100ns since epoch + */ + def toJavaTimestamp(num100ns: Long): Timestamp = { + // setNanos() will overwrite the millisecond part, so the milliseconds should be + // cut off at seconds + var seconds = num100ns / HUNDRED_NANOS_PER_SECOND + var nanos = num100ns % HUNDRED_NANOS_PER_SECOND + // setNanos() can not accept negative value + if (nanos < 0) { + nanos += HUNDRED_NANOS_PER_SECOND + seconds -= 1 + } + val t = new Timestamp(seconds * 1000) + t.setNanos(nanos.toInt * 100) + t + } + + /** + * Return the number of 100ns since epoch from java.sql.Timestamp. + */ + def fromJavaTimestamp(t: Timestamp): Long = { + if (t != null) { + t.getTime() * 10000L + (t.getNanos().toLong / 100) % 10000L + } else { + 0L + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala index aebabfc475925..a558641fcfed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import java.sql.Timestamp - import scala.math.Ordering import scala.reflect.runtime.universe.typeTag @@ -38,18 +36,16 @@ class TimestampType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Timestamp + private[sql] type InternalType = Long @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = new Ordering[InternalType] { - def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) - } + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the TimestampType is 12 bytes. */ - override def defaultSize: Int = 12 + override def defaultSize: Int = 8 private[spark] override def asNullable: TimestampType = this } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 5bc7c30eee1b6..3aca94db3bd8f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Timestamp, Date} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ /** @@ -137,7 +138,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(cast(sd, DateType), StringType), sd) checkEvaluation(cast(cast(d, StringType), DateType), 0) checkEvaluation(cast(cast(nts, TimestampType), StringType), nts) - checkEvaluation(cast(cast(ts, StringType), TimestampType), ts) + checkEvaluation(cast(cast(ts, StringType), TimestampType), DateUtils.fromJavaTimestamp(ts)) // all convert to string type to check checkEvaluation(cast(cast(cast(nts, TimestampType), DateType), StringType), sd) @@ -269,9 +270,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(ts, LongType), 15.toLong) checkEvaluation(cast(ts, FloatType), 15.002f) checkEvaluation(cast(ts, DoubleType), 15.002) - checkEvaluation(cast(cast(tss, ShortType), TimestampType), ts) - checkEvaluation(cast(cast(tss, IntegerType), TimestampType), ts) - checkEvaluation(cast(cast(tss, LongType), TimestampType), ts) + checkEvaluation(cast(cast(tss, ShortType), TimestampType), DateUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, IntegerType), TimestampType), DateUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(tss, LongType), TimestampType), DateUtils.fromJavaTimestamp(ts)) checkEvaluation( cast(cast(millis.toFloat / 1000, TimestampType), FloatType), millis.toFloat / 1000) @@ -283,7 +284,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { Decimal(1)) // A test for higher precision than millis - checkEvaluation(cast(cast(0.00000001, TimestampType), DoubleType), 0.00000001) + checkEvaluation(cast(cast(0.0000001, TimestampType), DoubleType), 0.0000001) checkEvaluation(cast(Double.NaN, TimestampType), null) checkEvaluation(cast(1.0 / 0.0, TimestampType), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateUtilsSuite.scala new file mode 100644 index 0000000000000..a4245545ffc1d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateUtilsSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.sql.Timestamp + +import org.apache.spark.SparkFunSuite + + +class DateUtilsSuite extends SparkFunSuite { + + test("timestamp") { + val now = new Timestamp(System.currentTimeMillis()) + now.setNanos(100) + val ns = DateUtils.fromJavaTimestamp(now) + assert(ns % 10000000L == 1) + assert(DateUtils.toJavaTimestamp(ns) == now) + + List(-111111111111L, -1L, 0, 1L, 111111111111L).foreach { t => + val ts = DateUtils.toJavaTimestamp(t) + assert(DateUtils.fromJavaTimestamp(ts) == t) + assert(DateUtils.toJavaTimestamp(DateUtils.fromJavaTimestamp(ts)) == ts) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 261c4fcad24aa..077c0ad70ac4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -190,7 +190,7 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(DecimalType(10, 5), 4096) checkDefaultSize(DecimalType.Unlimited, 4096) checkDefaultSize(DateType, 4) - checkDefaultSize(TimestampType, 12) + checkDefaultSize(TimestampType, 8) checkDefaultSize(StringType, 4096) checkDefaultSize(BinaryType, 4096) checkDefaultSize(ArrayType(DoubleType, true), 800) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index b0f983c180673..83881a3687090 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.columnar -import java.sql.Timestamp - import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { @@ -234,22 +232,7 @@ private[sql] class StringColumnStats extends ColumnStats { private[sql] class DateColumnStats extends IntColumnStats -private[sql] class TimestampColumnStats extends ColumnStats { - protected var upper: Timestamp = null - protected var lower: Timestamp = null - - override def gatherStats(row: Row, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) - if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Timestamp] - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += TIMESTAMP.defaultSize - } - } - - override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes) -} +private[sql] class TimestampColumnStats extends LongColumnStats private[sql] class BinaryColumnStats extends ColumnStats { override def gatherStats(row: Row, ordinal: Int): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 20be5ca9d0046..c9c4d630fb5f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag @@ -355,22 +354,20 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { } } -private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) { - override def extract(buffer: ByteBuffer): Timestamp = { - val timestamp = new Timestamp(buffer.getLong()) - timestamp.setNanos(buffer.getInt()) - timestamp +private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { + override def extract(buffer: ByteBuffer): Long = { + buffer.getLong } - override def append(v: Timestamp, buffer: ByteBuffer): Unit = { - buffer.putLong(v.getTime).putInt(v.getNanos) + override def append(v: Long, buffer: ByteBuffer): Unit = { + buffer.putLong(v) } - override def getField(row: Row, ordinal: Int): Timestamp = { - row(ordinal).asInstanceOf[Timestamp] + override def getField(row: Row, ordinal: Int): Long = { + row(ordinal).asInstanceOf[Long] } - override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = { + override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { row(ordinal) = value } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 256d527d7b636..60f3b2d539ffe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -20,14 +20,13 @@ package org.apache.spark.sql.execution import java.io._ import java.math.{BigDecimal, BigInteger} import java.nio.ByteBuffer -import java.sql.Timestamp import scala.reflect.ClassTag -import org.apache.spark.serializer._ import org.apache.spark.Logging +import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, MutableRow, SpecificMutableRow} import org.apache.spark.sql.types._ /** @@ -304,11 +303,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val timestamp = row.getAs[java.sql.Timestamp](i) - val time = timestamp.getTime - val nanos = timestamp.getNanos - out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value. - out.writeInt(nanos) // Write the nanoseconds part. + out.writeLong(row.getAs[Long](i)) } case StringType => @@ -429,11 +424,7 @@ private[sql] object SparkSqlSerializer2 { if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { - val time = in.readLong() // Read the milliseconds value. - val nanos = in.readInt() // Read the nanoseconds part. - val timestamp = new Timestamp(time) - timestamp.setNanos(nanos) - mutableRow.update(i, timestamp) + mutableRow.update(i, in.readLong()) } case StringType => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index dffb265601bdb..720b529d5946f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -170,6 +170,8 @@ package object debug { case (_: Short, ShortType) => case (_: Boolean, BooleanType) => case (_: Double, DoubleType) => + case (_: Int, DateType) => + case (_: Long, TimestampType) => case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType) case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 342587904789a..955b478a4882f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -148,6 +148,7 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) case (date: Int, DateType) => DateUtils.toJavaDate(date) + case (t: Long, TimestampType) => DateUtils.toJavaTimestamp(t) case (s: UTF8String, StringType) => s.toString // Pyrolite can handle Timestamp and Decimal @@ -186,10 +187,12 @@ object EvaluatePython { }): Row case (c: java.util.Calendar, DateType) => - DateUtils.fromJavaDate(new java.sql.Date(c.getTime().getTime())) + DateUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis)) case (c: java.util.Calendar, TimestampType) => - new java.sql.Timestamp(c.getTime().getTime()) + c.getTimeInMillis * 10000L + case (t: java.sql.Timestamp, TimestampType) => + DateUtils.fromJavaTimestamp(t) case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index db68b9c86db1b..9028d5ed72c92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -385,7 +385,7 @@ private[sql] class JDBCRDD( // DateUtils.fromJavaDate does not handle null value, so we need to check it. val dateVal = rs.getDate(pos) if (dateVal != null) { - mutableRow.update(i, DateUtils.fromJavaDate(dateVal)) + mutableRow.setInt(i, DateUtils.fromJavaDate(dateVal)) } else { mutableRow.update(i, null) } @@ -417,7 +417,13 @@ private[sql] class JDBCRDD( case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 case StringConversion => mutableRow.setString(i, rs.getString(pos)) - case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) + case TimestampConversion => + val t = rs.getTimestamp(pos) + if (t != null) { + mutableRow.setLong(i, DateUtils.fromJavaTimestamp(t)) + } else { + mutableRow.update(i, null) + } case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) case BinaryLongConversion => { val bytes = rs.getBytes(pos) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 0e223758051a6..4e07cf36ae434 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.json import java.io.ByteArrayOutputStream -import java.sql.Timestamp import scala.collection.Map @@ -65,10 +64,10 @@ private[sql] object JacksonParser { DateUtils.millisToDays(DateUtils.stringToTime(parser.getText).getTime) case (VALUE_STRING, TimestampType) => - new Timestamp(DateUtils.stringToTime(parser.getText).getTime) + DateUtils.stringToTime(parser.getText).getTime * 10000L case (VALUE_NUMBER_INT, TimestampType) => - new Timestamp(parser.getLongValue) + parser.getLongValue * 10000L case (_, StringType) => val writer = new ByteArrayOutputStream() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 7e1e21f5fbb99..fb0d137bdbfdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.json -import java.sql.Timestamp - import scala.collection.Map import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} @@ -398,11 +396,11 @@ private[sql] object JsonRDD extends Logging { } } - private def toTimestamp(value: Any): Timestamp = { + private def toTimestamp(value: Any): Long = { value match { - case value: java.lang.Integer => new Timestamp(value.asInstanceOf[Int].toLong) - case value: java.lang.Long => new Timestamp(value) - case value: java.lang.String => toTimestamp(DateUtils.stringToTime(value).getTime) + case value: java.lang.Integer => value.asInstanceOf[Int].toLong * 10000L + case value: java.lang.Long => value * 10000L + case value: java.lang.String => DateUtils.stringToTime(value).getTime * 10000L } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 85c2ce740fe52..ddc5097f88fb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -28,6 +28,7 @@ import org.apache.parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Co import org.apache.parquet.schema.MessageType import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.parquet.CatalystConverter.FieldType import org.apache.spark.sql.types._ import org.apache.spark.sql.parquet.timestamp.NanoTime @@ -266,8 +267,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { /** * Read a Timestamp value from a Parquet Int96Value */ - protected[parquet] def readTimestamp(value: Binary): Timestamp = { - CatalystTimestampConverter.convertToTimestamp(value) + protected[parquet] def readTimestamp(value: Binary): Long = { + DateUtils.fromJavaTimestamp(CatalystTimestampConverter.convertToTimestamp(value)) } } @@ -401,7 +402,7 @@ private[parquet] class CatalystPrimitiveRowConverter( current.setInt(fieldIndex, value) override protected[parquet] def updateDate(fieldIndex: Int, value: Int): Unit = - current.update(fieldIndex, value) + current.setInt(fieldIndex, value) override protected[parquet] def updateLong(fieldIndex: Int, value: Long): Unit = current.setLong(fieldIndex, value) @@ -425,7 +426,7 @@ private[parquet] class CatalystPrimitiveRowConverter( current.update(fieldIndex, UTF8String(value)) override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = - current.update(fieldIndex, readTimestamp(value)) + current.setLong(fieldIndex, readTimestamp(value)) override protected[parquet] def updateDecimal( fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 89db408b1c382..e03dbdec0491d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -29,6 +29,7 @@ import org.apache.parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ /** @@ -204,7 +205,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case IntegerType => writer.addInteger(value.asInstanceOf[Int]) case ShortType => writer.addInteger(value.asInstanceOf[Short]) case LongType => writer.addLong(value.asInstanceOf[Long]) - case TimestampType => writeTimestamp(value.asInstanceOf[java.sql.Timestamp]) + case TimestampType => writeTimestamp(value.asInstanceOf[Long]) case ByteType => writer.addInteger(value.asInstanceOf[Byte]) case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) @@ -311,8 +312,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) } - private[parquet] def writeTimestamp(ts: java.sql.Timestamp): Unit = { - val binaryNanoTime = CatalystTimestampConverter.convertFromTimestamp(ts) + private[parquet] def writeTimestamp(ts: Long): Unit = { + val binaryNanoTime = CatalystTimestampConverter.convertFromTimestamp( + DateUtils.toJavaTimestamp(ts)) writer.addBinary(binaryNanoTime) } } @@ -357,7 +359,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case FloatType => writer.addFloat(record.getFloat(index)) case BooleanType => writer.addBoolean(record.getBoolean(index)) case DateType => writer.addInteger(record.getInt(index)) - case TimestampType => writeTimestamp(record(index).asInstanceOf[java.sql.Timestamp]) + case TimestampType => writeTimestamp(record(index).asInstanceOf[Long]) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 72e60d9aa75cb..17a3cec48b856 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ -import org.apache.spark.storage.{RDDBlockId, StorageLevel} +import org.apache.spark.storage.{StorageLevel, RDDBlockId} case class BigData(s: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 339e719f39f16..16836628cb73a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -31,7 +31,7 @@ class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), Row(null, null, 0)) testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) + testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(Long.MaxValue, Long.MinValue, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index a1e76eaa982cc..8421e670ff05d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -18,17 +18,16 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.Timestamp -import com.esotericsoftware.kryo.{Serializer, Kryo} import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.spark.serializer.KryoRegistrator +import com.esotericsoftware.kryo.{Kryo, Serializer} -import org.apache.spark.{Logging, SparkConf, SparkFunSuite} +import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} class ColumnTypeSuite extends SparkFunSuite with Logging { val DEFAULT_BUFFER_SIZE = 512 @@ -36,7 +35,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { test("defaultSize") { val checks = Map( INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, - FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 12, + FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 8, BINARY -> 16, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => @@ -69,7 +68,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(BOOLEAN, true, 1) checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length) checkActualSize(DATE, 0, 4) - checkActualSize(TIMESTAMP, new Timestamp(0L), 12) + checkActualSize(TIMESTAMP, 0L, 8) val binary = Array.fill[Byte](4)(0: Byte) checkActualSize(BINARY, binary, 4 + 4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 75d993e563e06..c5d38595c0bec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.columnar -import java.sql.Timestamp - import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, AtomicType} +import org.apache.spark.sql.types.{AtomicType, DataType, Decimal, UTF8String} object ColumnarTestUtils { def makeNullRow(length: Int): GenericMutableRow = { @@ -52,10 +50,7 @@ object ColumnarTestUtils { case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) case DATE => Random.nextInt() - case TIMESTAMP => - val timestamp = new Timestamp(Random.nextLong()) - timestamp.setNanos(Random.nextInt(999999999)) - timestamp + case TIMESTAMP => Random.nextLong() case _ => // Using a random one-element map instead of an arbitrary object Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 49d348c3ed21b..69ab1c292d221 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -326,7 +326,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(cal.get(Calendar.HOUR) === 11) assert(cal.get(Calendar.MINUTE) === 22) assert(cal.get(Calendar.SECOND) === 33) - assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543) + assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543500) } test("test DATE types") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index d889c7be17ce7..fca24364fe6ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -76,21 +76,25 @@ class JsonSuite extends QueryTest with TestJsonData { checkTypePromotion( Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) - checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(new Timestamp(intNumber.toLong), + checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(intNumber)), + enforceCorrectType(intNumber, TimestampType)) + checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong)), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" - checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) + checkTypePromotion(DateUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)), + enforceCorrectType(strTime, TimestampType)) val strDate = "2014-10-15" checkTypePromotion( DateUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType)) val ISO8601Time1 = "1970-01-01T01:00:01.0Z" - checkTypePromotion(new Timestamp(3601000), enforceCorrectType(ISO8601Time1, TimestampType)) + checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(3601000)), + enforceCorrectType(ISO8601Time1, TimestampType)) checkTypePromotion(DateUtils.millisToDays(3601000), enforceCorrectType(ISO8601Time1, DateType)) val ISO8601Time2 = "1970-01-01T02:00:01-01:00" - checkTypePromotion(new Timestamp(10801000), enforceCorrectType(ISO8601Time2, TimestampType)) + checkTypePromotion(DateUtils.fromJavaTimestamp(new Timestamp(10801000)), + enforceCorrectType(ISO8601Time2, TimestampType)) checkTypePromotion(DateUtils.millisToDays(10801000), enforceCorrectType(ISO8601Time2, DateType)) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 0693c7ea5b332..82c0b494598a8 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -252,7 +252,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "load_dyn_part14.*", // These work alone but fail when run with other tests... // the answer is sensitive for jdk version - "udf_java_method" + "udf_java_method", + + // Spark SQL use Long for TimestampType, lose the precision under 100ns + "timestamp_1", + "timestamp_2" ) /** @@ -795,8 +799,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "stats_publisher_error_1", "subq2", "tablename_with_select", - "timestamp_1", - "timestamp_2", "timestamp_3", "timestamp_comparison", "timestamp_lazy", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index c466203cd0220..1f14cba78f479 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -250,7 +250,8 @@ private[hive] trait HiveInspectors { PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, poi.getWritableConstantValue.getHiveDecimal) case poi: WritableConstantTimestampObjectInspector => - poi.getWritableConstantValue.getTimestamp.clone() + val t = poi.getWritableConstantValue + t.getSeconds * 10000000L + t.getNanos / 100L case poi: WritableConstantIntObjectInspector => poi.getWritableConstantValue.get() case poi: WritableConstantDoubleObjectInspector => @@ -313,11 +314,11 @@ private[hive] trait HiveInspectors { case x: DateObjectInspector if x.preferWritable() => DateUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get()) case x: DateObjectInspector => DateUtils.fromJavaDate(x.getPrimitiveJavaObject(data)) - // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object - // if next timestamp is null, so Timestamp object is cloned case x: TimestampObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).getTimestamp.clone() - case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone() + val t = x.getPrimitiveWritableObject(data) + t.getSeconds * 10000000L + t.getNanos / 100 + case ti: TimestampObjectInspector => + DateUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data)) case _ => pi.getPrimitiveJavaObject(data) } case li: ListObjectInspector => @@ -356,6 +357,9 @@ private[hive] trait HiveInspectors { case _: JavaDateObjectInspector => (o: Any) => DateUtils.toJavaDate(o.asInstanceOf[Int]) + case _: JavaTimestampObjectInspector => + (o: Any) => DateUtils.toJavaTimestamp(o.asInstanceOf[Long]) + case soi: StandardStructObjectInspector => val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) (o: Any) => { @@ -465,7 +469,7 @@ private[hive] trait HiveInspectors { case _: DateObjectInspector if x.preferWritable() => getDateWritable(a) case _: DateObjectInspector => DateUtils.toJavaDate(a.asInstanceOf[Int]) case _: TimestampObjectInspector if x.preferWritable() => getTimestampWritable(a) - case _: TimestampObjectInspector => a.asInstanceOf[java.sql.Timestamp] + case _: TimestampObjectInspector => DateUtils.toJavaTimestamp(a.asInstanceOf[Long]) } case x: SettableStructObjectInspector => val fieldRefs = x.getAllStructFieldRefs @@ -727,7 +731,7 @@ private[hive] trait HiveInspectors { TypeInfoFactory.voidTypeInfo, null) private def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].getBytes) private def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) @@ -776,7 +780,7 @@ private[hive] trait HiveInspectors { if (value == null) { null } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) + new hiveIo.TimestampWritable(DateUtils.toJavaTimestamp(value.asInstanceOf[Long])) } private def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 334bfccc9d200..d3c82d8c2e326 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -363,10 +363,10 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) case oi: TimestampObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, oi.getPrimitiveJavaObject(value).clone()) + row.setLong(ordinal, DateUtils.fromJavaTimestamp(oi.getPrimitiveJavaObject(value))) case oi: DateObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.update(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) + row.setInt(ordinal, DateUtils.fromJavaDate(oi.getPrimitiveJavaObject(value))) case oi: BinaryObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, oi.getPrimitiveJavaObject(value)) diff --git a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 index 27de46fdf22ac..84a31a5a6970b 100644 --- a/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 +++ b/sql/hive/src/test/resources/golden/timestamp cast #5-0-dbd7bcd167d322d6617b884c02c7f247 @@ -1 +1 @@ --0.0010000000000000009 +-0.001 From 6a47114bc297f0bce874e425feb1c24a5c26cef0 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Wed, 10 Jun 2015 18:19:12 -0700 Subject: [PATCH 014/151] [SPARK-8285] [SQL] CombineSum should be calculated as unlimited decimal first case cs CombineSum(expr) => val calcType = expr.dataType expr.dataType match { case DecimalType.Fixed(_, _) => DecimalType.Unlimited case _ => expr.dataType } calcType is always expr.dataType. credits are all belong to IntelliJ Author: navis.ryu Closes #6736 from navis/SPARK-8285 and squashes the following commits: 20382c1 [navis.ryu] [SPARK-8285] [SQL] CombineSum should be calculated as unlimited decimal first --- .../org/apache/spark/sql/execution/GeneratedAggregate.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 3e27c1bde2dfd..af3791734d0c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -118,7 +118,7 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) case cs @ CombineSum(expr) => - val calcType = expr.dataType + val calcType = expr.dataType match { case DecimalType.Fixed(_, _) => DecimalType.Unlimited @@ -129,7 +129,7 @@ case class GeneratedAggregate( val currentSum = AttributeReference("currentSum", calcType, nullable = true)() val initialValue = Literal.create(null, calcType) - // Coalasce avoids double calculation... + // Coalesce avoids double calculation... // but really, common sub expression elimination would be better.... val zero = Cast(Literal(0), calcType) // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its From 4e42842e82e058d54329bd66185d8a7e77ab335a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 10 Jun 2015 18:22:47 -0700 Subject: [PATCH 015/151] [SPARK-8164] transformExpressions should support nested expression sequence Currently we only support `Seq[Expression]`, we should handle cases like `Seq[Seq[Expression]]` so that we can remove the unnecessary `GroupExpression`. Author: Wenchen Fan Closes #6706 from cloud-fan/clean and squashes the following commits: 60a1193 [Wenchen Fan] support nested expression sequence and remove GroupExpression --- .../sql/catalyst/analysis/Analyzer.scala | 6 ++--- .../sql/catalyst/expressions/Expression.scala | 12 ---------- .../spark/sql/catalyst/plans/QueryPlan.scala | 22 +++++++++---------- .../plans/logical/basicOperators.scala | 2 +- .../sql/catalyst/trees/TreeNodeSuite.scala | 14 ++++++++++++ .../apache/spark/sql/execution/Expand.scala | 4 ++-- 6 files changed, 30 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c4f12cfe87993..cbd8def4f1d3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -172,8 +172,8 @@ class Analyzer( * expressions which equal GroupBy expressions with Literal(null), if those expressions * are not set for this grouping set (according to the bit mask). */ - private[this] def expand(g: GroupingSets): Seq[GroupExpression] = { - val result = new scala.collection.mutable.ArrayBuffer[GroupExpression] + private[this] def expand(g: GroupingSets): Seq[Seq[Expression]] = { + val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] g.bitmasks.foreach { bitmask => // get the non selected grouping attributes according to the bit mask @@ -194,7 +194,7 @@ class Analyzer( Literal.create(bitmask, IntegerType) }) - result += GroupExpression(substitution) + result += substitution } result.toSeq diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a05794f1dbd86..63dd5f9854aed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -239,18 +239,6 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } -// TODO Semantically we probably not need GroupExpression -// All we need is holding the Seq[Expression], and ONLY used in doing the -// expressions transformation correctly. Probably will be removed since it's -// not like a real expressions. -case class GroupExpression(children: Seq[Expression]) extends Expression { - self: Product => - override def eval(input: Row): Any = throw new UnsupportedOperationException - override def nullable: Boolean = false - override def foldable: Boolean = false - override def dataType: DataType = throw new UnsupportedOperationException -} - /** * Expressions that require a specific `DataType` as input should implement this trait * so that the proper type conversions can be performed in the analyzer. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index eff5c61644944..2f545bb432165 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -81,17 +81,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } } - val newArgs = productIterator.map { + def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs - case seq: Traversable[_] => seq.map { - case e: Expression => transformExpressionDown(e) - case other => other - } + case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other - }.toArray + } + + val newArgs = productIterator.map(recursiveTransform).toArray if (changed) makeCopy(newArgs) else this } @@ -114,17 +113,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy } } - val newArgs = productIterator.map { + def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs - case seq: Traversable[_] => seq.map { - case e: Expression => transformExpressionUp(e) - case other => other - } + case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other - }.toArray + } + + val newArgs = productIterator.map(recursiveTransform).toArray if (changed) makeCopy(newArgs) else this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e77e5c27b687a..963c7820914f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -226,7 +226,7 @@ case class Window( * @param child Child operator */ case class Expand( - projections: Seq[GroupExpression], + projections: Seq[Seq[Expression]], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def statistics: Statistics = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 67db3d5e6d751..8ec79c3d4d28d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -31,6 +31,11 @@ case class Dummy(optKey: Option[Expression]) extends Expression { override def eval(input: Row): Any = null.asInstanceOf[Any] } +case class ComplexPlan(exprs: Seq[Seq[Expression]]) + extends org.apache.spark.sql.catalyst.plans.logical.LeafNode { + override def output: Seq[Attribute] = Nil +} + class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } @@ -220,4 +225,13 @@ class TreeNodeSuite extends SparkFunSuite { assert(expected === actual) } } + + test("transformExpressions on nested expression sequence") { + val plan = ComplexPlan(Seq(Seq(Literal(1)), Seq(Literal(2)))) + val actual = plan.transformExpressions { + case Literal(value, _) => Literal(value.toString) + } + val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2")))) + assert(expected === actual) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index f16ca36909fab..4b601c11924b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partit */ @DeveloperApi case class Expand( - projections: Seq[GroupExpression], + projections: Seq[Seq[Expression]], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { @@ -49,7 +49,7 @@ case class Expand( // workers via closure. However we can't assume the Projection // is serializable because of the code gen, so we have to // create the projections within each of the partition processing. - val groups = projections.map(ee => newProjection(ee.children, child.output)).toArray + val groups = projections.map(ee => newProjection(ee, child.output)).toArray new Iterator[Row] { private[this] var result: Row = _ From 9fe3adccef687c92ff1ac17d946af089c8e28d66 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 10 Jun 2015 19:55:10 -0700 Subject: [PATCH 016/151] [SPARK-8248][SQL] string function: length Author: Cheng Hao Closes #6724 from chenghao-intel/length and squashes the following commits: aaa3c31 [Cheng Hao] revert the additional change 97148a9 [Cheng Hao] remove the codegen testing temporally ae08003 [Cheng Hao] update the comments 1eb1fd1 [Cheng Hao] simplify the code as commented 3e92d32 [Cheng Hao] use the selectExpr in unit test intead of SQLQuery 3c729aa [Cheng Hao] fix bug for constant null value in codegen 3641f06 [Cheng Hao] keep the length() method for registered function 8e30171 [Cheng Hao] update the code as comment db604ae [Cheng Hao] Add code gen support 548d2ef [Cheng Hao] register the length() 09a0738 [Cheng Hao] add length support --- .../catalyst/analysis/FunctionRegistry.scala | 13 +++++++----- .../sql/catalyst/expressions/Expression.scala | 3 +++ .../expressions/stringOperations.scala | 21 +++++++++++++++++++ .../expressions/StringFunctionsSuite.scala | 12 +++++++++++ .../org/apache/spark/sql/functions.scala | 18 ++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 20 ++++++++++++++++++ 6 files changed, 82 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ba89a5c8d1372..39875d7f216b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -89,14 +89,10 @@ object FunctionRegistry { expression[CreateArray]("array"), expression[Coalesce]("coalesce"), expression[Explode]("explode"), - expression[Lower]("lower"), - expression[Substring]("substr"), - expression[Substring]("substring"), expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), expression[Sqrt]("sqrt"), - expression[Upper]("upper"), // Math functions expression[Acos]("acos"), @@ -132,7 +128,14 @@ object FunctionRegistry { expression[Last]("last"), expression[Max]("max"), expression[Min]("min"), - expression[Sum]("sum") + expression[Sum]("sum"), + + // string functions + expression[Lower]("lower"), + expression[StringLength]("length"), + expression[Substring]("substr"), + expression[Substring]("substring"), + expression[Upper]("upper") ) val builtin: FunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 63dd5f9854aed..8c1e4d74f9df1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -212,6 +212,9 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + /** * Called by unary expressions to generate a code block that returns null if its parent returns * null, and if not not null, use `f` to generate the expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 856f56488c7a5..345038323ddc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -294,3 +294,24 @@ object Substring { apply(str, pos, Literal(Integer.MAX_VALUE)) } } + +/** + * A function that return the length of the given string expression. + */ +case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = IntegerType + override def expectedChildTypes: Seq[DataType] = Seq(StringType) + + override def eval(input: Row): Any = { + val string = child.eval(input) + if (string == null) null else string.asInstanceOf[UTF8String].length + } + + override def toString: String = s"length($child)" + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"($c).length()") + } +} + + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index 2e81296c4e623..d363e631540d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -215,4 +215,16 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { evaluate("abbbbc" rlike regEx, create_row("**")) } } + + test("length for string") { + val regEx = 'a.string.at(0) + checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef")) + checkEvaluation(StringLength(regEx), 5, create_row("abdef")) + checkEvaluation(StringLength(regEx), 0, create_row("")) + checkEvaluation(StringLength(regEx), null, create_row(null)) + // TODO currently bug in codegen, let's temporally disable this + // checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) + } + + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b3fc1e6cd987e..083f6b6bceee8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -37,6 +37,7 @@ import org.apache.spark.util.Utils * @groupname normal_funcs Non-aggregate functions * @groupname math_funcs Math functions * @groupname window_funcs Window functions + * @groupname string_funcs String functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -1317,6 +1318,23 @@ object functions { */ def toRadians(columnName: String): Column = toRadians(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // String functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Computes the length of a given string value + * @group string_funcs + * @since 1.5.0 + */ + def strlen(e: Column): Column = StringLength(e.expr) + + /** + * Computes the length of a given string column + * @group string_funcs + * @since 1.5.0 + */ + def strlen(columnName: String): Column = strlen(Column(columnName)) ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b93ad39f5da45..171a2151e67ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -109,4 +109,24 @@ class DataFrameFunctionsSuite extends QueryTest { testData2.select(bitwiseNOT($"a")), testData2.collect().toSeq.map(r => Row(~r.getInt(0)))) } + + test("length") { + checkAnswer( + nullStrings.select(strlen($"s"), strlen("s")), + nullStrings.collect().toSeq.map { r => + val v = r.getString(1) + val l = if (v == null) null else v.length + Row(l, l) + }) + + checkAnswer( + nullStrings.selectExpr("length(s)"), + nullStrings.collect().toSeq.map { r => + val v = r.getString(1) + val l = if (v == null) null else v.length + Row(l) + }) + } + + } From 2758ff0a96f03a61e10999b2462acf7a13236b7c Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 10 Jun 2015 20:22:32 -0700 Subject: [PATCH 017/151] [SPARK-8217] [SQL] math function log2 Author: Daoyuan Wang This patch had conflicts when merged, resolved by Committer: Reynold Xin Closes #6718 from adrian-wang/udflog2 and squashes the following commits: 3909f48 [Daoyuan Wang] math function: log2 --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 17 ++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 6 ++++++ .../org/apache/spark/sql/functions.scala | 20 +++++++++++++++++-- .../spark/sql/DataFrameFunctionsSuite.scala | 12 +++++++++++ 5 files changed, 54 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 39875d7f216b2..a7816e327526f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -111,6 +111,7 @@ object FunctionRegistry { expression[Log10]("log10"), expression[Log1p]("log1p"), expression[Pi]("pi"), + expression[Log2]("log2"), expression[Pow]("pow"), expression[Rint]("rint"), expression[Signum]("signum"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index e1d8c9a0cdb5a..97e960b8d6422 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -161,6 +161,23 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") +case class Log2(child: Expression) + extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2); + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + } + """ + } +} + case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 1fe69059d39da..864c954ee82cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -185,6 +185,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true) } + test("log2") { + def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) + testUnary(Log2, f, (0 to 20).map(_ * 0.1)) + testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true) + } + test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 083f6b6bceee8..c5b77724aae17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1084,7 +1084,7 @@ object functions { def log(columnName: String): Column = log(Column(columnName)) /** - * Computes the logarithm of the given value in Base 10. + * Computes the logarithm of the given value in base 10. * * @group math_funcs * @since 1.4.0 @@ -1092,7 +1092,7 @@ object functions { def log10(e: Column): Column = Log10(e.expr) /** - * Computes the logarithm of the given value in Base 10. + * Computes the logarithm of the given value in base 10. * * @group math_funcs * @since 1.4.0 @@ -1124,6 +1124,22 @@ object functions { */ def pi(): Column = Pi() + /** + * Computes the logarithm of the given column in base 2. + * + * @group math_funcs + * @since 1.5.0 + */ + def log2(expr: Column): Column = Log2(expr.expr) + + /** + * Computes the logarithm of the given value in base 2. + * + * @group math_funcs + * @since 1.5.0 + */ + def log2(columnName: String): Column = log2(Column(columnName)) + /** * Returns the value of the first argument raised to the power of the second argument. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 171a2151e67ae..659b64c185f43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -128,5 +128,17 @@ class DataFrameFunctionsSuite extends QueryTest { }) } + test("log2 functions test") { + val df = Seq((1, 2)).toDF("a", "b") + checkAnswer( + df.select(log2("b") + log2("a")), + Row(1)) + checkAnswer( + ctx.sql("SELECT LOG2(8)"), + Row(3)) + checkAnswer( + ctx.sql("SELECT LOG2(null)"), + Row(null)) + } } From a777eb04bf981312b640326607158f78dd4163cd Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 10 Jun 2015 21:13:47 -0700 Subject: [PATCH 018/151] [HOTFIX] Adding more contributor name bindings --- dev/create-release/known_translations | 42 +++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index 0a599b5a65549..bbd4330e1c2e5 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -91,3 +91,45 @@ zapletal-martin - Martin Zapletal zuxqoj - Shekhar Bansal mingyukim - Mingyu Kim sigmoidanalytics - Mayur Rustagi +AiHe - Ai He +BenFradet - Ben Fradet +FavioVazquez - Favio Vazquez +JaysonSunshine - Jayson Sunshine +Liuchang0812 - Liu Chang +Sephiroth-Lin - Sephiroth Lin +baishuo - Cheng Lian +daisukebe - Shixiong Zhu +dobashim - Masaru Dobashi +ehnalis - Zoltan Zvara +emres - Emre Sevinc +gchen - Guancheng Chen +haiyangsea - Haiyang Sea +hlin09 - Hao Lin +hqzizania - Qian Huang +jeanlyn - Jean Lyn +jerluc - Jeremy A. Lucas +jrabary - Jaonary Rabarisoa +judynash - Judy Nash +kaka1992 - Chen Song +ksonj - Kalle Jepsen +kuromatsu-nobuyuki - Nobuyuki Kuromatsu +lazyman500 - Dong Xu +leahmcguire - Leah McGuire +mbittmann - Mark Bittmann +mbonaci - Marko Bonaci +meawoppl - Matthew Goodman +nyaapa - Arsenii Krasikov +phatak-dev - Madhukara Phatak +prabeesh - Prabeesh K +rakeshchalasani - Rakesh Chalasani +raschild - Marcelo Vanzin +rekhajoshm - Rekha Joshi +sisihj - June He +szheng79 - Shuai Zheng +ted-yu - Andrew Or +texasmichelle - Michelle Casbon +vinodkc - Vinod KC +yongtang - Yong Tang +ypcat - Pei-Lun Lee +zhichao-li - Zhichao Li +zzcclp - Zhichao Zhang From e84545fa771dde90de5675a9c551fe287af6f7fb Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 10 Jun 2015 22:56:36 -0700 Subject: [PATCH 019/151] [HOTFIX] Fixing errors in name mappings --- dev/create-release/known_translations | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations index bbd4330e1c2e5..5f2671a6e5053 100644 --- a/dev/create-release/known_translations +++ b/dev/create-release/known_translations @@ -97,8 +97,6 @@ FavioVazquez - Favio Vazquez JaysonSunshine - Jayson Sunshine Liuchang0812 - Liu Chang Sephiroth-Lin - Sephiroth Lin -baishuo - Cheng Lian -daisukebe - Shixiong Zhu dobashim - Masaru Dobashi ehnalis - Zoltan Zvara emres - Emre Sevinc @@ -122,11 +120,9 @@ nyaapa - Arsenii Krasikov phatak-dev - Madhukara Phatak prabeesh - Prabeesh K rakeshchalasani - Rakesh Chalasani -raschild - Marcelo Vanzin rekhajoshm - Rekha Joshi sisihj - June He szheng79 - Shuai Zheng -ted-yu - Andrew Or texasmichelle - Michelle Casbon vinodkc - Vinod KC yongtang - Yong Tang From 6b68366df345d4572cf138f9efe17e23d0d1971e Mon Sep 17 00:00:00 2001 From: Adam Roberts Date: Thu, 11 Jun 2015 08:40:46 +0100 Subject: [PATCH 020/151] [SPARK-8289] Specify stack size for consistency with Java tests - resolves test failures This change is a simple one and specifies a stack size of 4096k instead of the vendor default for Java tests (the defaults vary between Java vendors). This remedies test failures observed with JavaALSSuite with IBM and Oracle Java owing to a lower default size in comparison to the size with OpenJDK. 4096k is a suitable default where the tests pass with each Java vendor tested. The alternative is to reduce the number of iterations in the test (no observed failures with 5 iterations instead of 15). -Xss works with Oracle's HotSpot VM, IBM's J9 VM and OpenJDK (IcedTea). I have ensured this does not have any negative implications for other tests. Author: Adam Roberts Author: a-roberts Closes #6727 from a-roberts/IncJavaStackSize and squashes the following commits: ab40aea [Adam Roberts] Specify stack size for SBT builds 5032d8d [a-roberts] Update pom.xml --- pom.xml | 2 +- project/SparkBuild.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index e9700a5d7b149..6d4f717d4931b 100644 --- a/pom.xml +++ b/pom.xml @@ -1244,7 +1244,7 @@ **/*Suite.java ${project.build.directory}/surefire-reports - -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + -Xmx3g -Xss4096k -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m diff --git a/docs/css/main.css b/docs/css/main.css index f6fe7d5f07da1..89305a7d3a358 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -146,3 +146,8 @@ ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu { .MathJax .mi { color: inherit } .MathJax .mf { color: inherit } .MathJax .mh { color: inherit } + +/** + * AnchorJS (anchor links when hovering over headers) + */ +a.anchorjs-link:hover { text-decoration: none; } diff --git a/docs/js/main.js b/docs/js/main.js index f1a90e47e89a7..f5d66b16f7b21 100755 --- a/docs/js/main.js +++ b/docs/js/main.js @@ -68,38 +68,11 @@ function codeTabs() { }); } -function makeCollapsable(elt, accordionClass, accordionBodyId, title) { - $(elt).addClass("accordion-inner"); - $(elt).wrap('
') - $(elt).wrap('
') - $(elt).wrap('
') - $(elt).parent().before( - '
' + - '' + - title + - '' + - '
' - ); -} - -// Enable "view solution" sections (for exercises) -function viewSolution() { - var counter = 0 - $("div.solution").each(function() { - var id = "solution_" + counter - makeCollapsable(this, "", id, - '' + - '' + "View Solution"); - counter++; - }); -} // A script to fix internal hash links because we have an overlapping top bar. // Based on https://github.com/twitter/bootstrap/issues/193#issuecomment-2281510 function maybeScrollToHash() { - console.log("HERE"); if (window.location.hash && $(window.location.hash).length) { - console.log("HERE2", $(window.location.hash), $(window.location.hash).offset().top); var newTop = $(window.location.hash).offset().top - 57; $(window).scrollTop(newTop); } @@ -107,7 +80,12 @@ function maybeScrollToHash() { $(function() { codeTabs(); - viewSolution(); + // Display anchor links when hovering over headers. For documentation of the + // configuration options, see the AnchorJS documentation. + anchors.options = { + placement: 'left' + }; + anchors.add(); $(window).bind('hashchange', function() { maybeScrollToHash(); diff --git a/docs/js/vendor/anchor.min.js b/docs/js/vendor/anchor.min.js new file mode 100755 index 0000000000000..68c3cb7073b6d --- /dev/null +++ b/docs/js/vendor/anchor.min.js @@ -0,0 +1,6 @@ +/*! + * AnchorJS - v1.1.1 - 2015-05-23 + * https://github.com/bryanbraun/anchorjs + * Copyright (c) 2015 Bryan Braun; Licensed MIT + */ +function AnchorJS(A){"use strict";this.options=A||{},this._applyRemainingDefaultOptions=function(A){this.options.icon=this.options.hasOwnProperty("icon")?A.icon:"",this.options.visible=this.options.hasOwnProperty("visible")?A.visible:"hover",this.options.placement=this.options.hasOwnProperty("placement")?A.placement:"right",this.options.class=this.options.hasOwnProperty("class")?A.class:""},this._applyRemainingDefaultOptions(A),this.add=function(A){var e,t,o,n,i,s,a,l,c,r,h,g,B,Q;if(this._applyRemainingDefaultOptions(this.options),A){if("string"!=typeof A)throw new Error("The selector provided to AnchorJS was invalid.")}else A="h1, h2, h3, h4, h5, h6";if(e=document.querySelectorAll(A),0===e.length)return!1;for(this._addBaselineStyles(),t=document.querySelectorAll("[id]"),o=[].map.call(t,function(A){return A.id}),i=0;i',B=document.createElement("div"),B.innerHTML=g,Q=B.childNodes,"always"===this.options.visible&&(Q[0].style.opacity="1"),""===this.options.icon&&(Q[0].style.fontFamily="anchorjs-icons",Q[0].style.fontStyle="normal",Q[0].style.fontVariant="normal",Q[0].style.fontWeight="normal"),"left"===this.options.placement?(Q[0].style.position="absolute",Q[0].style.marginLeft="-1em",Q[0].style.paddingRight="0.5em",e[i].insertBefore(Q[0],e[i].firstChild)):(Q[0].style.paddingLeft="0.375em",e[i].appendChild(Q[0]))}return this},this.remove=function(A){for(var e,t=document.querySelectorAll(A),o=0;o Date: Thu, 18 Jun 2015 16:00:27 -0700 Subject: [PATCH 099/151] [SPARK-8376] [DOCS] Add common lang3 to the Spark Flume Sink doc Commons Lang 3 has been added as one of the dependencies of Spark Flume Sink since #5703. This PR updates the doc for it. Author: zsxwing Closes #6829 from zsxwing/flume-sink-dep and squashes the following commits: f8617f0 [zsxwing] Add common lang3 to the Spark Flume Sink doc --- docs/streaming-flume-integration.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index c8ab146bcae0a..8d6e74370918f 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -99,6 +99,12 @@ Configuring Flume on the chosen machine requires the following two steps. artifactId = scala-library version = {{site.SCALA_VERSION}} + (iii) *Commons Lang 3 JAR*: Download the Commons Lang 3 JAR. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/apache/commons/commons-lang3/3.3.2/commons-lang3-3.3.2.jar)). + + groupId = org.apache.commons + artifactId = commons-lang3 + version = 3.3.2 + 2. **Configuration file**: On that machine, configure Flume agent to send data to an Avro sink by having the following in the configuration file. agent.sinks = spark From 207a98ca59757d9cdd033d0f72863ad9ffb4e4b9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 16:45:14 -0700 Subject: [PATCH 100/151] [SPARK-8446] [SQL] Add helper functions for testing SparkPlan physical operators This patch introduces `SparkPlanTest`, a base class for unit tests of SparkPlan physical operators. This is analogous to Spark SQL's existing `QueryTest`, which does something similar for end-to-end tests with actual queries. These helper methods provide nicer error output when tests fail and help developers to avoid writing lots of boilerplate in order to execute manually constructed physical plans. Author: Josh Rosen Author: Josh Rosen Author: Michael Armbrust Closes #6885 from JoshRosen/spark-plan-test and squashes the following commits: f8ce275 [Josh Rosen] Fix some IntelliJ inspections and delete some dead code 84214be [Josh Rosen] Add an extra column which isn't part of the sort ae1896b [Josh Rosen] Provide implicits automatically a80f9b0 [Josh Rosen] Merge pull request #4 from marmbrus/pr/6885 d9ab1e4 [Michael Armbrust] Add simple resolver c60a44d [Josh Rosen] Manually bind references 996332a [Josh Rosen] Add types so that tests compile a46144a [Josh Rosen] WIP --- .../spark/sql/execution/SortSuite.scala | 44 +++++ .../spark/sql/execution/SparkPlanTest.scala | 167 ++++++++++++++++++ 2 files changed, 211 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala new file mode 100644 index 0000000000000..a1e3ca11b1ad9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class SortSuite extends SparkPlanTest { + + // This test was originally added as an example of how to use [[SparkPlanTest]]; + // it's not designed to be a comprehensive test of ExternalSort. + test("basic sorting using ExternalSort") { + + val input = Seq( + ("Hello", 4, 2.0), + ("Hello", 1, 1.0), + ("World", 8, 3.0) + ) + + checkAnswer( + input.toDF("a", "b", "c"), + ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan), + input.sorted) + + checkAnswer( + input.toDF("a", "b", "c"), + ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan), + input.sortBy(t => (t._2, t._1))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala new file mode 100644 index 0000000000000..13f3be8ca28d6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal + +import org.apache.spark.SparkFunSuite + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.util._ + +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.{DataFrameHolder, Row, DataFrame} + +/** + * Base class for writing tests for individual physical operators. For an example of how this + * class's test helper methods can be used, see [[SortSuite]]. + */ +class SparkPlanTest extends SparkFunSuite { + + /** + * Creates a DataFrame from a local Seq of Product. + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { + TestSQLContext.implicits.localSeqToDataFrameHolder(data) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } +} + +/** + * Helper methods for writing tests of individual physical operators. + */ +object SparkPlanTest { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + def checkAnswer( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[Row]): Option[String] = { + + val outputPlan = planFunction(input.queryExecution.sparkPlan) + + // A very simple resolver to make writing tests easier. In contrast to the real resolver + // this is always case sensitive and does not try to handle scoping or complex type resolution. + val resolvedPlan = outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { + case (a, i) => + (a.name, BoundReference(i, a.dataType, a.nullable)) + }.toMap + + plan.transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + converted.sortBy(_.toString()) + } + + val sparkAnswer: Seq[Row] = try { + resolvedPlan.executeCollect().toSeq + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan: + | $outputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | Results do not match for Spark plan: + | $outputPlan + | == Results == + | ${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) + } + + None + } +} + From dc413138995b45a7a957acae007dc11622110310 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 18 Jun 2015 18:41:15 -0700 Subject: [PATCH 101/151] [SPARK-8218][SQL] Binary log math function update. Some minor updates based on after merging #6725. Author: Reynold Xin Closes #6871 from rxin/log and squashes the following commits: ab51542 [Reynold Xin] Use JVM log 76fc8de [Reynold Xin] Fixed arg. a7c1522 [Reynold Xin] [SPARK-8218][SQL] Binary log math function update. --- python/pyspark/sql/functions.py | 13 +++++++++---- .../spark/sql/catalyst/expressions/math.scala | 4 ++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 177fc196e0834..acdb01d3d3f5f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -404,18 +404,23 @@ def when(condition, value): return Column(jc) -@since(1.4) -def log(col, base=math.e): +@since(1.5) +def log(arg1, arg2=None): """Returns the first argument-based logarithm of the second argument. - >>> df.select(log(df.age, 10.0).alias('ten')).map(lambda l: str(l.ten)[:7]).collect() + If there is only one argument, then this takes the natural logarithm of the argument. + + >>> df.select(log(10.0, df.age).alias('ten')).map(lambda l: str(l.ten)[:7]).collect() ['0.30102', '0.69897'] >>> df.select(log(df.age).alias('e')).map(lambda l: str(l.e)[:7]).collect() ['0.69314', '1.60943'] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.log(base, _to_java_column(col)) + if arg2 is None: + jc = sc._jvm.functions.log(_to_java_column(arg1)) + else: + jc = sc._jvm.functions.log(arg1, _to_java_column(arg2)) return Column(jc) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 3b83c6da0e60c..f79bf4aee00d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -260,6 +260,10 @@ case class Pow(left: Expression, right: Expression) case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { + + /** + * Natural log, i.e. using e as the base. + */ def this(child: Expression) = { this(EulerNumber(), child) } From 43f50decdd20fafc55913c56ffa30f56040090e4 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 18 Jun 2015 19:36:05 -0700 Subject: [PATCH 102/151] [SPARK-8135] Don't load defaults when reconstituting Hadoop Configurations Author: Sandy Ryza Closes #6679 from sryza/sandy-spark-8135 and squashes the following commits: c5554ff [Sandy Ryza] SPARK-8135. In SerializableWritable, don't load defaults when instantiating Configuration --- .../apache/spark/SerializableWritable.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../org/apache/spark/SparkHadoopWriter.scala | 3 +- .../spark/api/python/PythonHadoopUtil.scala | 6 +-- .../apache/spark/api/python/PythonRDD.scala | 12 +++--- .../org/apache/spark/rdd/CheckpointRDD.scala | 11 +++--- .../org/apache/spark/rdd/HadoopRDD.scala | 8 ++-- .../org/apache/spark/rdd/NewHadoopRDD.scala | 4 +- .../apache/spark/rdd/PairRDDFunctions.scala | 6 +-- .../apache/spark/rdd/RDDCheckpointData.scala | 3 +- .../util/SerializableConfiguration.scala | 36 ++++++++++++++++++ .../spark/util/SerializableJobConf.scala | 37 +++++++++++++++++++ .../sql/parquet/ParquetTableOperations.scala | 5 ++- .../apache/spark/sql/parquet/newParquet.scala | 7 ++-- .../sql/sources/DataSourceStrategy.scala | 8 ++-- .../spark/sql/sources/SqlNewHadoopRDD.scala | 4 +- .../apache/spark/sql/sources/commands.scala | 3 +- .../apache/spark/sql/sources/interfaces.scala | 6 +-- .../apache/spark/sql/hive/TableReader.scala | 9 ++--- .../hive/execution/InsertIntoHiveTable.scala | 7 ++-- .../spark/sql/hive/hiveWriterContainers.scala | 3 +- .../spark/sql/hive/orc/OrcRelation.scala | 5 ++- .../streaming/dstream/FileInputDStream.scala | 5 +-- .../dstream/PairDStreamFunctions.scala | 7 ++-- .../rdd/WriteAheadLogBackedBlockRDD.scala | 5 +-- .../streaming/scheduler/ReceiverTracker.scala | 9 +++-- 26 files changed, 146 insertions(+), 67 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala create mode 100644 core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala index cb2cae185256a..beb2e27254725 100644 --- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -41,7 +41,7 @@ class SerializableWritable[T <: Writable](@transient var t: T) extends Serializa private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() val ow = new ObjectWritable() - ow.setConf(new Configuration()) + ow.setConf(new Configuration(false)) ow.readFields(in) t = ow.get().asInstanceOf[T] } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a453c9bf4864a..141276ac901fb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -974,7 +974,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. - val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) + val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) new HadoopRDD( this, diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 59ac82ccec53b..f5dd36cbcfe6d 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD +import org.apache.spark.util.SerializableJobConf /** * Internal helper class that saves an RDD using a Hadoop OutputFormat. @@ -42,7 +43,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) with Serializable { private val now = new Date() - private val conf = new SerializableWritable(jobConf) + private val conf = new SerializableJobConf(jobConf) private var jobID = 0 private var splitID = 0 diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index c9181a29d4756..b959b683d1674 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -19,8 +19,8 @@ package org.apache.spark.api.python import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SerializableWritable, SparkException} +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, SparkException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ import scala.util.{Failure, Success, Try} @@ -61,7 +61,7 @@ private[python] object Converter extends Logging { * Other objects are passed through without conversion. */ private[python] class WritableToJavaConverter( - conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] { + conf: Broadcast[SerializableConfiguration]) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 55a37f8c944b2..dc9f62f39e6d5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -36,7 +36,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import scala.util.control.NonFatal @@ -445,7 +445,7 @@ private[spark] object PythonRDD extends Logging { val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -471,7 +471,7 @@ private[spark] object PythonRDD extends Logging { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -497,7 +497,7 @@ private[spark] object PythonRDD extends Logging { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -540,7 +540,7 @@ private[spark] object PythonRDD extends Logging { val rdd = hadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -566,7 +566,7 @@ private[spark] object PythonRDD extends Logging { val rdd = hadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index a4715e3437d94..33e6998b2cb10 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -21,13 +21,12 @@ import java.io.IOException import scala.reflect.ClassTag -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -38,7 +37,7 @@ private[spark] class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { - val broadcastedConf = sc.broadcast(new SerializableWritable(sc.hadoopConfiguration)) + val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration)) @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) @@ -87,7 +86,7 @@ private[spark] object CheckpointRDD extends Logging { def writeToFile[T: ClassTag]( path: String, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], blockSize: Int = -1 )(ctx: TaskContext, iterator: Iterator[T]) { val env = SparkEnv.get @@ -135,7 +134,7 @@ private[spark] object CheckpointRDD extends Logging { def readFromFile[T]( path: Path, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], context: TaskContext ): Iterator[T] = { val env = SparkEnv.get @@ -164,7 +163,7 @@ private[spark] object CheckpointRDD extends Logging { val path = new Path(hdfsPath, "temp") val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf()) val fs = path.getFileSystem(conf) - val broadcastedConf = sc.broadcast(new SerializableWritable(conf)) + val broadcastedConf = sc.broadcast(new SerializableConfiguration(conf)) sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 2cefe63d44b20..bee59a437f120 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -44,7 +44,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{NextIterator, Utils} +import org.apache.spark.util.{SerializableConfiguration, NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} import org.apache.spark.storage.StorageLevel @@ -100,7 +100,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp @DeveloperApi class HadoopRDD[K, V]( @transient sc: SparkContext, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], @@ -121,8 +121,8 @@ class HadoopRDD[K, V]( minPartitions: Int) = { this( sc, - sc.broadcast(new SerializableWritable(conf)) - .asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + sc.broadcast(new SerializableConfiguration(conf)) + .asInstanceOf[Broadcast[SerializableConfiguration]], None /* initLocalJobConfFuncOpt */, inputFormatClass, keyClass, diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 84456d6d868dc..f827270ee6a44 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel @@ -74,7 +74,7 @@ class NewHadoopRDD[K, V]( with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + private val confBroadcast = sc.broadcast(new SerializableConfiguration(conf)) // private val serializableConf = new SerializableWritable(conf) private val jobTrackerId: String = { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index cfd3e26faf2b9..91a6a2d039852 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -44,7 +44,7 @@ import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -1002,7 +1002,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableWritable(job.getConfiguration) + val wrappedConf = new SerializableConfiguration(job.getConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -1065,7 +1065,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val wrappedConf = new SerializableWritable(hadoopConf) + val wrappedConf = new SerializableConfiguration(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 1722c27e55003..acbd31aacdf59 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} +import org.apache.spark.util.SerializableConfiguration /** * Enumeration to manage state transitions of an RDD through checkpointing @@ -91,7 +92,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) // Save to file, and reload it as an RDD val broadcastedConf = rdd.context.broadcast( - new SerializableWritable(rdd.context.hadoopConfiguration)) + new SerializableConfiguration(rdd.context.hadoopConfiguration)) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { rdd.context.cleaner.foreach { cleaner => diff --git a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala new file mode 100644 index 0000000000000..30bcf1d2f24d5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.io.{ObjectInputStream, ObjectOutputStream} + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.util.Utils + +private[spark] +class SerializableConfiguration(@transient var value: Configuration) extends Serializable { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + out.defaultWriteObject() + value.write(out) + } + + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + value = new Configuration(false) + value.readFields(in) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala new file mode 100644 index 0000000000000..afbcc6efc850c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.{ObjectInputStream, ObjectOutputStream} + +import org.apache.hadoop.mapred.JobConf + +import org.apache.spark.util.Utils + +private[spark] +class SerializableJobConf(@transient var value: JobConf) extends Serializable { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + out.defaultWriteObject() + value.write(out) + } + + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + value = new JobConf(false) + value.readFields(in) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 65ecad9878f8e..b30fc171c0af1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -49,7 +49,8 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InternalRow, _} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, SerializableWritable, TaskContext} +import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.util.SerializableConfiguration /** * :: DeveloperApi :: @@ -329,7 +330,7 @@ private[sql] case class InsertIntoParquetTable( job.setOutputKeyClass(keyType) job.setOutputValueClass(classOf[InternalRow]) NewFileOutputFormat.setOutputPath(job, new Path(path)) - val wrappedConf = new SerializableWritable(job.getConfiguration) + val wrappedConf = new SerializableConfiguration(job.getConfiguration) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = sqlContext.sparkContext.newRddId() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 4c702c3b0d43f..c9de45e0ddfbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._ import scala.util.Try import com.google.common.base.Objects -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -42,8 +41,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SerializableWritable, SparkException, Partition => SparkPartition} +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, SparkException, Partition => SparkPartition} private[sql] class DefaultSource extends HadoopFsRelationProvider { override def createRelation( @@ -258,7 +257,7 @@ private[sql] class ParquetRelation2( requiredColumns: Array[String], filters: Array[Filter], inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown // Create the function to set variable Parquet confs at both driver and executor side. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 4cf67439b9b8d..a8f56f4767407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} import org.apache.spark.sql._ @@ -27,9 +28,8 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources} -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Logging, SerializableWritable, TaskContext} /** * A Strategy for planning scans over data sources defined using the sources API. @@ -91,7 +91,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // broadcast HadoopConf. val sharedHadoopConf = SparkHadoopUtil.get.conf val confBroadcast = - t.sqlContext.sparkContext.broadcast(new SerializableWritable(sharedHadoopConf)) + t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) pruneFilterProject( l, projects, @@ -126,7 +126,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Otherwise, the cost of broadcasting HadoopConf in every RDD will be high. val sharedHadoopConf = SparkHadoopUtil.get.conf val confBroadcast = - relation.sqlContext.sparkContext.broadcast(new SerializableWritable(sharedHadoopConf)) + relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) // Builds RDD[Row]s for each selected partition. val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala index ebad0c1564ec0..2bdc341021256 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala @@ -34,7 +34,7 @@ import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.{RDD, HadoopRDD} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import scala.reflect.ClassTag @@ -65,7 +65,7 @@ private[spark] class SqlNewHadoopPartition( */ private[sql] class SqlNewHadoopRDD[K, V]( @transient sc : SparkContext, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], @transient initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index d39a20b388375..c16bd9ae52c81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} +import org.apache.spark.util.SerializableConfiguration private[sql] case class InsertIntoDataSource( logicalRelation: LogicalRelation, @@ -260,7 +261,7 @@ private[sql] abstract class BaseWriterContainer( with Logging with Serializable { - protected val serializableConf = new SerializableWritable(ContextUtil.getConfiguration(job)) + protected val serializableConf = new SerializableConfiguration(ContextUtil.getConfiguration(job)) // This is only used on driver side. @transient private val jobContext: JobContext = job diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 43d3507d7d2ba..7005c7079af91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -27,12 +27,12 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.SerializableWritable import org.apache.spark.sql.execution.RDDConversions import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration /** * ::DeveloperApi:: @@ -518,7 +518,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String], - broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { val inputStatuses = inputPaths.flatMap { input => val path = new Path(input) @@ -648,7 +648,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio requiredColumns: Array[String], filters: Array[Filter], inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { buildScan(requiredColumns, filters, inputFiles) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 485810320f3c1..439f39bafc926 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ @@ -30,12 +29,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} -import org.apache.spark.{Logging, SerializableWritable} +import org.apache.spark.{Logging} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * A trait for subclasses that handle table scans. @@ -72,7 +71,7 @@ class HadoopTableReader( // TODO: set aws s3 credentials. private val _broadcastedHiveConf = - sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf)) + sc.sparkContext.broadcast(new SerializableConfiguration(hiveExtraConf)) override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( @@ -276,7 +275,7 @@ class HadoopTableReader( val rdd = new HadoopRDD( sc.sparkContext, - _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], Some(initializeJobConfFunc), inputFormatClass, classOf[Writable], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 1d306c5d10af8..404bb937aaf87 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -35,9 +35,10 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ -import org.apache.spark.{SerializableWritable, SparkException, TaskContext} +import org.apache.spark.{SparkException, TaskContext} import scala.collection.JavaConversions._ +import org.apache.spark.util.SerializableJobConf private[hive] case class InsertIntoHiveTable( @@ -64,7 +65,7 @@ case class InsertIntoHiveTable( rdd: RDD[InternalRow], valueClass: Class[_], fileSinkConf: FileSinkDesc, - conf: SerializableWritable[JobConf], + conf: SerializableJobConf, writerContainer: SparkHiveWriterContainer): Unit = { assert(valueClass != null, "Output value class not set") conf.value.setOutputValueClass(valueClass) @@ -172,7 +173,7 @@ case class InsertIntoHiveTable( } val jobConf = new JobConf(sc.hiveconf) - val jobConfSer = new SerializableWritable(jobConf) + val jobConfSer = new SerializableJobConf(jobConf) val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ee440e304ec19..0bc69c00c241c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -37,6 +37,7 @@ import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableJobConf /** * Internal helper class that saves an RDD using a Hive OutputFormat. @@ -57,7 +58,7 @@ private[hive] class SparkHiveWriterContainer( PlanUtils.configureOutputJobPropertiesForStorageHandler(tableDesc) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) } - protected val conf = new SerializableWritable(jobConf) + protected val conf = new SerializableJobConf(jobConf) private var jobID = 0 private var splitID = 0 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index f03c4cd54e7e6..77f1ca9ae0875 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -39,7 +39,8 @@ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreType import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.{Logging, SerializableWritable} +import org.apache.spark.{Logging} +import org.apache.spark.util.SerializableConfiguration /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -283,7 +284,7 @@ private[orc] case class OrcTableScan( classOf[Writable] ).asInstanceOf[HadoopRDD[NullWritable, Writable]] - val wrappedConf = new SerializableWritable(conf) + val wrappedConf = new SerializableConfiguration(conf) rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) => val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 6c1fab56740ee..86a8e2beff57c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -26,10 +26,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.spark.{SparkConf, SerializableWritable} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ -import org.apache.spark.util.{TimeStampedHashMap, Utils} +import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Utils} /** * This class represents an input stream that monitors a Hadoop-compatible filesystem for new @@ -78,7 +77,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) extends InputDStream[(K, V)](ssc_) { - private val serializableConfOpt = conf.map(new SerializableWritable(_)) + private val serializableConfOpt = conf.map(new SerializableConfiguration(_)) /** * Minimum duration of remembering the information of selected files. Defaults to 60 seconds. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 358e4c66df7ba..71bec96d46c8d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,10 +24,11 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.spark.{HashPartitioner, Partitioner, SerializableWritable} +import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.StreamingContext.rddToFileName +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. @@ -688,7 +689,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) conf: JobConf = new JobConf(ssc.sparkContext.hadoopConfiguration) ): Unit = ssc.withScope { // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints - val serializableConf = new SerializableWritable(conf) + val serializableConf = new SerializableJobConf(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value) @@ -721,7 +722,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) conf: Configuration = ssc.sparkContext.hadoopConfiguration ): Unit = ssc.withScope { // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints - val serializableConf = new SerializableWritable(conf) + val serializableConf = new SerializableConfiguration(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsNewAPIHadoopFile( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index ffce6a4c3c74c..31ce8e1ec14d7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -23,12 +23,11 @@ import java.util.UUID import scala.reflect.ClassTag import scala.util.control.NonFatal -import org.apache.commons.io.FileUtils - import org.apache.spark._ import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.util._ +import org.apache.spark.util.SerializableConfiguration /** * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]]. @@ -94,7 +93,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( // Hadoop configuration is not serializable, so broadcast it as a serializable. @transient private val hadoopConfig = sc.hadoopConfiguration - private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig) + private val broadcastedHadoopConf = new SerializableConfiguration(hadoopConfig) override def isValid(): Boolean = true diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index f1504b09c9873..e6cdbec11e94c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -21,10 +21,12 @@ import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} +import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} +import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, + StopReceiver} +import org.apache.spark.util.SerializableConfiguration /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -294,7 +296,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } val checkpointDirOption = Option(ssc.checkpointDir) - val serializableHadoopConf = new SerializableWritable(ssc.sparkContext.hadoopConfiguration) + val serializableHadoopConf = + new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) // Function to start the receiver on the worker node val startReceiver = (iterator: Iterator[Receiver[_]]) => { From 4ce3bab89f6bdf6208fdad2fbfaba0b53d1954e3 Mon Sep 17 00:00:00 2001 From: Lars Francke Date: Thu, 18 Jun 2015 19:40:32 -0700 Subject: [PATCH 103/151] [SPARK-8462] [DOCS] Documentation fixes for Spark SQL This fixes various minor documentation issues on the Spark SQL page Author: Lars Francke Closes #6890 from lfrancke/SPARK-8462 and squashes the following commits: dd7e302 [Lars Francke] Merge branch 'master' into SPARK-8462 34eff2c [Lars Francke] Minor documentation fixes --- docs/sql-programming-guide.md | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c6e6ec88a205f..9b5ea394a6efb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -819,8 +819,8 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet") You can also manually specify the data source that will be used along with any extra options that you would like to pass to the data source. Data sources are specified by their fully qualified -name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use the shorted -name (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short +names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types using this syntax.
@@ -828,7 +828,7 @@ using this syntax. {% highlight scala %} val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") -df.select("name", "age").write.format("json").save("namesAndAges.parquet") +df.select("name", "age").write.format("json").save("namesAndAges.json") {% endhighlight %}
@@ -975,7 +975,7 @@ schemaPeople.write().parquet("people.parquet"); // The result of loading a parquet file is also a DataFrame. DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); -//Parquet files can also be registered as tables and then used in SQL statements. +// Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); List teenagerNames = teenagers.javaRDD().map(new Function() { @@ -1059,7 +1059,7 @@ SELECT * FROM parquetTable Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in the path of each partition directory. The Parquet data source is now able to discover and infer -partitioning information automatically. For exmaple, we can store all our previously used +partitioning information automatically. For example, we can store all our previously used population data into a partitioned table using the following directory structure, with two extra columns, `gender` and `country` as partitioning columns: @@ -1125,12 +1125,12 @@ source is now able to automatically detect this case and merge schemas of all th import sqlContext.implicits._ // Create a simple DataFrame, stored into a partition directory -val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") +val df1 = sc.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") df1.write.parquet("data/test_table/key=1") // Create another DataFrame in a new partition directory, // adding a new column and dropping an existing column -val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") +val df2 = sc.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") df2.write.parquet("data/test_table/key=2") // Read the partitioned table @@ -1138,7 +1138,7 @@ val df3 = sqlContext.read.parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together -// with the partiioning column appeared in the partition directory paths. +// with the partitioning column appeared in the partition directory paths. // root // |-- single: int (nullable = true) // |-- double: int (nullable = true) @@ -1169,7 +1169,7 @@ df3 = sqlContext.load("data/test_table", "parquet") df3.printSchema() # The final schema consists of all 3 columns in the Parquet files together -# with the partiioning column appeared in the partition directory paths. +# with the partitioning column appeared in the partition directory paths. # root # |-- single: int (nullable = true) # |-- double: int (nullable = true) @@ -1196,7 +1196,7 @@ df3 <- loadDF(sqlContext, "data/test_table", "parquet") printSchema(df3) # The final schema consists of all 3 columns in the Parquet files together -# with the partiioning column appeared in the partition directory paths. +# with the partitioning column appeared in the partition directory paths. # root # |-- single: int (nullable = true) # |-- double: int (nullable = true) @@ -1253,7 +1253,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` false Turn on Parquet filter pushdown optimization. This feature is turned off by default because of a known - bug in Paruet 1.6.0rc3 (PARQUET-136). + bug in Parquet 1.6.0rc3 (PARQUET-136). However, if your table doesn't contain any nullable string or binary columns, it's still safe to turn this feature on. @@ -1402,7 +1402,7 @@ sqlContext <- sparkRSQL.init(sc) # The path can be either a single text file or a directory storing text files. path <- "examples/src/main/resources/people.json" # Create a DataFrame from the file(s) pointed to by path -people <- jsonFile(sqlContex,t path) +people <- jsonFile(sqlContext, path) # The inferred schema can be visualized using the printSchema() method. printSchema(people) @@ -1474,7 +1474,7 @@ sqlContext.sql("FROM src SELECT key, value").collect().foreach(println) When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be +the `sql` method a `HiveContext` also provides an `hql` method, which allows queries to be expressed in HiveQL. {% highlight java %} @@ -2770,7 +2770,7 @@ from pyspark.sql.types import * MapType - enviroment + environment list(type="map", keyType=keyType, valueType=valueType, valueContainsNull=[valueContainsNull])
Note: The default value of valueContainsNull is True. From 3eaed8769c16e887edb9d54f5816b4ee6da23de5 Mon Sep 17 00:00:00 2001 From: Dibyendu Bhattacharya Date: Thu, 18 Jun 2015 19:58:47 -0700 Subject: [PATCH 104/151] [SPARK-8080] [STREAMING] Receiver.store with Iterator does not give correct count at Spark UI tdas zsxwing this is the new PR for Spark-8080 I have merged https://github.com/apache/spark/pull/6659 Also to mention , for MEMORY_ONLY settings , when Block is not able to unrollSafely to memory if enough space is not there, BlockManager won't try to put the block and ReceivedBlockHandler will throw SparkException as it could not find the block id in PutResult. Thus number of records in block won't be counted if Block failed to unroll in memory. Which is fine. For MEMORY_DISK settings , if BlockManager not able to unroll block to memory, block will still get deseralized to Disk. Same for WAL based store. So for those cases ( storage level = memory + disk ) number of records will be counted even though the block not able to unroll to memory. thus I added the isFullyConsumed in the CountingIterator but have not used it as such case will never happen that block not fully consumed and ReceivedBlockHandler still get the block ID. I have added few test cases to cover those block unrolling scenarios also. Author: Dibyendu Bhattacharya Author: U-PEROOT\UBHATD1 Closes #6707 from dibbhatt/master and squashes the following commits: f6cb6b5 [Dibyendu Bhattacharya] [SPARK-8080][STREAMING] Receiver.store with Iterator does not give correct count at Spark UI f37cfd8 [Dibyendu Bhattacharya] [SPARK-8080][STREAMING] Receiver.store with Iterator does not give correct count at Spark UI 5a8344a [Dibyendu Bhattacharya] [SPARK-8080][STREAMING] Receiver.store with Iterator does not give correct count at Spark UI Count ByteBufferBlock as 1 count fceac72 [Dibyendu Bhattacharya] [SPARK-8080][STREAMING] Receiver.store with Iterator does not give correct count at Spark UI 0153e7e [Dibyendu Bhattacharya] [SPARK-8080][STREAMING] Receiver.store with Iterator does not give correct count at Spark UI Fixed comments given by @zsxwing 4c5931d [Dibyendu Bhattacharya] [SPARK-8080][STREAMING] Receiver.store with Iterator does not give correct count at Spark UI 01e6dc8 [U-PEROOT\UBHATD1] A --- .../receiver/ReceivedBlockHandler.scala | 53 +++++- .../receiver/ReceiverSupervisorImpl.scala | 7 +- .../streaming/ReceivedBlockHandlerSuite.scala | 154 +++++++++++++++++- .../streaming/ReceivedBlockTrackerSuite.scala | 2 +- 4 files changed, 194 insertions(+), 22 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 207d64d9414ee..c8dd6e06812dc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -32,7 +32,10 @@ import org.apache.spark.{Logging, SparkConf, SparkException} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { - def blockId: StreamBlockId // Any implementation of this trait will store a block id + // Any implementation of this trait will store a block id + def blockId: StreamBlockId + // Any implementation of this trait will have to return the number of records + def numRecords: Option[Long] } /** Trait that represents a class that handles the storage of blocks received by receiver */ @@ -51,7 +54,8 @@ private[streaming] trait ReceivedBlockHandler { * that stores the metadata related to storage of blocks using * [[org.apache.spark.streaming.receiver.BlockManagerBasedBlockHandler]] */ -private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockId) +private[streaming] case class BlockManagerBasedStoreResult( + blockId: StreamBlockId, numRecords: Option[Long]) extends ReceivedBlockStoreResult @@ -64,11 +68,20 @@ private[streaming] class BlockManagerBasedBlockHandler( extends ReceivedBlockHandler with Logging { def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + + var numRecords = None: Option[Long] + val putResult: Seq[(BlockId, BlockStatus)] = block match { case ArrayBufferBlock(arrayBuffer) => - blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, tellMaster = true) + numRecords = Some(arrayBuffer.size.toLong) + blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, + tellMaster = true) case IteratorBlock(iterator) => - blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) + val countIterator = new CountingIterator(iterator) + val putResult = blockManager.putIterator(blockId, countIterator, storageLevel, + tellMaster = true) + numRecords = countIterator.count + putResult case ByteBufferBlock(byteBuffer) => blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true) case o => @@ -79,7 +92,7 @@ private[streaming] class BlockManagerBasedBlockHandler( throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") } - BlockManagerBasedStoreResult(blockId) + BlockManagerBasedStoreResult(blockId, numRecords) } def cleanupOldBlocks(threshTime: Long) { @@ -96,6 +109,7 @@ private[streaming] class BlockManagerBasedBlockHandler( */ private[streaming] case class WriteAheadLogBasedStoreResult( blockId: StreamBlockId, + numRecords: Option[Long], walRecordHandle: WriteAheadLogRecordHandle ) extends ReceivedBlockStoreResult @@ -151,12 +165,17 @@ private[streaming] class WriteAheadLogBasedBlockHandler( */ def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + var numRecords = None: Option[Long] // Serialize the block so that it can be inserted into both val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => + numRecords = Some(arrayBuffer.size.toLong) blockManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => - blockManager.dataSerialize(blockId, iterator) + val countIterator = new CountingIterator(iterator) + val serializedBlock = blockManager.dataSerialize(blockId, countIterator) + numRecords = countIterator.count + serializedBlock case ByteBufferBlock(byteBuffer) => byteBuffer case _ => @@ -181,7 +200,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Combine the futures, wait for both to complete, and return the write ahead log record handle val combinedFuture = storeInBlockManagerFuture.zip(storeInWriteAheadLogFuture).map(_._2) val walRecordHandle = Await.result(combinedFuture, blockStoreTimeout) - WriteAheadLogBasedStoreResult(blockId, walRecordHandle) + WriteAheadLogBasedStoreResult(blockId, numRecords, walRecordHandle) } def cleanupOldBlocks(threshTime: Long) { @@ -199,3 +218,23 @@ private[streaming] object WriteAheadLogBasedBlockHandler { new Path(checkpointDir, new Path("receivedData", streamId.toString)).toString } } + +/** + * A utility that will wrap the Iterator to get the count + */ +private class CountingIterator[T](iterator: Iterator[T]) extends Iterator[T] { + private var _count = 0 + + private def isFullyConsumed: Boolean = !iterator.hasNext + + def hasNext(): Boolean = iterator.hasNext + + def count(): Option[Long] = { + if (isFullyConsumed) Some(_count) else None + } + + def next(): T = { + _count += 1 + iterator.next() + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 8be732b64e3a3..6078cdf8f8790 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -137,15 +137,10 @@ private[streaming] class ReceiverSupervisorImpl( blockIdOption: Option[StreamBlockId] ) { val blockId = blockIdOption.getOrElse(nextBlockId) - val numRecords = receivedBlock match { - case ArrayBufferBlock(arrayBuffer) => Some(arrayBuffer.size.toLong) - case _ => None - } - val time = System.currentTimeMillis val blockStoreResult = receivedBlockHandler.storeBlock(blockId, receivedBlock) logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") - + val numRecords = blockStoreResult.numRecords val blockInfo = ReceivedBlockInfo(streamId, numRecords, metadataOption, blockStoreResult) trackerEndpoint.askWithRetry[Boolean](AddBlock(blockInfo)) logDebug(s"Reported block $blockId") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index cca8cedb1d080..6c0c926755c20 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -49,7 +49,6 @@ class ReceivedBlockHandlerSuite val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() - val storageLevel = StorageLevel.MEMORY_ONLY_SER val streamId = 1 val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -57,10 +56,12 @@ class ReceivedBlockHandlerSuite val serializer = new KryoSerializer(conf) val manualClock = new ManualClock val blockManagerSize = 10000000 + val blockManagerBuffer = new ArrayBuffer[BlockManager]() var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null + var storageLevel: StorageLevel = null var tempDirectory: File = null before { @@ -70,20 +71,21 @@ class ReceivedBlockHandlerSuite blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, - blockManagerSize, conf, mapOutputTracker, shuffleManager, - new NioBlockTransferService(conf, securityMgr), securityMgr, 0) - blockManager.initialize("app-id") + storageLevel = StorageLevel.MEMORY_ONLY_SER + blockManager = createBlockManager(blockManagerSize, conf) tempDirectory = Utils.createTempDir() manualClock.setTime(0) } after { - if (blockManager != null) { - blockManager.stop() - blockManager = null + for ( blockManager <- blockManagerBuffer ) { + if (blockManager != null) { + blockManager.stop() + } } + blockManager = null + blockManagerBuffer.clear() if (blockManagerMaster != null) { blockManagerMaster.stop() blockManagerMaster = null @@ -174,6 +176,130 @@ class ReceivedBlockHandlerSuite } } + test("Test Block - count messages") { + // Test count with BlockManagedBasedBlockHandler + testCountWithBlockManagerBasedBlockHandler(true) + // Test count with WriteAheadLogBasedBlockHandler + testCountWithBlockManagerBasedBlockHandler(false) + } + + test("Test Block - isFullyConsumed") { + val sparkConf = new SparkConf() + sparkConf.set("spark.storage.unrollMemoryThreshold", "512") + // spark.storage.unrollFraction set to 0.4 for BlockManager + sparkConf.set("spark.storage.unrollFraction", "0.4") + // Block Manager with 12000 * 0.4 = 4800 bytes of free space for unroll + blockManager = createBlockManager(12000, sparkConf) + + // there is not enough space to store this block in MEMORY, + // But BlockManager will be able to sereliaze this block to WAL + // and hence count returns correct value. + testRecordcount(false, StorageLevel.MEMORY_ONLY, + IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) + + // there is not enough space to store this block in MEMORY, + // But BlockManager will be able to sereliaze this block to DISK + // and hence count returns correct value. + testRecordcount(true, StorageLevel.MEMORY_AND_DISK, + IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) + + // there is not enough space to store this block With MEMORY_ONLY StorageLevel. + // BlockManager will not be able to unroll this block + // and hence it will not tryToPut this block, resulting the SparkException + storageLevel = StorageLevel.MEMORY_ONLY + withBlockManagerBasedBlockHandler { handler => + val thrown = intercept[SparkException] { + storeSingleBlock(handler, IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator)) + } + } + } + + private def testCountWithBlockManagerBasedBlockHandler(isBlockManagerBasedBlockHandler: Boolean) { + // ByteBufferBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None) + // ByteBufferBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None) + // ArrayBufferBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25)) + // ArrayBufferBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25)) + // ArrayBufferBlock-DISK_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY, + ArrayBufferBlock(ArrayBuffer.fill(50)(0)), blockManager, Some(50)) + // ArrayBufferBlock-MEMORY_AND_DISK + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK, + ArrayBufferBlock(ArrayBuffer.fill(75)(0)), blockManager, Some(75)) + // IteratorBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100)) + // IteratorBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100)) + // IteratorBlock-DISK_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY, + IteratorBlock((ArrayBuffer.fill(125)(0)).iterator), blockManager, Some(125)) + // IteratorBlock-MEMORY_AND_DISK + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK, + IteratorBlock((ArrayBuffer.fill(150)(0)).iterator), blockManager, Some(150)) + } + + private def createBlockManager( + maxMem: Long, + conf: SparkConf, + name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { + val transfer = new NioBlockTransferService(conf, securityMgr) + val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, + mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + manager.initialize("app-id") + blockManagerBuffer += manager + manager + } + + /** + * Test storing of data using different types of Handler, StorageLevle and ReceivedBlocks + * and verify the correct record count + */ + private def testRecordcount(isBlockManagedBasedBlockHandler: Boolean, + sLevel: StorageLevel, + receivedBlock: ReceivedBlock, + bManager: BlockManager, + expectedNumRecords: Option[Long] + ) { + blockManager = bManager + storageLevel = sLevel + var bId: StreamBlockId = null + try { + if (isBlockManagedBasedBlockHandler) { + // test received block with BlockManager based handler + withBlockManagerBasedBlockHandler { handler => + val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock) + bId = blockId + assert(blockStoreResult.numRecords === expectedNumRecords, + "Message count not matches for a " + + receivedBlock.getClass.getName + + " being inserted using BlockManagerBasedBlockHandler with " + sLevel) + } + } else { + // test received block with WAL based handler + withWriteAheadLogBasedBlockHandler { handler => + val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock) + bId = blockId + assert(blockStoreResult.numRecords === expectedNumRecords, + "Message count not matches for a " + + receivedBlock.getClass.getName + + " being inserted using WriteAheadLogBasedBlockHandler with " + sLevel) + } + } + } finally { + // Removing the Block Id to use same blockManager for next test + blockManager.removeBlock(bId, true) + } + } + /** * Test storing of data using different forms of ReceivedBlocks and verify that they succeeded * using the given verification function @@ -251,9 +377,21 @@ class ReceivedBlockHandlerSuite (blockIds, storeResults) } + /** Store single block using a handler */ + private def storeSingleBlock( + handler: ReceivedBlockHandler, + block: ReceivedBlock + ): (StreamBlockId, ReceivedBlockStoreResult) = { + val blockId = generateBlockId + val blockStoreResult = handler.storeBlock(blockId, block) + logDebug("Done inserting") + (blockId, blockStoreResult) + } + private def getWriteAheadLogFiles(): Seq[String] = { getLogFilesInDirectory(checkpointDirToLogDir(tempDirectory.toString, streamId)) } private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong) } + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index be305b5e0dfea..f793a12843b2f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -225,7 +225,7 @@ class ReceivedBlockTrackerSuite /** Generate blocks infos using random ids */ def generateBlockInfos(): Seq[ReceivedBlockInfo] = { List.fill(5)(ReceivedBlockInfo(streamId, Some(0L), None, - BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) + BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)), Some(0L)))) } /** Get all the data written in the given write ahead log file. */ From a71cbbdea581573192a59bf8472861c463c40fcb Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 18 Jun 2015 22:01:52 -0700 Subject: [PATCH 105/151] [SPARK-8458] [SQL] Don't strip scheme part of output path when writing ORC files `Path.toUri.getPath` strips scheme part of output path (from `file:///foo` to `/foo`), which causes ORC data source only writes to the file system configured in Hadoop configuration. Should use `Path.toString` instead. Author: Cheng Lian Closes #6892 from liancheng/spark-8458 and squashes the following commits: 87f8199 [Cheng Lian] Don't strip scheme of output path when writing ORC files --- .../main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 77f1ca9ae0875..dbce39f21d271 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -111,7 +111,7 @@ private[orc] class OrcOutputWriter( new OrcOutputFormat().getRecordWriter( new Path(path, filename).getFileSystem(conf), conf.asInstanceOf[JobConf], - new Path(path, filename).toUri.getPath, + new Path(path, filename).toString, Reporter.NULL ).asInstanceOf[RecordWriter[NullWritable, Writable]] } From 754929b153aba3a8f8fbafa1581957da4ccc18be Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 18 Jun 2015 23:13:05 -0700 Subject: [PATCH 106/151] [SPARK-8348][SQL] Add in operator to DataFrame Column I have added it for only Scala. TODO: we should also support `in` operator in Python. Author: Yu ISHIKAWA Closes #6824 from yu-iskw/SPARK-8348 and squashes the following commits: e76d02f [Yu ISHIKAWA] Not use infix notation 6f744ac [Yu ISHIKAWA] Fit the test cases because these used the old test data set. 00077d3 [Yu ISHIKAWA] [SPARK-8348][SQL] Add in operator to DataFrame Column --- .../main/scala/org/apache/spark/sql/Column.scala | 2 +- .../apache/spark/sql/ColumnExpressionSuite.scala | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d3efa83380d04..b4e008a6e8480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -621,7 +621,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.3.0 */ @scala.annotation.varargs - def in(list: Column*): Column = In(expr, list.map(_.expr)) + def in(list: Any*): Column = In(expr, list.map(lit(_).expr)) /** * SQL like expression. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 5a08578e7ba4b..88bb743ab0bc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -296,6 +296,22 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer) } + test("in") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + checkAnswer(df.filter($"a".in(1, 2)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".in(3, 2)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".in(3, 1)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"b".in("y", "x")), + df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".in("z", "x")), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".in("z", "y")), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + } + val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: From a2016b4bc4ef13339f168c3f4e135fa422046137 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 19 Jun 2015 00:07:53 -0700 Subject: [PATCH 107/151] [SPARK-8444] [STREAMING] Adding Python streaming example for queueStream A Python example similar to the existing one for Scala. Author: Bryan Cutler Closes #6884 from BryanCutler/streaming-queueStream-example-8444 and squashes the following commits: 435ba7e [Bryan Cutler] [SPARK-8444] Fixed style checks, increased sleep time to show empty queue 257abb0 [Bryan Cutler] [SPARK-8444] Stop context gracefully, Removed unused import, Added description comment 376ef6e [Bryan Cutler] [SPARK-8444] Fixed bug causing DStream.pprint to append empty parenthesis to output instead of blank line 1ff5f8b [Bryan Cutler] [SPARK-8444] Adding Python streaming example for queue_stream --- .../src/main/python/streaming/queue_stream.py | 50 +++++++++++++++++++ python/pyspark/streaming/dstream.py | 2 +- 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/python/streaming/queue_stream.py diff --git a/examples/src/main/python/streaming/queue_stream.py b/examples/src/main/python/streaming/queue_stream.py new file mode 100644 index 0000000000000..dcd6a0fc6ff91 --- /dev/null +++ b/examples/src/main/python/streaming/queue_stream.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Create a queue of RDDs that will be mapped/reduced one at a time in + 1 second intervals. + + To run this example use + `$ bin/spark-submit examples/src/main/python/streaming/queue_stream.py +""" +import sys +import time + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonStreamingQueueStream") + ssc = StreamingContext(sc, 1) + + # Create the queue through which RDDs can be pushed to + # a QueueInputDStream + rddQueue = [] + for i in xrange(5): + rddQueue += [ssc.sparkContext.parallelize([j for j in xrange(1, 1001)], 10)] + + # Create the QueueInputDStream and use it do some processing + inputStream = ssc.queueStream(rddQueue) + mappedStream = inputStream.map(lambda x: (x % 10, 1)) + reducedStream = mappedStream.reduceByKey(lambda a, b: a + b) + reducedStream.pprint() + + ssc.start() + time.sleep(6) + ssc.stop(stopSparkContext=True, stopGraceFully=True) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index ff097985fae3e..8dcb9645cdc6b 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -176,7 +176,7 @@ def takeAndPrint(time, rdd): print(record) if len(taken) > num: print("...") - print() + print("") self.foreachRDD(takeAndPrint) From fdf63f12490c674cc1877ddf7b70343c4fd6f4f1 Mon Sep 17 00:00:00 2001 From: Kevin Conor Date: Fri, 19 Jun 2015 00:12:20 -0700 Subject: [PATCH 108/151] [SPARK-8339] [PYSPARK] integer division for python 3 Itertools islice requires an integer for the stop argument. Switching to integer division here prevents a ValueError when vs is evaluated above. davies This is my original work, and I license it to the project. Author: Kevin Conor Closes #6794 from kconor/kconor-patch-1 and squashes the following commits: da5e700 [Kevin Conor] Integer division for batch size --- python/pyspark/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d8cdcda3a3783..7f9d0a338d31e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -272,7 +272,7 @@ def dump_stream(self, iterator, stream): if size < best: batch *= 2 elif size > best * 10 and batch > 1: - batch /= 2 + batch //= 2 def __repr__(self): return "AutoBatchedSerializer(%s)" % self.serializer From 54557f353e588f5ff622ab8e67068bab408bce92 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Fri, 19 Jun 2015 09:57:12 +0200 Subject: [PATCH 109/151] [SPARK-8387] [FOLLOWUP ] [WEBUI] Update driver log URL to show only 4096 bytes This is to follow up #6834 , update the driver log URL as well for consistency. Author: Carson Wang Closes #6878 from carsonwang/logUrl and squashes the following commits: 13be948 [Carson Wang] update log URL in YarnClusterSuite a0004f4 [Carson Wang] Update driver log URL to show only 4096 bytes --- .../scheduler/cluster/YarnClusterSchedulerBackend.scala | 5 +++-- .../org/apache/spark/deploy/yarn/YarnClusterSuite.scala | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 1ace1a97d5156..33f580aaebdc0 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -115,8 +115,9 @@ private[spark] class YarnClusterSchedulerBackend( val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" logDebug(s"Base URL for logs: $baseUrl") - driverLogs = Some( - Map("stderr" -> s"$baseUrl/stderr?start=0", "stdout" -> s"$baseUrl/stdout?start=0")) + driverLogs = Some(Map( + "stderr" -> s"$baseUrl/stderr?start=-4096", + "stdout" -> s"$baseUrl/stdout?start=-4096")) } } } catch { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index a0f25ba450068..335e966519c7c 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -376,7 +376,7 @@ private object YarnClusterDriver extends Logging with Matchers { new URL(urlStr) val containerId = YarnSparkHadoopUtil.get.getContainerId val user = Utils.getCurrentUserName() - assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=0")) + assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096")) } } From 93360dc3cd6186e9d33c762d153a829a5882b72b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 19 Jun 2015 11:58:07 +0200 Subject: [PATCH 110/151] [SPARK-7913] [CORE] Make AppendOnlyMap use the same growth strategy of OpenHashSet and consistent exception message This is a follow up PR for #6456 to make AppendOnlyMap consistent with OpenHashSet. /cc srowen andrewor14 Author: zsxwing Closes #6879 from zsxwing/append-only-map and squashes the following commits: 912c0ad [zsxwing] Fix the doc dd4385b [zsxwing] Make AppendOnlyMap use the same growth strategy of OpenHashSet and consistent exception message --- .../apache/spark/util/collection/AppendOnlyMap.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index d215ee43cb539..4c1e16155462e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -32,7 +32,7 @@ import org.apache.spark.annotation.DeveloperApi * size, which is guaranteed to explore all spaces for each key (see * http://en.wikipedia.org/wiki/Quadratic_probing). * - * The map can support up to `536870912 (2 ^ 29)` elements. + * The map can support up to `375809638 (0.7 * 2 ^ 29)` elements. * * TODO: Cache the hash values of each key? java.util.HashMap does that. */ @@ -199,11 +199,8 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) /** Increase table size by 1, rehashing if necessary */ private def incrementSize() { - if (curSize == MAXIMUM_CAPACITY) { - throw new IllegalStateException(s"Can't put more that ${MAXIMUM_CAPACITY} elements") - } curSize += 1 - if (curSize > growThreshold && capacity < MAXIMUM_CAPACITY) { + if (curSize > growThreshold) { growTable() } } @@ -216,7 +213,8 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) /** Double the table's size and re-hash everything */ protected def growTable() { // capacity < MAXIMUM_CAPACITY (2 ^ 29) so capacity * 2 won't overflow - val newCapacity = (capacity * 2).min(MAXIMUM_CAPACITY) + val newCapacity = capacity * 2 + require(newCapacity <= MAXIMUM_CAPACITY, s"Can't contain more than ${growThreshold} elements") val newData = new Array[AnyRef](2 * newCapacity) val newMask = newCapacity - 1 // Insert all our old values into the new array. Note that because our old keys are From ebd363aecde977511469d47fb1ea7cb5df3c3541 Mon Sep 17 00:00:00 2001 From: Jihong MA Date: Fri, 19 Jun 2015 14:05:11 +0200 Subject: [PATCH 111/151] [SPARK-7265] Improving documentation for Spark SQL Hive support Please review this pull request. Author: Jihong MA Closes #5933 from JihongMA/SPARK-7265 and squashes the following commits: dfaa971 [Jihong MA] SPARK-7265 minor fix of the content ace454d [Jihong MA] SPARK-7265 take out PySpark on YARN limitation 9ea0832 [Jihong MA] Merge remote-tracking branch 'upstream/master' d5bf3f5 [Jihong MA] Merge remote-tracking branch 'upstream/master' 7b842e6 [Jihong MA] Merge remote-tracking branch 'upstream/master' 9c84695 [Jihong MA] SPARK-7265 address review comment a399aa6 [Jihong MA] SPARK-7265 Improving documentation for Spark SQL Hive support --- docs/sql-programming-guide.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9b5ea394a6efb..26c036f6648da 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1445,7 +1445,12 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running +the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the +YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the +`spark-submit` command. +
From 47af7c1ebfdbd7637f626ab07bf2bda6534f37ea Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Fri, 19 Jun 2015 14:51:19 +0200 Subject: [PATCH 112/151] =?UTF-8?q?[SPARK-8389]=20[STREAMING]=20[KAFKA]=20?= =?UTF-8?q?Example=20of=20getting=20offset=20ranges=20out=20o=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …f the existing java direct stream api Author: cody koeninger Closes #6846 from koeninger/SPARK-8389 and squashes the following commits: 3f3c57a [cody koeninger] [Streaming][Kafka][SPARK-8389] Example of getting offset ranges out of the existing java direct stream api --- .../kafka/JavaDirectKafkaStreamSuite.java | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index c0669fb336657..3913b711ba28b 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -32,6 +32,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; @@ -65,8 +66,8 @@ public void tearDown() { @Test public void testKafkaStream() throws InterruptedException { - String topic1 = "topic1"; - String topic2 = "topic2"; + final String topic1 = "topic1"; + final String topic2 = "topic2"; String[] topic1data = createTopicAndSendData(topic1); String[] topic2data = createTopicAndSendData(topic2); @@ -87,6 +88,16 @@ public void testKafkaStream() throws InterruptedException { StringDecoder.class, kafkaParams, topicToSet(topic1) + ).transformToPair( + // Make sure you can get offset ranges from the rdd + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges)rdd.rdd()).offsetRanges(); + Assert.assertEquals(offsets[0].topic(), topic1); + return rdd; + } + } ).map( new Function, String>() { @Override From 43c7ec6384e51105dedf3a53354b6a3732cc27b2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 19 Jun 2015 09:46:51 -0700 Subject: [PATCH 113/151] [SPARK-8151] [MLLIB] pipeline components should correctly implement copy Otherwise, extra params get ignored in `PipelineModel.transform`. jkbradley Author: Xiangrui Meng Closes #6622 from mengxr/SPARK-8087 and squashes the following commits: 0e4c8c4 [Xiangrui Meng] fix merge issues 26fc1f0 [Xiangrui Meng] address comments e607a04 [Xiangrui Meng] merge master b85b57e [Xiangrui Meng] fix examples/compile d6f7891 [Xiangrui Meng] rename defaultCopyWithParams to defaultCopy 84ec278 [Xiangrui Meng] remove setter checks due to generics 2cf2ed0 [Xiangrui Meng] snapshot 291814f [Xiangrui Meng] OneVsRest.copy 1dfe3bd [Xiangrui Meng] PipelineModel.copy should copy stages --- .../examples/ml/JavaDeveloperApiExample.java | 5 ++++ .../examples/ml/DeveloperApiExample.scala | 2 ++ .../scala/org/apache/spark/ml/Estimator.scala | 4 +-- .../scala/org/apache/spark/ml/Model.scala | 5 +--- .../scala/org/apache/spark/ml/Pipeline.scala | 6 ++-- .../scala/org/apache/spark/ml/Predictor.scala | 4 +-- .../org/apache/spark/ml/Transformer.scala | 6 ++-- .../spark/ml/classification/Classifier.scala | 1 + .../DecisionTreeClassifier.scala | 2 ++ .../ml/classification/GBTClassifier.scala | 2 ++ .../classification/LogisticRegression.scala | 2 ++ .../spark/ml/classification/OneVsRest.scala | 16 +++++++++- .../RandomForestClassifier.scala | 2 ++ .../BinaryClassificationEvaluator.scala | 2 ++ .../spark/ml/evaluation/Evaluator.scala | 4 +-- .../ml/evaluation/RegressionEvaluator.scala | 4 ++- .../apache/spark/ml/feature/Binarizer.scala | 2 ++ .../apache/spark/ml/feature/Bucketizer.scala | 2 ++ .../spark/ml/feature/ElementwiseProduct.scala | 2 +- .../apache/spark/ml/feature/HashingTF.scala | 4 ++- .../org/apache/spark/ml/feature/IDF.scala | 13 ++++++-- .../spark/ml/feature/OneHotEncoder.scala | 2 ++ .../ml/feature/PolynomialExpansion.scala | 4 ++- .../spark/ml/feature/StandardScaler.scala | 7 +++++ .../spark/ml/feature/StringIndexer.scala | 7 +++++ .../apache/spark/ml/feature/Tokenizer.scala | 4 +++ .../spark/ml/feature/VectorAssembler.scala | 3 ++ .../spark/ml/feature/VectorIndexer.scala | 9 +++++- .../apache/spark/ml/feature/Word2Vec.scala | 7 +++++ .../org/apache/spark/ml/param/params.scala | 15 +++++++--- .../apache/spark/ml/recommendation/ALS.scala | 7 +++++ .../ml/regression/DecisionTreeRegressor.scala | 2 ++ .../spark/ml/regression/GBTRegressor.scala | 2 ++ .../ml/regression/LinearRegression.scala | 2 ++ .../ml/regression/RandomForestRegressor.scala | 2 ++ .../spark/ml/tuning/CrossValidator.scala | 11 +++++++ .../org/apache/spark/mllib/feature/IDF.scala | 2 +- .../apache/spark/mllib/feature/Word2Vec.scala | 2 +- .../apache/spark/ml/param/JavaTestParams.java | 5 ++++ .../org/apache/spark/ml/PipelineSuite.scala | 10 +++++++ .../DecisionTreeClassifierSuite.scala | 12 ++++++-- .../classification/GBTClassifierSuite.scala | 11 +++++++ .../LogisticRegressionSuite.scala | 9 +++++- .../ml/classification/OneVsRestSuite.scala | 30 +++++++++++++++++++ .../RandomForestClassifierSuite.scala | 10 ++++++- .../BinaryClassificationEvaluatorSuite.scala | 28 +++++++++++++++++ .../evaluation/RegressionEvaluatorSuite.scala | 5 ++++ .../spark/ml/feature/BinarizerSuite.scala | 5 ++++ .../spark/ml/feature/BucketizerSuite.scala | 5 ++++ .../spark/ml/feature/HashingTFSuite.scala | 3 +- .../apache/spark/ml/feature/IDFSuite.scala | 8 +++++ .../spark/ml/feature/OneHotEncoderSuite.scala | 5 ++++ .../ml/feature/PolynomialExpansionSuite.scala | 5 ++++ .../spark/ml/feature/StringIndexerSuite.scala | 7 +++++ .../spark/ml/feature/TokenizerSuite.scala | 12 ++++++++ .../ml/feature/VectorAssemblerSuite.scala | 5 ++++ .../spark/ml/feature/VectorIndexerSuite.scala | 7 +++++ .../spark/ml/feature/Word2VecSuite.scala | 8 +++++ .../apache/spark/ml/param/ParamsSuite.scala | 22 +++++++++----- .../apache/spark/ml/param/TestParams.scala | 4 +-- .../ml/param/shared/SharedParamsSuite.scala | 6 ++-- .../spark/ml/tuning/CrossValidatorSuite.scala | 5 +++- 62 files changed, 350 insertions(+), 55 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index ec533d174ebdc..9df26ffca5775 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -156,6 +156,11 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { // Create a model, and return it. return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); } + + @Override + public MyJavaLogisticRegression copy(ParamMap extra) { + return defaultCopy(extra); + } } /** diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 3ee456edbe01e..7b8cc21ed8982 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -130,6 +130,8 @@ private class MyLogisticRegression(override val uid: String) // Create a model, and return it. new MyLogisticRegressionModel(uid, weights).setParent(this) } + + override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index e9a5d7c0e7988..57e416591de69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -78,7 +78,5 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { paramMaps.map(fit(dataset, _)) } - override def copy(extra: ParamMap): Estimator[M] = { - super.copy(extra).asInstanceOf[Estimator[M]] - } + override def copy(extra: ParamMap): Estimator[M] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 186bf7ae7a2f6..252acc156583f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -45,8 +45,5 @@ abstract class Model[M <: Model[M]] extends Transformer { /** Indicates whether this [[Model]] has a corresponding parent. */ def hasParent: Boolean = parent != null - override def copy(extra: ParamMap): M = { - // The default implementation of Params.copy doesn't work for models. - throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)") - } + override def copy(extra: ParamMap): M } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index a9bd28df71ee1..a1f3851d804ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -66,9 +66,7 @@ abstract class PipelineStage extends Params with Logging { outputSchema } - override def copy(extra: ParamMap): PipelineStage = { - super.copy(extra).asInstanceOf[PipelineStage] - } + override def copy(extra: ParamMap): PipelineStage } /** @@ -198,6 +196,6 @@ class PipelineModel private[ml] ( } override def copy(extra: ParamMap): PipelineModel = { - new PipelineModel(uid, stages) + new PipelineModel(uid, stages.map(_.copy(extra))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e752b81a14282..edaa2afb790e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -90,9 +90,7 @@ abstract class Predictor[ copyValues(train(dataset).setParent(this)) } - override def copy(extra: ParamMap): Learner = { - super.copy(extra).asInstanceOf[Learner] - } + override def copy(extra: ParamMap): Learner /** * Train a model using the given dataset and parameters. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index f07f733a5ddb5..3c7bcf7590e6d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -67,9 +67,7 @@ abstract class Transformer extends PipelineStage { */ def transform(dataset: DataFrame): DataFrame - override def copy(extra: ParamMap): Transformer = { - super.copy(extra).asInstanceOf[Transformer] - } + override def copy(extra: ParamMap): Transformer } /** @@ -120,4 +118,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] dataset.withColumn($(outputCol), callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) } + + override def copy(extra: ParamMap): T = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 263d580fe2dd3..14c285dbfc54a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 8030e0728a56c..2dc1824964a42 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -86,6 +86,8 @@ final class DecisionTreeClassifier(override val uid: String) super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, subsamplingRate = 1.0) } + + override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 62f4b51f770e9..554e3b8e052b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -141,6 +141,8 @@ final class GBTClassifier(override val uid: String) val oldModel = oldGBT.run(oldDataset) GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f136bcee9cf2b..2e6eedd45ab07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -220,6 +220,8 @@ class LogisticRegression(override val uid: String) new LogisticRegressionModel(uid, weights.compressed, intercept) } + + override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 825f9ed1b54b2..b657882f8ad3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -24,7 +24,7 @@ import scala.language.existentials import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{DataFrame, Row} @@ -133,6 +133,12 @@ final class OneVsRestModel private[ml] ( aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) .drop(accColName) } + + override def copy(extra: ParamMap): OneVsRestModel = { + val copied = new OneVsRestModel( + uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) + copyValues(copied, extra) + } } /** @@ -209,4 +215,12 @@ final class OneVsRest(override val uid: String) val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this) copyValues(model) } + + override def copy(extra: ParamMap): OneVsRest = { + val copied = defaultCopy(extra).asInstanceOf[OneVsRest] + if (isDefined(classifier)) { + copied.setClassifier($(classifier).copy(extra)) + } + copied + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 852a67e066322..d3c67494a31e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -97,6 +97,8 @@ final class RandomForestClassifier(override val uid: String) oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index f695ddaeefc72..4a82b77f0edcb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -79,4 +79,6 @@ class BinaryClassificationEvaluator(override val uid: String) metrics.unpersist() metric } + + override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index 61e937e693699..e56c946a063e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -46,7 +46,5 @@ abstract class Evaluator extends Params { */ def evaluate(dataset: DataFrame): Double - override def copy(extra: ParamMap): Evaluator = { - super.copy(extra).asInstanceOf[Evaluator] - } + override def copy(extra: ParamMap): Evaluator } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index abb1b35bedea5..8670e9679d055 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.{Param, ParamValidators} +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics @@ -80,4 +80,6 @@ final class RegressionEvaluator(override val uid: String) } metric } + + override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index b06122d733853..46314854d5e3a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -83,4 +83,6 @@ final class Binarizer(override val uid: String) val outputFields = inputFields :+ attr.toStructField() StructType(outputFields) } + + override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index a3d1f6f65ccaf..67e4785bc3553 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -89,6 +89,8 @@ final class Bucketizer(override val uid: String) SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } + + override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra) } private[feature] object Bucketizer { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 1e758cb775de7..a359cb8f37ec3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.param.{ParamMap, Param} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index f936aef80f8af..319d23e46cef4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature @@ -74,4 +74,6 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } + + override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 376b84530cd57..ecde80810580c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -45,9 +45,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol /** @group getParam */ def getMinDocFreq: Int = $(minDocFreq) - /** @group setParam */ - def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) - /** * Validate and transform the input schema. */ @@ -72,6 +69,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) + override def fit(dataset: DataFrame): IDFModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } @@ -82,6 +82,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): IDF = defaultCopy(extra) } /** @@ -109,4 +111,9 @@ class IDFModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): IDFModel = { + val copied = new IDFModel(uid, idfModel) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 8f34878c8d329..3825942795645 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -165,4 +165,6 @@ class OneHotEncoder(override val uid: String) extends Transformer dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata)) } + + override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 442e95820217a..d85e468562d4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType @@ -61,6 +61,8 @@ class PolynomialExpansion(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT() + + override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index b0fd06d84fdb3..ca3c1cfb56b7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -92,6 +92,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) StructType(outputFields) } + + override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) } /** @@ -125,4 +127,9 @@ class StandardScalerModel private[ml] ( val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) StructType(outputFields) } + + override def copy(extra: ParamMap): StandardScalerModel = { + val copied = new StandardScalerModel(uid, scaler) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f4e250757560a..bf7be363b8224 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -83,6 +83,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) } /** @@ -144,4 +146,9 @@ class StringIndexerModel private[ml] ( schema } } + + override def copy(extra: ParamMap): StringIndexerModel = { + val copied = new StringIndexerModel(uid, labels) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 21c15b6c33f6c..5f9f57a2ebcfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -43,6 +43,8 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S } override protected def outputDataType: DataType = new ArrayType(StringType, false) + + override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) } /** @@ -112,4 +114,6 @@ class RegexTokenizer(override val uid: String) } override protected def outputDataType: DataType = new ArrayType(StringType, false) + + override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 229ee27ec5942..9f83c2ee16178 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} @@ -117,6 +118,8 @@ class VectorAssembler(override val uid: String) } StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false)) } + + override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) } private object VectorAssembler { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 1d0f23b4fb3db..f4854a5e4b7b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{IntParam, ParamValidators, Params} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} @@ -131,6 +131,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod SchemaUtils.checkColumnType(schema, $(inputCol), dataType) SchemaUtils.appendColumn(schema, $(outputCol), dataType) } + + override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra) } private object VectorIndexer { @@ -399,4 +401,9 @@ class VectorIndexerModel private[ml] ( val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes) newAttributeGroup.toStructField() } + + override def copy(extra: ParamMap): VectorIndexerModel = { + val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 36f19509f0cfb..6ea6590956300 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -132,6 +132,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) } /** @@ -180,4 +182,9 @@ class Word2VecModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): Word2VecModel = { + val copied = new Word2VecModel(uid, wordVectors) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ba94d6a3a80a9..15ebad8838a2a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -492,13 +492,20 @@ trait Params extends Identifiable with Serializable { /** * Creates a copy of this instance with the same UID and some extra params. - * The default implementation tries to create a new instance with the same UID. + * Subclasses should implement this method and set the return type properly. + * + * @see [[defaultCopy()]] + */ + def copy(extra: ParamMap): Params + + /** + * Default implementation of copy with extra params. + * It tries to create a new instance with the same UID. * Then it copies the embedded and extra parameters over and returns the new instance. - * Subclasses should override this method if the default approach is not sufficient. */ - def copy(extra: ParamMap): Params = { + protected final def defaultCopy[T <: Params](extra: ParamMap): T = { val that = this.getClass.getConstructor(classOf[String]).newInstance(uid) - copyValues(that, extra) + copyValues(that, extra).asInstanceOf[T] } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index df009d855ecbb..2e44cd4cc6a22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -216,6 +216,11 @@ class ALSModel private[ml] ( SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } + + override def copy(extra: ParamMap): ALSModel = { + val copied = new ALSModel(uid, rank, userFactors, itemFactors) + copyValues(copied, extra) + } } @@ -330,6 +335,8 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): ALS = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 43b68e7bb20fa..be1f8063d41d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -76,6 +76,8 @@ final class DecisionTreeRegressor(override val uid: String) super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, subsamplingRate = 1.0) } + + override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index b7e374bb6cb49..036e3acb07412 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -131,6 +131,8 @@ final class GBTRegressor(override val uid: String) val oldModel = oldGBT.run(oldDataset) GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 70cd8e9e87fae..01306545fc7cd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -186,6 +186,8 @@ class LinearRegression(override val uid: String) // TODO: Converts to sparse format based on the storage, but may base on the scoring speed. copyValues(new LinearRegressionModel(uid, weights.compressed, intercept)) } + + override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 49a1f7ce8c995..21c59061a02fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -86,6 +86,8 @@ final class RandomForestRegressor(override val uid: String) oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index cb29392e8bc63..e2444ab65b43b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -149,6 +149,17 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM est.copy(paramMap).validateParams() } } + + override def copy(extra: ParamMap): CrossValidator = { + val copied = defaultCopy(extra).asInstanceOf[CrossValidator] + if (copied.isDefined(estimator)) { + copied.setEstimator(copied.getEstimator.copy(extra)) + } + if (copied.isDefined(evaluator)) { + copied.setEvaluator(copied.getEvaluator.copy(extra)) + } + copied + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index efbfeb4059f5a..3fab7ea79befc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -159,7 +159,7 @@ private object IDF { * Represents an IDF model that can transform term frequency vectors. */ @Experimental -class IDFModel private[mllib] (val idf: Vector) extends Serializable { +class IDFModel private[spark] (val idf: Vector) extends Serializable { /** * Transforms term frequency (TF) vectors to TF-IDF vectors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 51546d41c36a6..f087d06d2a46a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -431,7 +431,7 @@ class Word2Vec extends Serializable with Logging { * Word2Vec model */ @Experimental -class Word2VecModel private[mllib] ( +class Word2VecModel private[spark] ( model: Map[String, Array[Float]]) extends Serializable with Saveable { // wordList: Ordered list of words obtained from model. diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index ff5929235ac2c..3ae09d39ef500 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -102,4 +102,9 @@ private void init() { setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0})); } + + @Override + public JavaTestParams copy(ParamMap extra) { + return defaultCopy(extra); + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 29394fefcbc43..63d2fa31c7499 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -24,6 +24,7 @@ import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.HashingTF import org.apache.spark.ml.param.ParamMap import org.apache.spark.sql.DataFrame @@ -84,6 +85,15 @@ class PipelineSuite extends SparkFunSuite { } } + test("PipelineModel.copy") { + val hashingTF = new HashingTF() + .setNumFeatures(100) + val model = new PipelineModel("pipeline", Array[Transformer](hashingTF)) + val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10)) + require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + "copy should handle extra stage params") + } + test("pipeline model constructors") { val transform0 = mock[Transformer] val model1 = mock[MyModel] diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index ae40b0b8ff854..73b4805c4c597 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, - DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeClassifierSuite.compareAPIs @@ -55,6 +55,12 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()) } + test("params") { + ParamsSuite.checkParams(new DecisionTreeClassifier) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)) + ParamsSuite.checkParams(model) + } + ///////////////////////////////////////////////////////////////////////////// // Tests calling train() ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 1302da3c373ff..82c345491bb3c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -51,6 +54,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) } + test("params") { + ParamsSuite.checkParams(new GBTClassifier) + val model = new GBTClassificationModel("gbtc", + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))), + Array(1.0)) + ParamsSuite.checkParams(model) + } + test("Binary classification with continuous features: Log Loss") { val categoricalFeatures = Map.empty[Int, Int] testCombinations.foreach { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a755cac3ea76e..5a6265ea992c6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -62,6 +63,12 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("params") { + ParamsSuite.checkParams(new LogisticRegression) + val model = new LogisticRegressionModel("logReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + test("logistic regression: default params") { val lr = new LogisticRegression assert(lr.getLabelCol === "label") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 1d04ccb509057..75cf5bd4ead4f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -19,15 +19,18 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.Metadata class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -52,6 +55,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { dataset = sqlContext.createDataFrame(rdd) } + test("params") { + ParamsSuite.checkParams(new OneVsRest) + val lrModel = new LogisticRegressionModel("lr", Vectors.dense(0.0), 0.0) + val model = new OneVsRestModel("ovr", Metadata.empty, Array(lrModel)) + ParamsSuite.checkParams(model) + } + test("one-vs-rest: default params") { val numClasses = 3 val ova = new OneVsRest() @@ -102,6 +112,26 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val output = ovr.fit(dataset).transform(dataset) assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) } + + test("OneVsRest.copy and OneVsRestModel.copy") { + val lr = new LogisticRegression() + .setMaxIter(1) + + val ovr = new OneVsRest() + withClue("copy with classifier unset should work") { + ovr.copy(ParamMap(lr.maxIter -> 10)) + } + ovr.setClassifier(lr) + val ovr1 = ovr.copy(ParamMap(lr.maxIter -> 10)) + require(ovr.getClassifier.getOrDefault(lr.maxIter) === 1, "copy should have no side-effects") + require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10, + "copy should handle extra classifier params") + + val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1)) + ovrModel.models.foreach { case m: LogisticRegressionModel => + require(m.getThreshold === 0.1, "copy should handle extra model params") + } + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index eee9355a67be3..1b6b69c7dc71e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -27,7 +29,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * Test suite for [[RandomForestClassifier]]. */ @@ -62,6 +63,13 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses) } + test("params") { + ParamsSuite.checkParams(new RandomForestClassifier) + val model = new RandomForestClassificationModel("rfc", + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)))) + ParamsSuite.checkParams(model) + } + test("Binary classification with continuous features:" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val rf = new RandomForestClassifier() diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala new file mode 100644 index 0000000000000..def869fe66777 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite + +class BinaryClassificationEvaluatorSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new BinaryClassificationEvaluator) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 36a1ac6b7996d..aa722da323935 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -18,12 +18,17 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new RegressionEvaluator) + } + test("Regression Evaluator: default params") { /** * Here is the instruction describing how to export the test data into CSV format diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 7953bd0417191..2086043983661 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -30,6 +31,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4) } + test("params") { + ParamsSuite.checkParams(new Binarizer) + } + test("Binarize continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) val dataFrame: DataFrame = sqlContext.createDataFrame( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 507a8a7db24c7..ec85e0d151e07 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -27,6 +28,10 @@ import org.apache.spark.sql.{DataFrame, Row} class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Bucketizer) + } + test("Bucket continuous features, without -inf,inf") { // Check a set of valid feature values. val splits = Array(-0.5, 0.0, 0.5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 7b2d70e644005..4157b84b29d01 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -28,8 +28,7 @@ import org.apache.spark.util.Utils class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { - val hashingTF = new HashingTF - ParamsSuite.checkParams(hashingTF, 3) + ParamsSuite.checkParams(new HashingTF) } test("hashingTF") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index d83772e8be755..08f80af03429b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -38,6 +40,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("params") { + ParamsSuite.checkParams(new IDF) + val model = new IDFModel("idf", new OldIDFModel(Vectors.dense(1.0))) + ParamsSuite.checkParams(model) + } + test("compute IDF with default parameter") { val numOfFeatures = 4 val data = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 2e5036a844562..65846a846b7b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame @@ -36,6 +37,10 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { indexer.transform(df) } + test("params") { + ParamsSuite.checkParams(new OneHotEncoder) + } + test("OneHotEncoder dropLast = false") { val transformed = stringIndexed() val encoder = new OneHotEncoder() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index feca866cd711d..29eebd8960ebc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.ml.param.ParamsSuite import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite @@ -27,6 +28,10 @@ import org.apache.spark.sql.Row class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new PolynomialExpansion) + } + test("Polynomial expansion with default parameter") { val data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 5f557e16e5150..99f82bea42688 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -19,10 +19,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new StringIndexer) + val model = new StringIndexerModel("indexer", Array("a", "b")) + ParamsSuite.checkParams(model) + } + test("StringIndexer") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index ac279cb3215c2..e5fd21c3f6fca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -20,15 +20,27 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) +class TokenizerSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new Tokenizer) + } +} + class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.ml.feature.RegexTokenizerSuite._ + test("params") { + ParamsSuite.checkParams(new RegexTokenizer) + } + test("RegexTokenizer") { val tokenizer0 = new RegexTokenizer() .setGaps(false) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 489abb5af7130..bb4d5b983e0d4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row @@ -26,6 +27,10 @@ import org.apache.spark.sql.functions.col class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new VectorAssembler) + } + test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 06affc7305cf5..8c85c96d5c6d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -21,6 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -91,6 +92,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { private def getIndexer: VectorIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexed") + test("params") { + ParamsSuite.checkParams(new VectorIndexer) + val model = new VectorIndexerModel("indexer", 1, Map.empty) + ParamsSuite.checkParams(model) + } + test("Cannot fit an empty DataFrame") { val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) val vectorIndexer = getIndexer diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 94ebc3aebfa37..aa6ce533fd885 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -18,13 +18,21 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Word2Vec) + val model = new Word2VecModel("w2v", new OldWord2VecModel(Map("a" -> Array(0.0f)))) + ParamsSuite.checkParams(model) + } + test("Word2Vec") { val sqlContext = new SQLContext(sc) import sqlContext.implicits._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 96094d7a099aa..050d4170ea017 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -205,19 +205,27 @@ class ParamsSuite extends SparkFunSuite { object ParamsSuite extends SparkFunSuite { /** - * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered - * by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as - * the param method name. + * Checks common requirements for [[Params.params]]: + * - params are ordered by names + * - param parent has the same UID as the object's UID + * - param name is the same as the param method name + * - obj.copy should return the same type as the obj */ - def checkParams(obj: Params, expectedNumParams: Int): Unit = { + def checkParams(obj: Params): Unit = { + val clazz = obj.getClass + val params = obj.params - require(params.length === expectedNumParams, - s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.") val paramNames = params.map(_.name) - require(paramNames === paramNames.sorted) + require(paramNames === paramNames.sorted, "params must be ordered by names") params.foreach { p => assert(p.parent === obj.uid) assert(obj.getParam(p.name) === p) + // TODO: Check that setters return self, which needs special handling for generic types. } + + val copyMethod = clazz.getMethod("copy", classOf[ParamMap]) + val copyReturnType = copyMethod.getReturnType + require(copyReturnType === obj.getClass, + s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index a9e78366ad98f..2759248344531 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -38,7 +38,5 @@ class TestParams(override val uid: String) extends Params with HasMaxIter with H require(isDefined(inputCol)) } - override def copy(extra: ParamMap): TestParams = { - super.copy(extra).asInstanceOf[TestParams] - } + override def copy(extra: ParamMap): TestParams = defaultCopy(extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala index eb5408d3fee7c..b3af81a3c60b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala @@ -18,13 +18,15 @@ package org.apache.spark.ml.param.shared import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.param.Params +import org.apache.spark.ml.param.{ParamMap, Params} class SharedParamsSuite extends SparkFunSuite { test("outputCol") { - class Obj(override val uid: String) extends Params with HasOutputCol + class Obj(override val uid: String) extends Params with HasOutputCol { + override def copy(extra: ParamMap): Obj = defaultCopy(extra) + } val obj = new Obj("obj") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 9b3619f0046ea..36af4b34a9e40 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite - import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator} @@ -98,6 +97,8 @@ object CrossValidatorSuite { override def transformSchema(schema: StructType): StructType = { throw new UnsupportedOperationException } + + override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) } class MyEvaluator extends Evaluator { @@ -107,5 +108,7 @@ object CrossValidatorSuite { } override val uid: String = "eval" + + override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) } } From 2c59d5c12a0a02702839bfaf631505b8a311c5a9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 Jun 2015 10:09:31 -0700 Subject: [PATCH 114/151] [SPARK-8207] [SQL] Add math function bin JIRA: https://issues.apache.org/jira/browse/SPARK-8207 Author: Liang-Chi Hsieh Closes #6721 from viirya/expr_bin and squashes the following commits: 07e1c8f [Liang-Chi Hsieh] Remove AbstractUnaryMathExpression and let BIN inherit UnaryExpression. 0677f1a [Liang-Chi Hsieh] For comments. cf62b95 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_bin 0cf20f2 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_bin dea9c12 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_bin d4f4774 [Liang-Chi Hsieh] Add @ignore_unicode_prefix. 7a0196f [Liang-Chi Hsieh] Fix python style. ac2bacd [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_bin a0a2d0f [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_bin 4cb764d [Liang-Chi Hsieh] For comments. 0f78682 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_bin c0c3197 [Liang-Chi Hsieh] Add bin to FunctionRegistry. 824f761 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_bin 50e0c3b [Liang-Chi Hsieh] Add math function bin(a: long): string. --- python/pyspark/sql/functions.py | 14 ++++++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 33 +++++++++++++++++- .../expressions/MathFunctionsSuite.scala | 34 +++++++++++++++---- .../org/apache/spark/sql/functions.scala | 18 ++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 10 ++++++ 6 files changed, 102 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index acdb01d3d3f5f..cfa87aeea193a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -35,6 +35,7 @@ __all__ = [ 'array', 'approxCountDistinct', + 'bin', 'coalesce', 'countDistinct', 'explode', @@ -231,6 +232,19 @@ def approxCountDistinct(col, rsd=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def bin(col): + """Returns the string representation of the binary value of the given column. + + >>> df.select(bin(df.age).alias('c')).collect() + [Row(c=u'10'), Row(c=u'101')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.bin(_to_java_column(col)) + return Column(jc) + + @since(1.4) def coalesce(*cols): """Returns the first column that is not null. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 13b2bb05f5280..79273a78408a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -103,6 +103,7 @@ object FunctionRegistry { expression[Asin]("asin"), expression[Atan]("atan"), expression[Atan2]("atan2"), + expression[Bin]("bin"), expression[Cbrt]("cbrt"), expression[Ceil]("ceil"), expression[Ceil]("ceiling"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index f79bf4aee00d5..250564dc4b818 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import java.lang.{Long => JLong} + import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{DataType, DoubleType} +import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType} +import org.apache.spark.unsafe.types.UTF8String /** * A leaf expression specifically for math constants. Math constants expect no input. @@ -207,6 +210,34 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia override def funcName: String = "toRadians" } +case class Bin(child: Expression) + extends UnaryExpression with Serializable with ExpectsInputTypes { + + val name: String = "BIN" + + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"$name($child)" + + override def expectedChildTypes: Seq[DataType] = Seq(LongType) + override def dataType: DataType = StringType + + def funcName: String = name.toLowerCase + + override def eval(input: catalyst.InternalRow): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + UTF8String.fromString(JLong.toBinaryString(evalE.asInstanceOf[Long])) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c) => + s"${ctx.stringType}.fromString(java.lang.Long.toBinaryString($c))") + } +} //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 21e9b92b7214e..0d1d5ebdff2d5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.types.{DataType, DoubleType, LongType} class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -41,16 +42,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { * Used for testing unary math expressions. * * @param c expression - * @param f The functions in scala.math + * @param f The functions in scala.math or elsewhere used to generate expected results * @param domain The set of values to run the function with * @param expectNull Whether the given values should return null or not * @tparam T Generic type for primitives + * @tparam U Generic type for the output of the given function `f` */ - private def testUnary[T]( + private def testUnary[T, U]( c: Expression => Expression, - f: T => T, + f: T => U, domain: Iterable[T] = (-20 to 20).map(_ * 0.1), - expectNull: Boolean = false): Unit = { + expectNull: Boolean = false, + evalType: DataType = DoubleType): Unit = { if (expectNull) { domain.foreach { value => checkEvaluation(c(Literal(value)), null, EmptyRow) @@ -60,7 +63,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c(Literal(value)), f(value), EmptyRow) } } - checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) + checkEvaluation(c(Literal.create(null, evalType)), null, create_row(null)) } /** @@ -168,7 +171,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("signum") { - testUnary[Double](Signum, math.signum) + testUnary[Double, Double](Signum, math.signum) } test("log") { @@ -186,6 +189,23 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), expectNull = true) } + test("bin") { + testUnary(Bin, java.lang.Long.toBinaryString, (-20 to 20).map(_.toLong), evalType = LongType) + + val row = create_row(null, 12L, 123L, 1234L, -123L) + val l1 = 'a.long.at(0) + val l2 = 'a.long.at(1) + val l3 = 'a.long.at(2) + val l4 = 'a.long.at(3) + val l5 = 'a.long.at(4) + + checkEvaluation(Bin(l1), null, row) + checkEvaluation(Bin(l2), java.lang.Long.toBinaryString(12), row) + checkEvaluation(Bin(l3), java.lang.Long.toBinaryString(123), row) + checkEvaluation(Bin(l4), java.lang.Long.toBinaryString(1234), row) + checkEvaluation(Bin(l5), java.lang.Long.toBinaryString(-123), row) + } + test("log2") { def f: (Double) => Double = (x: Double) => math.log(x) / math.log(2) testUnary(Log2, f, (0 to 20).map(_ * 0.1)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d8a91bead7c33..40ae9f5df8e9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -889,6 +889,24 @@ object functions { */ def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) + /** + * An expression that returns the string representation of the binary value of the given long + * column. For example, bin("12") returns "1100". + * + * @group math_funcs + * @since 1.5.0 + */ + def bin(e: Column): Column = Bin(e.expr) + + /** + * An expression that returns the string representation of the binary value of the given long + * column. For example, bin("12") returns "1100". + * + * @group math_funcs + * @since 1.5.0 + */ + def bin(columnName: String): Column = bin(Column(columnName)) + /** * Computes the cube-root of the given value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index cfd23867a9bba..70819fe287060 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -110,6 +110,16 @@ class DataFrameFunctionsSuite extends QueryTest { testData2.collect().toSeq.map(r => Row(~r.getInt(0)))) } + test("bin") { + val df = Seq[(Integer, Integer)]((12, null)).toDF("a", "b") + checkAnswer( + df.select(bin("a"), bin("b")), + Row("1100", null)) + checkAnswer( + df.selectExpr("bin(a)", "bin(b)"), + Row("1100", null)) + } + test("if function") { val df = Seq((1, 2)).toDF("a", "b") checkAnswer( From 9baf093014a48c5ec49f747773f4500dafdfa4ec Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Fri, 19 Jun 2015 10:47:07 -0700 Subject: [PATCH 115/151] [SPARK-8430] ExternalShuffleBlockResolver of shuffle service should support UnsafeShuffleManager andrewor14 can you take a look?thanks Author: Lianhui Wang Closes #6873 from lianhuiwang/SPARK-8430 and squashes the following commits: 51c47ca [Lianhui Wang] update andrewor's comments 2b27b19 [Lianhui Wang] support UnsafeShuffleManager --- .../spark/network/shuffle/ExternalShuffleBlockResolver.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index dd08e24cade23..022ed88a16480 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -108,7 +108,8 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { return getHashBasedShuffleBlockData(executor, blockId); - } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager)) { + } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager) + || "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager".equals(executor.shuffleManager)) { return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); } else { throw new UnsupportedOperationException( From fe08561e2ee13fc8f641db8b6e6c1499bdfd4d29 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 19 Jun 2015 10:48:16 -0700 Subject: [PATCH 116/151] [SPARK-8476] [CORE] Setters inc/decDiskBytesSpilled in TaskMetrics should also be private. This is a follow-up of [SPARK-3288](https://issues.apache.org/jira/browse/SPARK-3288). Author: Takuya UESHIN Closes #6896 from ueshin/issues/SPARK-8476 and squashes the following commits: 89251d8 [Takuya UESHIN] Make inc/decDiskBytesSpilled in TaskMetrics private[spark]. --- .../main/scala/org/apache/spark/executor/TaskMetrics.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 38b61d7242fce..a3b4561b07e7f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -94,8 +94,8 @@ class TaskMetrics extends Serializable { */ private var _diskBytesSpilled: Long = _ def diskBytesSpilled: Long = _diskBytesSpilled - def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value - def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value + private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value + private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value /** * If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read From 0c32fc125c45e59f06cb55f3ba7da612d840ca86 Mon Sep 17 00:00:00 2001 From: Shilei Date: Fri, 19 Jun 2015 10:49:27 -0700 Subject: [PATCH 117/151] [SPARK-8234][SQL] misc function: md5 Author: Shilei Closes #6779 from qiansl127/MD5 and squashes the following commits: 11fcdb2 [Shilei] Fix the indent 04bd27b [Shilei] Add codegen da60eb3 [Shilei] Remove checkInputDataTypes function 9509ad0 [Shilei] Format code 12c61f4 [Shilei] Accept only BinaryType for Md5 1df0b5b [Shilei] format to scala type 60ccde1 [Shilei] Add more test case b8c73b4 [Shilei] Rewrite the type check for Md5 c166167 [Shilei] Add md5 function --- .../catalyst/analysis/FunctionRegistry.scala | 3 ++ .../spark/sql/catalyst/expressions/misc.scala | 50 +++++++++++++++++++ .../expressions/MiscFunctionsSuite.scala | 32 ++++++++++++ .../org/apache/spark/sql/functions.scala | 21 ++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 11 ++++ 5 files changed, 117 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 79273a78408a9..5fb3369f85d12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -133,6 +133,9 @@ object FunctionRegistry { expression[ToDegrees]("degrees"), expression[ToRadians]("radians"), + // misc functions + expression[Md5]("md5"), + // aggregate functions expression[Average]("avg"), expression[Count]("count"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala new file mode 100644 index 0000000000000..4bee8cb728e5c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types.{BinaryType, StringType, DataType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * A function that calculates an MD5 128-bit checksum and returns it as a hex string + * For input of type [[BinaryType]] + */ +case class Md5(child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = StringType + + override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + UTF8String.fromString(DigestUtils.md5Hex(value.asInstanceOf[Array[Byte]])) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => + "org.apache.spark.unsafe.types.UTF8String.fromString" + + s"(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala new file mode 100644 index 0000000000000..48b84130b4556 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.{StringType, BinaryType} + +class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("md5") { + checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932") + checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + "6ac1e56bc78f031059be7be854522c4c") + checkEvaluation(Md5(Literal.create(null, BinaryType)), null) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 40ae9f5df8e9a..7e7a099a8318b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -36,6 +36,7 @@ import org.apache.spark.util.Utils * @groupname sort_funcs Sorting functions * @groupname normal_funcs Non-aggregate functions * @groupname math_funcs Math functions + * @groupname misc_funcs Misc functions * @groupname window_funcs Window functions * @groupname string_funcs String functions * @groupname Ungrouped Support functions for DataFrames. @@ -1376,6 +1377,26 @@ object functions { */ def toRadians(columnName: String): Column = toRadians(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // Misc functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Calculates the MD5 digest and returns the value as a 32 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def md5(e: Column): Column = Md5(e.expr) + + /** + * Calculates the MD5 digest and returns the value as a 32 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def md5(columnName: String): Column = md5(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 70819fe287060..8b53b384a22fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -133,6 +133,17 @@ class DataFrameFunctionsSuite extends QueryTest { Row("x", "y", null)) } + test("misc md5 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(md5($"a"), md5("b")), + Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) + + checkAnswer( + df.selectExpr("md5(a)", "md5(b)"), + Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) + } + test("string length function") { checkAnswer( nullStrings.select(strlen($"s"), strlen("s")), From a9858036bfd339b47dd6d2ed69ccbb61269c225e Mon Sep 17 00:00:00 2001 From: RJ Nowling Date: Fri, 19 Jun 2015 10:50:44 -0700 Subject: [PATCH 118/151] Add example that reads a local file, writes to a DFS path provided by th... ...e user, reads the file back from the DFS, and compares word counts on the local and DFS versions. Useful for verifying DFS correctness. Author: RJ Nowling Closes #3347 from rnowling/dfs_read_write_test and squashes the following commits: af8ccb7 [RJ Nowling] Don't use java.io.File since DFS may not be POSIX-compatible b0ef9ea [RJ Nowling] Fix string style 07c6132 [RJ Nowling] Fix string style 7d9a8df [RJ Nowling] Fix string style f74c160 [RJ Nowling] Fix else statement style b9edf12 [RJ Nowling] Fix spark wc style 44415b9 [RJ Nowling] Fix local wc style 94a4691 [RJ Nowling] Fix space df59b65 [RJ Nowling] Fix if statements 1b314f0 [RJ Nowling] Add scaladoc a931d70 [RJ Nowling] Fix import order 0c89558 [RJ Nowling] Add example that reads a local file, writes to a DFS path provided by the user, reads the file back from the DFS, and compares word counts on the local and DFS versions. Useful for verifying DFS correctness. --- .../spark/examples/DFSReadWriteTest.scala | 138 ++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala new file mode 100644 index 0000000000000..c05890dfbfde1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.io.File + +import scala.io.Source._ + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.SparkContext._ + +/** + * Simple test for reading and writing to a distributed + * file system. This example does the following: + * + * 1. Reads local file + * 2. Computes word count on local file + * 3. Writes local file to a DFS + * 4. Reads the file back from the DFS + * 5. Computes word count on the file using Spark + * 6. Compares the word count results + */ +object DFSReadWriteTest { + + private var localFilePath: File = new File(".") + private var dfsDirPath: String = "" + + private val NPARAMS = 2 + + private def readFile(filename: String): List[String] = { + val lineIter: Iterator[String] = fromFile(filename).getLines() + val lineList: List[String] = lineIter.toList + lineList + } + + private def printUsage(): Unit = { + val usage: String = "DFS Read-Write Test\n" + + "\n" + + "Usage: localFile dfsDir\n" + + "\n" + + "localFile - (string) local file to use in test\n" + + "dfsDir - (string) DFS directory for read/write tests\n" + + println(usage) + } + + private def parseArgs(args: Array[String]): Unit = { + if (args.length != NPARAMS) { + printUsage() + System.exit(1) + } + + var i = 0 + + localFilePath = new File(args(i)) + if (!localFilePath.exists) { + System.err.println("Given path (" + args(i) + ") does not exist.\n") + printUsage() + System.exit(1) + } + + if (!localFilePath.isFile) { + System.err.println("Given path (" + args(i) + ") is not a file.\n") + printUsage() + System.exit(1) + } + + i += 1 + dfsDirPath = args(i) + } + + def runLocalWordCount(fileContents: List[String]): Int = { + fileContents.flatMap(_.split(" ")) + .flatMap(_.split("\t")) + .filter(_.size > 0) + .groupBy(w => w) + .mapValues(_.size) + .values + .sum + } + + def main(args: Array[String]): Unit = { + parseArgs(args) + + println("Performing local word count") + val fileContents = readFile(localFilePath.toString()) + val localWordCount = runLocalWordCount(fileContents) + + println("Creating SparkConf") + val conf = new SparkConf().setAppName("DFS Read Write Test") + + println("Creating SparkContext") + val sc = new SparkContext(conf) + + println("Writing local file to DFS") + val dfsFilename = dfsDirPath + "/dfs_read_write_test" + val fileRDD = sc.parallelize(fileContents) + fileRDD.saveAsTextFile(dfsFilename) + + println("Reading file from DFS and running Word Count") + val readFileRDD = sc.textFile(dfsFilename) + + val dfsWordCount = readFileRDD + .flatMap(_.split(" ")) + .flatMap(_.split("\t")) + .filter(_.size > 0) + .map(w => (w, 1)) + .countByKey() + .values + .sum + + sc.stop() + + if (localWordCount == dfsWordCount) { + println(s"Success! Local Word Count ($localWordCount) " + + s"and DFS Word Count ($dfsWordCount) agree.") + } else { + println(s"Failure! Local Word Count ($localWordCount) " + + s"and DFS Word Count ($dfsWordCount) disagree.") + } + + } +} From 866816eb97002863ec205d854e1397982aecbc5e Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 19 Jun 2015 10:52:30 -0700 Subject: [PATCH 119/151] [SPARK-7180] [SPARK-8090] [SPARK-8091] Fix a number of SerializationDebugger bugs and limitations This PR solves three SerializationDebugger issues. * SPARK-7180 - SerializationDebugger fails with ArrayOutOfBoundsException * SPARK-8090 - SerializationDebugger does not handle classes with writeReplace correctly * SPARK-8091 - SerializationDebugger does not handle classes with writeObject method The solutions for each are explained as follows * SPARK-7180 - The wrong slot desc was used for getting the value of the fields in the object being tested. * SPARK-8090 - Test the type of the replaced object. * SPARK-8091 - Use a dummy ObjectOutputStream to collect all the objects written by the writeObject() method, and then test those objects as usual. I also added more tests in the testsuite to increase code coverage. For example, added tests for cases where there are not serializability issues. Author: Tathagata Das Closes #6625 from tdas/SPARK-7180 and squashes the following commits: c7cb046 [Tathagata Das] Addressed comments on docs ae212c8 [Tathagata Das] Improved docs 304c97b [Tathagata Das] Fixed build error 26b5179 [Tathagata Das] more tests.....92% line coverage 7e2fdcf [Tathagata Das] Added more tests d1967fb [Tathagata Das] Added comments. da75d34 [Tathagata Das] Removed unnecessary lines. 50a608d [Tathagata Das] Fixed bugs and added support for writeObject --- .../serializer/SerializationDebugger.scala | 112 ++++++++++++++++- .../SerializationDebuggerSuite.scala | 119 +++++++++++++++++- .../spark/streaming/StreamingContext.scala | 4 +- 3 files changed, 223 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index bb5db545531d2..cc2f0506817d3 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField} +import java.io._ import java.lang.reflect.{Field, Method} import java.security.AccessController @@ -62,7 +62,7 @@ private[spark] object SerializationDebugger extends Logging { * * It does not yet handle writeObject override, but that shouldn't be too hard to do either. */ - def find(obj: Any): List[String] = { + private[serializer] def find(obj: Any): List[String] = { new SerializationDebugger().visit(obj, List.empty) } @@ -125,6 +125,12 @@ private[spark] object SerializationDebugger extends Logging { return List.empty } + /** + * Visit an externalizable object. + * Since writeExternal() can choose to add arbitrary objects at the time of serialization, + * the only way to capture all the objects it will serialize is by using a + * dummy ObjectOutput that collects all the relevant objects for further testing. + */ private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] = { val fieldList = new ListObjectOutput @@ -145,17 +151,50 @@ private[spark] object SerializationDebugger extends Logging { // An object contains multiple slots in serialization. // Get the slots and visit fields in all of them. val (finalObj, desc) = findObjectAndDescriptor(o) + + // If the object has been replaced using writeReplace(), + // then call visit() on it again to test its type again. + if (!finalObj.eq(o)) { + return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack) + } + + // Every class is associated with one or more "slots", each slot refers to the parent + // classes of this class. These slots are used by the ObjectOutputStream + // serialization code to recursively serialize the fields of an object and + // its parent classes. For example, if there are the following classes. + // + // class ParentClass(parentField: Int) + // class ChildClass(childField: Int) extends ParentClass(1) + // + // Then serializing the an object Obj of type ChildClass requires first serializing the fields + // of ParentClass (that is, parentField), and then serializing the fields of ChildClass + // (that is, childField). Correspondingly, there will be two slots related to this object: + // + // 1. ParentClass slot, which will be used to serialize parentField of Obj + // 2. ChildClass slot, which will be used to serialize childField fields of Obj + // + // The following code uses the description of each slot to find the fields in the + // corresponding object to visit. + // val slotDescs = desc.getSlotDescs var i = 0 while (i < slotDescs.length) { val slotDesc = slotDescs(i) if (slotDesc.hasWriteObjectMethod) { - // TODO: Handle classes that specify writeObject method. + // If the class type corresponding to current slot has writeObject() defined, + // then its not obvious which fields of the class will be serialized as the writeObject() + // can choose arbitrary fields for serialization. This case is handled separately. + val elem = s"writeObject data (class: ${slotDesc.getName})" + val childStack = visitSerializableWithWriteObjectMethod(finalObj, elem :: stack) + if (childStack.nonEmpty) { + return childStack + } } else { + // Visit all the fields objects of the class corresponding to the current slot. val fields: Array[ObjectStreamField] = slotDesc.getFields val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields) val numPrims = fields.length - objFieldValues.length - desc.getObjFieldValues(finalObj, objFieldValues) + slotDesc.getObjFieldValues(finalObj, objFieldValues) var j = 0 while (j < objFieldValues.length) { @@ -169,18 +208,54 @@ private[spark] object SerializationDebugger extends Logging { } j += 1 } - } i += 1 } return List.empty } + + /** + * Visit a serializable object which has the writeObject() defined. + * Since writeObject() can choose to add arbitrary objects at the time of serialization, + * the only way to capture all the objects it will serialize is by using a + * dummy ObjectOutputStream that collects all the relevant fields for further testing. + * This is similar to how externalizable objects are visited. + */ + private def visitSerializableWithWriteObjectMethod( + o: Object, stack: List[String]): List[String] = { + val innerObjectsCatcher = new ListObjectOutputStream + var notSerializableFound = false + try { + innerObjectsCatcher.writeObject(o) + } catch { + case io: IOException => + notSerializableFound = true + } + + // If something was not serializable, then visit the captured objects. + // Otherwise, all the captured objects are safely serializable, so no need to visit them. + // As an optimization, just added them to the visited list. + if (notSerializableFound) { + val innerObjects = innerObjectsCatcher.outputArray + var k = 0 + while (k < innerObjects.length) { + val childStack = visit(innerObjects(k), stack) + if (childStack.nonEmpty) { + return childStack + } + k += 1 + } + } else { + visited ++= innerObjectsCatcher.outputArray + } + return List.empty + } } /** * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles * writeReplace in Serializable. It starts with the object itself, and keeps calling the - * writeReplace method until there is no more + * writeReplace method until there is no more. */ @tailrec private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = { @@ -220,6 +295,31 @@ private[spark] object SerializationDebugger extends Logging { override def writeByte(i: Int): Unit = {} } + /** An output stream that emulates /dev/null */ + private class NullOutputStream extends OutputStream { + override def write(b: Int) { } + } + + /** + * A dummy [[ObjectOutputStream]] that saves the list of objects written to it and returns + * them through `outputArray`. This works by using the [[ObjectOutputStream]]'s `replaceObject()` + * method which gets called on every object, only if replacing is enabled. So this subclass + * of [[ObjectOutputStream]] enabled replacing, and uses replaceObject to get the objects that + * are being serializabled. The serialized bytes are ignored by sending them to a + * [[NullOutputStream]], which acts like a /dev/null. + */ + private class ListObjectOutputStream extends ObjectOutputStream(new NullOutputStream) { + private val output = new mutable.ArrayBuffer[Any] + this.enableReplaceObject(true) + + def outputArray: Array[Any] = output.toArray + + override def replaceObject(obj: Object): Object = { + output += obj + obj + } + } + /** An implicit class that allows us to call private methods of ObjectStreamClass. */ implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal { def getSlotDescs: Array[ObjectStreamClass] = { diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala index 2707bb53bc383..2d5e9d66b2e15 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{ObjectOutput, ObjectInput} +import java.io._ import org.scalatest.BeforeAndAfterEach @@ -98,7 +98,7 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { } test("externalizable class writing out not serializable object") { - val s = find(new ExternalizableClass) + val s = find(new ExternalizableClass(new SerializableClass2(new NotSerializable))) assert(s.size === 5) assert(s(0).contains("NotSerializable")) assert(s(1).contains("objectField")) @@ -106,6 +106,93 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { assert(s(3).contains("writeExternal")) assert(s(4).contains("ExternalizableClass")) } + + test("externalizable class writing out serializable objects") { + assert(find(new ExternalizableClass(new SerializableClass1)).isEmpty) + } + + test("object containing writeReplace() which returns not serializable object") { + val s = find(new SerializableClassWithWriteReplace(new NotSerializable)) + assert(s.size === 3) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("writeReplace")) + assert(s(2).contains("SerializableClassWithWriteReplace")) + } + + test("object containing writeReplace() which returns serializable object") { + assert(find(new SerializableClassWithWriteReplace(new SerializableClass1)).isEmpty) + } + + test("object containing writeObject() and not serializable field") { + val s = find(new SerializableClassWithWriteObject(new NotSerializable)) + assert(s.size === 3) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("writeObject data")) + assert(s(2).contains("SerializableClassWithWriteObject")) + } + + test("object containing writeObject() and serializable field") { + assert(find(new SerializableClassWithWriteObject(new SerializableClass1)).isEmpty) + } + + test("object of serializable subclass with more fields than superclass (SPARK-7180)") { + // This should not throw ArrayOutOfBoundsException + find(new SerializableSubclass(new SerializableClass1)) + } + + test("crazy nested objects") { + + def findAndAssert(shouldSerialize: Boolean, obj: Any): Unit = { + val s = find(obj) + if (shouldSerialize) { + assert(s.isEmpty) + } else { + assert(s.nonEmpty) + assert(s.head.contains("NotSerializable")) + } + } + + findAndAssert(false, + new SerializableClassWithWriteReplace(new ExternalizableClass(new SerializableSubclass( + new SerializableArray( + Array(new SerializableClass1, new SerializableClass2(new NotSerializable)) + ) + ))) + ) + + findAndAssert(true, + new SerializableClassWithWriteReplace(new ExternalizableClass(new SerializableSubclass( + new SerializableArray( + Array(new SerializableClass1, new SerializableClass2(new SerializableClass1)) + ) + ))) + ) + } + + test("improveException") { + val e = SerializationDebugger.improveException( + new SerializableClass2(new NotSerializable), new NotSerializableException("someClass")) + assert(e.getMessage.contains("someClass")) // original exception message should be present + assert(e.getMessage.contains("SerializableClass2")) // found debug trace should be present + } + + test("improveException with error in debugger") { + // Object that throws exception in the SerializationDebugger + val o = new SerializableClass1 { + private def writeReplace(): Object = { + throw new Exception() + } + } + withClue("requirement: SerializationDebugger should fail trying debug this object") { + intercept[Exception] { + SerializationDebugger.find(o) + } + } + + val originalException = new NotSerializableException("someClass") + // verify thaht original exception is returned on failure + assert(SerializationDebugger.improveException(o, originalException).eq(originalException)) + } } @@ -118,10 +205,34 @@ class SerializableClass2(val objectField: Object) extends Serializable class SerializableArray(val arrayField: Array[Object]) extends Serializable -class ExternalizableClass extends java.io.Externalizable { +class SerializableSubclass(val objectField: Object) extends SerializableClass1 + + +class SerializableClassWithWriteObject(val objectField: Object) extends Serializable { + val serializableObjectField = new SerializableClass1 + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream): Unit = { + oos.defaultWriteObject() + } +} + + +class SerializableClassWithWriteReplace(@transient replacementFieldObject: Object) + extends Serializable { + private def writeReplace(): Object = { + replacementFieldObject + } +} + + +class ExternalizableClass(objectField: Object) extends java.io.Externalizable { + val serializableObjectField = new SerializableClass1 + override def writeExternal(out: ObjectOutput): Unit = { out.writeInt(1) - out.writeObject(new SerializableClass2(new NotSerializable)) + out.writeObject(serializableObjectField) + out.writeObject(objectField) } override def readExternal(in: ObjectInput): Unit = {} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 9cd9684d36404..1708f309fc002 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -549,8 +549,8 @@ class StreamingContext private[streaming] ( case e: NotSerializableException => throw new NotSerializableException( "DStream checkpointing has been enabled but the DStreams with their functions " + - "are not serializable\nSerialization stack:\n" + - SerializationDebugger.find(checkpoint).map("\t- " + _).mkString("\n") + "are not serializable\n" + + SerializationDebugger.improveException(checkpoint, e).getMessage() ) } } From 68a2dca292776d4a3f988353ba55adc73a7c1aa2 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 19 Jun 2015 10:56:19 -0700 Subject: [PATCH 120/151] [SPARK-8451] [SPARK-7287] SparkSubmitSuite should check exit code This patch also reenables the tests. Now that we have access to the log4j logs it should be easier to debug the flakiness. yhuai brkyvz Author: Andrew Or Closes #6886 from andrewor14/spark-submit-suite-fix and squashes the following commits: 3f99ff1 [Andrew Or] Move destroy to finally block 9a62188 [Andrew Or] Re-enable ignored tests 2382672 [Andrew Or] Check for exit code --- .../apache/spark/deploy/SparkSubmitSuite.scala | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 46ea28d0f18f6..357ed90be3f5c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -325,7 +325,7 @@ class SparkSubmitSuite runSparkSubmit(args) } - ignore("includes jars passed in through --jars") { + test("includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) @@ -340,7 +340,7 @@ class SparkSubmitSuite } // SPARK-7287 - ignore("includes jars passed in through --packages") { + test("includes jars passed in through --packages") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val main = MavenCoordinate("my.great.lib", "mylib", "0.1") val dep = MavenCoordinate("my.great.dep", "mylib", "0.1") @@ -499,9 +499,16 @@ class SparkSubmitSuite Seq("./bin/spark-submit") ++ args, new File(sparkHome), Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) - failAfter(60 seconds) { process.waitFor() } - // Ensure we still kill the process in case it timed out - process.destroy() + + try { + val exitCode = failAfter(60 seconds) { process.waitFor() } + if (exitCode != 0) { + fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + } + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } } private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { From 4be53d0395d3c7f61eef6b7d72db078e2e1199a7 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 19 Jun 2015 11:03:04 -0700 Subject: [PATCH 121/151] [SPARK-5836] [DOCS] [STREAMING] Clarify what may cause long-running Spark apps to preserve shuffle files Clarify what may cause long-running Spark apps to preserve shuffle files Author: Sean Owen Closes #6901 from srowen/SPARK-5836 and squashes the following commits: a9faef0 [Sean Owen] Clarify what may cause long-running Spark apps to preserve shuffle files --- docs/programming-guide.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index d5ff416fe89a4..ae712d62746f6 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1144,9 +1144,11 @@ generate these on the reduce side. When data does not fit in memory Spark will s to disk, incurring the additional overhead of disk I/O and increased garbage collection. Shuffle also generates a large number of intermediate files on disk. As of Spark 1.3, these files -are not cleaned up from Spark's temporary storage until Spark is stopped, which means that -long-running Spark jobs may consume available disk space. This is done so the shuffle doesn't need -to be re-computed if the lineage is re-computed. The temporary storage directory is specified by the +are preserved until the corresponding RDDs are no longer used and are garbage collected. +This is done so the shuffle files don't need to be re-created if the lineage is re-computed. +Garbage collection may happen only after a long period time, if the application retains references +to these RDDs or if GC does not kick in frequently. This means that long-running Spark jobs may +consume a large amount of disk space. The temporary storage directory is specified by the `spark.local.dir` configuration parameter when configuring the Spark context. Shuffle behavior can be tuned by adjusting a variety of configuration parameters. See the From c5876e529b8e29b25ca03c3a768c0e4709c9a535 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 19 Jun 2015 11:11:58 -0700 Subject: [PATCH 122/151] [SPARK-8368] [SPARK-8058] [SQL] HiveContext may override the context class loader of the current thread https://issues.apache.org/jira/browse/SPARK-8368 Also, I add tests according https://issues.apache.org/jira/browse/SPARK-8058. Author: Yin Huai Closes #6891 from yhuai/SPARK-8368 and squashes the following commits: 37bb3db [Yin Huai] Update test timeout and comment. 8762eec [Yin Huai] Style. 695cd2d [Yin Huai] Correctly set the class loader in the conf of the state in client wrapper. b3378fe [Yin Huai] Failed tests. --- .../apache/spark/sql/hive/HiveContext.scala | 3 +- .../spark/sql/hive/client/ClientWrapper.scala | 21 +- .../spark/sql/hive/client/HiveShim.scala | 15 +- .../hive/client/IsolatedClientLoader.scala | 13 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 182 ++++++++++++++++++ 5 files changed, 219 insertions(+), 15 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 4a66d6508ae0a..cf05c6c989655 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -158,7 +158,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { logInfo(s"Initializing execution hive, version $hiveExecutionVersion") new ClientWrapper( version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), - config = newTemporaryConfiguration()) + config = newTemporaryConfiguration(), + initClassLoader = Utils.getContextOrSparkClassLoader) } SessionState.setCurrentSessionState(executionHive.state) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 982ed63874a5f..42c2d4c98ffb2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -54,10 +54,13 @@ import org.apache.spark.sql.execution.QueryExecutionException * @param version the version of hive used when pick function calls that are not compatible. * @param config a collection of configuration options that will be added to the hive conf before * opening the hive client. + * @param initClassLoader the classloader used when creating the `state` field of + * this ClientWrapper. */ private[hive] class ClientWrapper( version: HiveVersion, - config: Map[String, String]) + config: Map[String, String], + initClassLoader: ClassLoader) extends ClientInterface with Logging { @@ -98,11 +101,18 @@ private[hive] class ClientWrapper( // Create an internal session state for this ClientWrapper. val state = { val original = Thread.currentThread().getContextClassLoader - Thread.currentThread().setContextClassLoader(getClass.getClassLoader) + // Switch to the initClassLoader. + Thread.currentThread().setContextClassLoader(initClassLoader) val ret = try { val oldState = SessionState.get() if (oldState == null) { val initialConf = new HiveConf(classOf[SessionState]) + // HiveConf is a Hadoop Configuration, which has a field of classLoader and + // the initial value will be the current thread's context class loader + // (i.e. initClassLoader at here). + // We call initialConf.setClassLoader(initClassLoader) at here to make + // this action explicit. + initialConf.setClassLoader(initClassLoader) config.foreach { case (k, v) => logDebug(s"Hive Config: $k=$v") initialConf.set(k, v) @@ -125,6 +135,7 @@ private[hive] class ClientWrapper( def conf: HiveConf = SessionState.get().getConf // TODO: should be a def?s + // When we create this val client, the HiveConf of it (conf) is the one associated with state. private val client = Hive.get(conf) /** @@ -132,13 +143,9 @@ private[hive] class ClientWrapper( */ private def withHiveState[A](f: => A): A = synchronized { val original = Thread.currentThread().getContextClassLoader - // This setContextClassLoader is used for Hive 0.12's metastore since Hive 0.12 will not - // internally override the context class loader of the current thread with the class loader - // associated with the HiveConf in `state`. - Thread.currentThread().setContextClassLoader(getClass.getClassLoader) // Set the thread local metastore client to the client associated with this ClientWrapper. Hive.set(client) - // Starting from Hive 0.13.0, setCurrentSessionState will use the classLoader associated + // setCurrentSessionState will use the classLoader associated // with the HiveConf in `state` to override the context class loader of the current // thread. shim.setCurrentSessionState(state) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 40c167926c8d6..5ae2dbb50d86b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -43,6 +43,11 @@ import org.apache.hadoop.hive.ql.session.SessionState */ private[client] sealed abstract class Shim { + /** + * Set the current SessionState to the given SessionState. Also, set the context classloader of + * the current thread to the one set in the HiveConf of this given `state`. + * @param state + */ def setCurrentSessionState(state: SessionState): Unit /** @@ -159,7 +164,15 @@ private[client] class Shim_v0_12 extends Shim { JBoolean.TYPE, JBoolean.TYPE) - override def setCurrentSessionState(state: SessionState): Unit = startMethod.invoke(null, state) + override def setCurrentSessionState(state: SessionState): Unit = { + // Starting from Hive 0.13, setCurrentSessionState will internally override + // the context class loader of the current thread by the class loader set in + // the conf of the SessionState. So, for this Hive 0.12 shim, we add the same + // behavior and make shim.setCurrentSessionState of all Hive versions have the + // consistent behavior. + Thread.currentThread().setContextClassLoader(state.getConf.getClassLoader) + startMethod.invoke(null, state) + } override def getDataLocation(table: Table): Option[String] = Option(getDataLocationMethod.invoke(table)).map(_.toString()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 69cfc5c3c3216..0934ad5034671 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -95,9 +95,8 @@ private[hive] object IsolatedClientLoader { * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. - * @param rootClassLoader The system root classloader. - * @param baseClassLoader The spark classloader that is used to load shared classes. Must not know - * about Hive classes. + * @param rootClassLoader The system root classloader. Must not know about Hive classes. + * @param baseClassLoader The spark classloader that is used to load shared classes. */ private[hive] class IsolatedClientLoader( val version: HiveVersion, @@ -110,8 +109,8 @@ private[hive] class IsolatedClientLoader( val barrierPrefixes: Seq[String] = Seq.empty) extends Logging { - // Check to make sure that the base classloader does not know about Hive. - assert(Try(baseClassLoader.loadClass("org.apache.hive.HiveConf")).isFailure) + // Check to make sure that the root classloader does not know about Hive. + assert(Try(rootClassLoader.loadClass("org.apache.hadoop.hive.conf.HiveConf")).isFailure) /** All jars used by the hive specific classloader. */ protected def allJars = execJars.toArray @@ -145,6 +144,7 @@ private[hive] class IsolatedClientLoader( def doLoadClass(name: String, resolve: Boolean): Class[_] = { val classFileName = name.replaceAll("\\.", "/") + ".class" if (isBarrierClass(name) && isolationOn) { + // For barrier classes, we construct a new copy of the class. val bytes = IOUtils.toByteArray(baseClassLoader.getResourceAsStream(classFileName)) logDebug(s"custom defining: $name - ${util.Arrays.hashCode(bytes)}") defineClass(name, bytes, 0, bytes.length) @@ -152,6 +152,7 @@ private[hive] class IsolatedClientLoader( logDebug(s"hive class: $name - ${getResource(classToPath(name))}") super.loadClass(name, resolve) } else { + // For shared classes, we delegate to baseClassLoader. logDebug(s"shared class: $name") baseClassLoader.loadClass(name) } @@ -167,7 +168,7 @@ private[hive] class IsolatedClientLoader( classLoader .loadClass(classOf[ClientWrapper].getName) .getConstructors.head - .newInstance(version, config) + .newInstance(version, config, classLoader) .asInstanceOf[ClientInterface] } catch { case e: InvocationTargetException => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala new file mode 100644 index 0000000000000..7963abf3b9c92 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.File + +import org.apache.spark._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + +/** + * This suite tests spark-submit with applications using HiveContext. + */ +class HiveSparkSubmitSuite + extends SparkFunSuite + with Matchers + with ResetSystemProperties + with Timeouts { + + def beforeAll() { + System.setProperty("spark.testing", "true") + } + + test("SPARK-8368: includes jars passed in through --jars") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") + val args = Seq( + "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), + "--name", "SparkSubmitClassLoaderTest", + "--master", "local-cluster[2,1,512]", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("SPARK-8020: set sql conf in spark conf") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,512]", + unusedJar.toString) + runSparkSubmit(args) + } + + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. + // This is copied from org.apache.spark.deploy.SparkSubmitSuite + private def runSparkSubmit(args: Seq[String]): Unit = { + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val process = Utils.executeCommand( + Seq("./bin/spark-submit") ++ args, + new File(sparkHome), + Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + try { + val exitCode = failAfter(120 seconds) { process.waitFor() } + if (exitCode != 0) { + fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + } + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } + } +} + +// This object is used for testing SPARK-8368: https://issues.apache.org/jira/browse/SPARK-8368. +// We test if we can load user jars in both driver and executors when HiveContext is used. +object SparkSubmitClassLoaderTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + val df = hiveContext.createDataFrame((1 to 100).map(i => (i, i))).toDF("i", "j") + // First, we load classes at driver side. + try { + Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) + Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + } catch { + case t: Throwable => + throw new Exception("Could not load user class from jar:\n", t) + } + // Second, we load classes at the executor side. + val result = df.mapPartitions { x => + var exception: String = null + try { + Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) + Class.forName(args(1), true, Thread.currentThread().getContextClassLoader) + } catch { + case t: Throwable => + exception = t + "\n" + t.getStackTraceString + exception = exception.replaceAll("\n", "\n\t") + } + Option(exception).toSeq.iterator + }.collect() + if (result.nonEmpty) { + throw new Exception("Could not load user class from jar:\n" + result(0)) + } + + // Load a Hive UDF from the jar. + hiveContext.sql( + """ + |CREATE TEMPORARY FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + """.stripMargin) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Load a Hive SerDe from the jar. + hiveContext.sql( + """ + |CREATE TABLE t1(key int, val string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' + """.stripMargin) + // Actually use the loaded UDF and SerDe. + hiveContext.sql( + "INSERT INTO TABLE t1 SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + val count = hiveContext.table("t1").orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"table t1 should have 10 rows instead of $count rows") + } + } +} + +// This object is used for testing SPARK-8020: https://issues.apache.org/jira/browse/SPARK-8020. +// We test if we can correctly set spark sql configurations when HiveContext is used. +object SparkSQLConfTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + // We override the SparkConf to add spark.sql.hive.metastore.version and + // spark.sql.hive.metastore.jars to the beginning of the conf entry array. + // So, if metadataHive get initialized after we set spark.sql.hive.metastore.version but + // before spark.sql.hive.metastore.jars get set, we will see the following exception: + // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only + // be used when hive execution version == hive metastore version. + // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars + // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1. + val conf = new SparkConf() { + override def getAll: Array[(String, String)] = { + def isMetastoreSetting(conf: String): Boolean = { + conf == "spark.sql.hive.metastore.version" || conf == "spark.sql.hive.metastore.jars" + } + // If there is any metastore settings, remove them. + val filteredSettings = super.getAll.filterNot(e => isMetastoreSetting(e._1)) + + // Always add these two metastore settings at the beginning. + ("spark.sql.hive.metastore.version" -> "0.12") +: + ("spark.sql.hive.metastore.jars" -> "maven") +: + filteredSettings + } + + // For this simple test, we do not really clone this object. + override def clone: SparkConf = this + } + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + // Run a simple command to make sure all lazy vals in hiveContext get instantiated. + hiveContext.tables().collect() + } +} From 4a462c282c72c47eeecf35b4ab227c1bc71908e5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 19 Jun 2015 11:36:59 -0700 Subject: [PATCH 123/151] [HOTFIX] Fix scala style in DFSReadWriteTest that causes tests failed This scala style problem causes tested failed. Author: Liang-Chi Hsieh Closes #6907 from viirya/hotfix_style and squashes the following commits: c53f188 [Liang-Chi Hsieh] Fix scala style. --- .../scala/org/apache/spark/examples/DFSReadWriteTest.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index c05890dfbfde1..1f12034ce0f57 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkContext._ /** * Simple test for reading and writing to a distributed * file system. This example does the following: - * + * * 1. Reads local file * 2. Computes word count on local file * 3. Writes local file to a DFS @@ -36,7 +36,7 @@ import org.apache.spark.SparkContext._ * 6. Compares the word count results */ object DFSReadWriteTest { - + private var localFilePath: File = new File(".") private var dfsDirPath: String = "" From e41e2fd6c61076f870de03b85c5da6c12b8da038 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 19 Jun 2015 11:40:04 -0700 Subject: [PATCH 124/151] [SPARK-8461] [SQL] fix codegen with REPL class loader The ExecutorClassLoader for REPL will cause Janino failed to find class for those in java.lang, so switch to use default class loader for Janino, which will also help performance. cc liancheng yhuai Author: Davies Liu Closes #6898 from davies/fix_class_loader and squashes the following commits: 24276d4 [Davies Liu] add regression test 4ff0457 [Davies Liu] address comment, refactor 7f5ffbe [Davies Liu] fix REPL class loader with codegen --- .../org/apache/spark/repl/ReplSuite.scala | 11 ++++++++++ .../expressions/codegen/CodeGenerator.scala | 22 +++++++++++-------- .../codegen/GenerateMutableProjection.scala | 8 ++----- .../codegen/GenerateOrdering.scala | 7 +----- .../codegen/GeneratePredicate.scala | 8 +------ .../codegen/GenerateProjection.scala | 7 +----- 6 files changed, 29 insertions(+), 34 deletions(-) diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 50fd43a418bca..f150fec7db945 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -267,6 +267,17 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("Exception", output) } + test("SPARK-8461 SQL with codegen") { + val output = runInterpreter("local", + """ + |val sqlContext = new org.apache.spark.sql.SQLContext(sc) + |sqlContext.setConf("spark.sql.codegen", "true") + |sqlContext.range(0, 100).filter('id > 50).count() + """.stripMargin) + assertContains("Long = 49", output) + assertDoesNotContain("java.lang.ClassNotFoundException", output) + } + test("SPARK-2632 importing a method from non serializable class and not using it.") { val output = runInterpreter("local", """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index ab850d17a6dd3..bd5475d2066fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -203,6 +203,11 @@ class CodeGenContext { def isPrimitiveType(dt: DataType): Boolean = primitiveTypes.contains(dt) } + +abstract class GeneratedClass { + def generate(expressions: Array[Expression]): Any +} + /** * A base class for generators of byte code to perform expression evaluation. Includes a set of * helpers for referring to Catalyst types and building trees that perform evaluation of individual @@ -214,11 +219,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected val mutableRowType: String = classOf[MutableRow].getName protected val genericMutableRowType: String = classOf[GenericMutableRow].getName - /** - * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. - */ - var debugLogging = false - /** * Generates a class for a given input expression. Called when there is not cached code * already available. @@ -239,10 +239,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * * It will track the time used to compile */ - protected def compile(code: String): Class[_] = { + protected def compile(code: String): GeneratedClass = { val startTime = System.nanoTime() - val clazz = try { - new ClassBodyEvaluator(code).getClazz() + val evaluator = new ClassBodyEvaluator() + evaluator.setParentClassLoader(getClass.getClassLoader) + evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow")) + evaluator.setExtendedClass(classOf[GeneratedClass]) + try { + evaluator.cook(code) } catch { case e: Exception => logError(s"failed to compile:\n $code", e) @@ -251,7 +255,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val endTime = System.nanoTime() def timeMs: Double = (endTime - startTime).toDouble / 1000000 logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms") - clazz + evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 573a9ea0a5471..e75e82d380541 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -47,9 +47,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu """ }.mkString("\n") val code = s""" - import org.apache.spark.sql.catalyst.InternalRow; - - public SpecificProjection generate($exprType[] expr) { + public Object generate($exprType[] expr) { return new SpecificProjection(expr); } @@ -85,10 +83,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu logDebug(s"code for ${expressions.mkString(",")}:\n$code") val c = compile(code) - // fetch the only one method `generate(Expression[])` - val m = c.getDeclaredMethods()(0) () => { - m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseMutableProjection] + c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 3e9ee60f33037..7ed2c5addec9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -76,8 +76,6 @@ object GenerateOrdering }.mkString("\n") val code = s""" - import org.apache.spark.sql.catalyst.InternalRow; - public SpecificOrdering generate($exprType[] expr) { return new SpecificOrdering(expr); } @@ -100,9 +98,6 @@ object GenerateOrdering logDebug(s"Generated Ordering: $code") - val c = compile(code) - // fetch the only one method `generate(Expression[])` - val m = c.getDeclaredMethods()(0) - m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[BaseOrdering] + compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index dad4364bdd94a..3ebc2c147579b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ /** @@ -41,8 +40,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool val ctx = newCodeGenContext() val eval = predicate.gen(ctx) val code = s""" - import org.apache.spark.sql.catalyst.InternalRow; - public SpecificPredicate generate($exprType[] expr) { return new SpecificPredicate(expr); } @@ -62,10 +59,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool logDebug(s"Generated predicate '$predicate':\n$code") - val c = compile(code) - // fetch the only one method `generate(Expression[])` - val m = c.getDeclaredMethods()(0) - val p = m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Predicate] + val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] (r: InternalRow) => p.eval(r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 8b5dc194be31f..2e20eda1a3002 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -147,8 +147,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }.mkString("\n") val code = s""" - import org.apache.spark.sql.catalyst.InternalRow; - public SpecificProjection generate($exprType[] expr) { return new SpecificProjection(expr); } @@ -220,9 +218,6 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") - val c = compile(code) - // fetch the only one method `generate(Expression[])` - val m = c.getDeclaredMethods()(0) - m.invoke(c.newInstance(), ctx.references.toArray).asInstanceOf[Projection] + compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] } } From 54976e55e36465108b71b40b8a431be9d6d703ce Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 19 Jun 2015 12:23:15 -0700 Subject: [PATCH 125/151] [SPARK-4118] [MLLIB] [PYSPARK] Python bindings for StreamingKMeans Python bindings for StreamingKMeans Will change status to MRG once docs, tests and examples are updated. Author: MechCoder Closes #6499 from MechCoder/spark-4118 and squashes the following commits: 7722d16 [MechCoder] minor style fixes 51052d3 [MechCoder] Doc fixes 2061a76 [MechCoder] Add tests for simultaneous training and prediction Minor style fixes 81482fd [MechCoder] minor 5d9fe61 [MechCoder] predictOn should take into account the latest model 8ab9e89 [MechCoder] Fix Python3 error a9817df [MechCoder] Better tests and minor fixes c80e451 [MechCoder] Add ignore_unicode_prefix ee8ce16 [MechCoder] Update tests, doc and examples 4b1481f [MechCoder] Some changes and tests d8b066a [MechCoder] [SPARK-4118] [MLlib] [PySpark] Python bindings for StreamingKMeans --- docs/mllib-clustering.md | 48 +++- .../mllib/api/python/PythonMLLibAPI.scala | 15 ++ python/pyspark/mllib/clustering.py | 207 +++++++++++++++++- python/pyspark/mllib/tests.py | 150 ++++++++++++- 4 files changed, 411 insertions(+), 9 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 1b088969ddc25..dcaa3784be874 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -592,15 +592,55 @@ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() +{% endhighlight %} +
+ +
+First we import the neccessary classes. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.clustering import StreamingKMeans {% endhighlight %} +Then we make an input stream of vectors for training, as well as a stream of labeled data +points for testing. We assume a StreamingContext `ssc` has been created, see +[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. + +{% highlight python %} +def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + +trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse) +testData = ssc.textFileStream("/testing/data/dir").map(parse) +{% endhighlight %} + +We create a model with random clusters and specify the number of clusters to find + +{% highlight python %} +model = StreamingKMeans(k=2, decayFactor=1.0).setRandomCenters(3, 1.0, 0) +{% endhighlight %} + +Now register the streams for training and testing and start the job, printing +the predicted cluster assignments on new data points as they arrive. + +{% highlight python %} +model.trainOn(trainingData) +print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + +ssc.start() +ssc.awaitTermination() +{% endhighlight %} +
+ + + As you add new text files with data the cluster centers will update. Each training point should be formatted as `[x1, x2, x3]`, and each test data point should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier (e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. With new data, the cluster centers will change! - - - - diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 1812b3ac7cc0e..2897865af6912 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -964,6 +964,21 @@ private[python] class PythonMLLibAPI extends Serializable { points.asScala.toArray) } + /** + * Java stub for the update method of StreamingKMeansModel. + */ + def updateStreamingKMeansModel( + clusterCenters: JList[Vector], + clusterWeights: JList[Double], + data: JavaRDD[Vector], + decayFactor: Double, + timeUnit: String): JList[Object] = { + val model = new StreamingKMeansModel( + clusterCenters.asScala.toArray, clusterWeights.asScala.toArray) + .update(data, decayFactor, timeUnit) + List[AnyRef](model.clusterCenters, Vectors.dense(model.clusterWeights)).asJava + } + } /** diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index b55583f82223f..c38229864d3b4 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -21,16 +21,20 @@ if sys.version > '3': xrange = range -from numpy import array +from math import exp, log + +from numpy import array, random, tile -from pyspark import RDD from pyspark import SparkContext +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector from pyspark.mllib.stat.distribution import MultivariateGaussian from pyspark.mllib.util import Saveable, Loader, inherit_doc +from pyspark.streaming import DStream -__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture'] +__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture', + 'StreamingKMeans', 'StreamingKMeansModel'] @inherit_doc @@ -98,6 +102,9 @@ def predict(self, x): """Find the cluster to which x belongs in this model.""" best = 0 best_distance = float("inf") + if isinstance(x, RDD): + return x.map(self.predict) + x = _convert_to_vector(x) for i in xrange(len(self.centers)): distance = x.squared_distance(self.centers[i]) @@ -264,6 +271,198 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia return GaussianMixtureModel(weight, mvg_obj) +class StreamingKMeansModel(KMeansModel): + """ + .. note:: Experimental + + Clustering model which can perform an online update of the centroids. + + The update formula for each centroid is given by + + * c_t+1 = ((c_t * n_t * a) + (x_t * m_t)) / (n_t + m_t) + * n_t+1 = n_t * a + m_t + + where + + * c_t: Centroid at the n_th iteration. + * n_t: Number of samples (or) weights associated with the centroid + at the n_th iteration. + * x_t: Centroid of the new data closest to c_t. + * m_t: Number of samples (or) weights of the new data closest to c_t + * c_t+1: New centroid. + * n_t+1: New number of weights. + * a: Decay Factor, which gives the forgetfulness. + + Note that if a is set to 1, it is the weighted mean of the previous + and new data. If it set to zero, the old centroids are completely + forgotten. + + :param clusterCenters: Initial cluster centers. + :param clusterWeights: List of weights assigned to each cluster. + + >>> initCenters = [[0.0, 0.0], [1.0, 1.0]] + >>> initWeights = [1.0, 1.0] + >>> stkm = StreamingKMeansModel(initCenters, initWeights) + >>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1], + ... [0.9, 0.9], [1.1, 1.1]]) + >>> stkm = stkm.update(data, 1.0, u"batches") + >>> stkm.centers + array([[ 0., 0.], + [ 1., 1.]]) + >>> stkm.predict([-0.1, -0.1]) + 0 + >>> stkm.predict([0.9, 0.9]) + 1 + >>> stkm.clusterWeights + [3.0, 3.0] + >>> decayFactor = 0.0 + >>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])]) + >>> stkm = stkm.update(data, 0.0, u"batches") + >>> stkm.centers + array([[ 0.2, 0.2], + [ 1.5, 1.5]]) + >>> stkm.clusterWeights + [1.0, 1.0] + >>> stkm.predict([0.2, 0.2]) + 0 + >>> stkm.predict([1.5, 1.5]) + 1 + """ + def __init__(self, clusterCenters, clusterWeights): + super(StreamingKMeansModel, self).__init__(centers=clusterCenters) + self._clusterWeights = list(clusterWeights) + + @property + def clusterWeights(self): + """Return the cluster weights.""" + return self._clusterWeights + + @ignore_unicode_prefix + def update(self, data, decayFactor, timeUnit): + """Update the centroids, according to data + + :param data: Should be a RDD that represents the new data. + :param decayFactor: forgetfulness of the previous centroids. + :param timeUnit: Can be "batches" or "points". If points, then the + decay factor is raised to the power of number of new + points and if batches, it is used as it is. + """ + if not isinstance(data, RDD): + raise TypeError("Data should be of an RDD, got %s." % type(data)) + data = data.map(_convert_to_vector) + decayFactor = float(decayFactor) + if timeUnit not in ["batches", "points"]: + raise ValueError( + "timeUnit should be 'batches' or 'points', got %s." % timeUnit) + vectorCenters = [_convert_to_vector(center) for center in self.centers] + updatedModel = callMLlibFunc( + "updateStreamingKMeansModel", vectorCenters, self._clusterWeights, + data, decayFactor, timeUnit) + self.centers = array(updatedModel[0]) + self._clusterWeights = list(updatedModel[1]) + return self + + +class StreamingKMeans(object): + """ + .. note:: Experimental + + Provides methods to set k, decayFactor, timeUnit to configure the + KMeans algorithm for fitting and predicting on incoming dstreams. + More details on how the centroids are updated are provided under the + docs of StreamingKMeansModel. + + :param k: int, number of clusters + :param decayFactor: float, forgetfulness of the previous centroids. + :param timeUnit: can be "batches" or "points". If points, then the + decayfactor is raised to the power of no. of new points. + """ + def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"): + self._k = k + self._decayFactor = decayFactor + if timeUnit not in ["batches", "points"]: + raise ValueError( + "timeUnit should be 'batches' or 'points', got %s." % timeUnit) + self._timeUnit = timeUnit + self._model = None + + def latestModel(self): + """Return the latest model""" + return self._model + + def _validate(self, dstream): + if self._model is None: + raise ValueError( + "Initial centers should be set either by setInitialCenters " + "or setRandomCenters.") + if not isinstance(dstream, DStream): + raise TypeError( + "Expected dstream to be of type DStream, " + "got type %s" % type(dstream)) + + def setK(self, k): + """Set number of clusters.""" + self._k = k + return self + + def setDecayFactor(self, decayFactor): + """Set decay factor.""" + self._decayFactor = decayFactor + return self + + def setHalfLife(self, halfLife, timeUnit): + """ + Set number of batches after which the centroids of that + particular batch has half the weightage. + """ + self._timeUnit = timeUnit + self._decayFactor = exp(log(0.5) / halfLife) + return self + + def setInitialCenters(self, centers, weights): + """ + Set initial centers. Should be set before calling trainOn. + """ + self._model = StreamingKMeansModel(centers, weights) + return self + + def setRandomCenters(self, dim, weight, seed): + """ + Set the initial centres to be random samples from + a gaussian population with constant weights. + """ + rng = random.RandomState(seed) + clusterCenters = rng.randn(self._k, dim) + clusterWeights = tile(weight, self._k) + self._model = StreamingKMeansModel(clusterCenters, clusterWeights) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + self._model.update(rdd, self._decayFactor, self._timeUnit) + + dstream.foreachRDD(update) + + def predictOn(self, dstream): + """ + Make predictions on a dstream. + Returns a transformed dstream object + """ + self._validate(dstream) + return dstream.map(lambda x: self._model.predict(x)) + + def predictOnValues(self, dstream): + """ + Make predictions on a keyed dstream. + Returns a transformed dstream object. + """ + self._validate(dstream) + return dstream.mapValues(lambda x: self._model.predict(x)) + + def _test(): import doctest globs = globals().copy() diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index c482e6b0681e3..744dc112d9209 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -23,8 +23,10 @@ import sys import tempfile import array as pyarray +from time import time, sleep -from numpy import array, array_equal, zeros, inf +from numpy import array, array_equal, zeros, inf, all, random +from numpy import sum as array_sum from py4j.protocol import Py4JJavaError if sys.version_info[:2] <= (2, 6): @@ -38,6 +40,7 @@ from pyspark import SparkContext from pyspark.mllib.common import _to_java_object_rdd +from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.regression import LabeledPoint @@ -48,6 +51,7 @@ from pyspark.mllib.feature import StandardScaler from pyspark.mllib.feature import ElementwiseProduct from pyspark.serializers import PickleSerializer +from pyspark.streaming import StreamingContext from pyspark.sql import SQLContext _have_scipy = False @@ -67,6 +71,20 @@ def setUp(self): self.sc = sc +class MLLibStreamingTestCase(unittest.TestCase): + def setUp(self): + self.sc = sc + self.ssc = StreamingContext(self.sc, 1.0) + + def tearDown(self): + self.ssc.stop(False) + + @staticmethod + def _ssc_wait(start_time, end_time, sleep_time): + while time() - start_time < end_time: + sleep(0.01) + + def _squared_distance(a, b): if isinstance(a, Vector): return a.squared_distance(b) @@ -863,6 +881,136 @@ def test_model_transform(self): eprod.transform(sparsevec), SparseVector(3, [0], [3])) +class StreamingKMeansTest(MLLibStreamingTestCase): + def test_model_params(self): + """Test that the model params are set correctly""" + stkm = StreamingKMeans() + stkm.setK(5).setDecayFactor(0.0) + self.assertEquals(stkm._k, 5) + self.assertEquals(stkm._decayFactor, 0.0) + + # Model not set yet. + self.assertIsNone(stkm.latestModel()) + self.assertRaises(ValueError, stkm.trainOn, [0.0, 1.0]) + + stkm.setInitialCenters( + centers=[[0.0, 0.0], [1.0, 1.0]], weights=[1.0, 1.0]) + self.assertEquals( + stkm.latestModel().centers, [[0.0, 0.0], [1.0, 1.0]]) + self.assertEquals(stkm.latestModel().clusterWeights, [1.0, 1.0]) + + def test_accuracy_for_single_center(self): + """Test that parameters obtained are correct for a single center.""" + centers, batches = self.streamingKMeansDataGenerator( + batches=5, numPoints=5, k=1, d=5, r=0.1, seed=0) + stkm = StreamingKMeans(1) + stkm.setInitialCenters([[0., 0., 0., 0., 0.]], [0.]) + input_stream = self.ssc.queueStream( + [self.sc.parallelize(batch, 1) for batch in batches]) + stkm.trainOn(input_stream) + + t = time() + self.ssc.start() + self._ssc_wait(t, 10.0, 0.01) + self.assertEquals(stkm.latestModel().clusterWeights, [25.0]) + realCenters = array_sum(array(centers), axis=0) + for i in range(5): + modelCenters = stkm.latestModel().centers[0][i] + self.assertAlmostEqual(centers[0][i], modelCenters, 1) + self.assertAlmostEqual(realCenters[i], modelCenters, 1) + + def streamingKMeansDataGenerator(self, batches, numPoints, + k, d, r, seed, centers=None): + rng = random.RandomState(seed) + + # Generate centers. + centers = [rng.randn(d) for i in range(k)] + + return centers, [[Vectors.dense(centers[j % k] + r * rng.randn(d)) + for j in range(numPoints)] + for i in range(batches)] + + def test_trainOn_model(self): + """Test the model on toy data with four clusters.""" + stkm = StreamingKMeans() + initCenters = [[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]] + stkm.setInitialCenters( + centers=initCenters, weights=[1.0, 1.0, 1.0, 1.0]) + + # Create a toy dataset by setting a tiny offest for each point. + offsets = [[0, 0.1], [0, -0.1], [0.1, 0], [-0.1, 0]] + batches = [] + for offset in offsets: + batches.append([[offset[0] + center[0], offset[1] + center[1]] + for center in initCenters]) + + batches = [self.sc.parallelize(batch, 1) for batch in batches] + input_stream = self.ssc.queueStream(batches) + stkm.trainOn(input_stream) + t = time() + self.ssc.start() + + # Give enough time to train the model. + self._ssc_wait(t, 6.0, 0.01) + finalModel = stkm.latestModel() + self.assertTrue(all(finalModel.centers == array(initCenters))) + self.assertEquals(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) + + def test_predictOn_model(self): + """Test that the model predicts correctly on toy data.""" + stkm = StreamingKMeans() + stkm._model = StreamingKMeansModel( + clusterCenters=[[1.0, 1.0], [-1.0, 1.0], [-1.0, -1.0], [1.0, -1.0]], + clusterWeights=[1.0, 1.0, 1.0, 1.0]) + + predict_data = [[[1.5, 1.5]], [[-1.5, 1.5]], [[-1.5, -1.5]], [[1.5, -1.5]]] + predict_data = [sc.parallelize(batch, 1) for batch in predict_data] + predict_stream = self.ssc.queueStream(predict_data) + predict_val = stkm.predictOn(predict_stream) + + result = [] + + def update(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + result.append(rdd_collect) + + predict_val.foreachRDD(update) + t = time() + self.ssc.start() + self._ssc_wait(t, 6.0, 0.01) + self.assertEquals(result, [[0], [1], [2], [3]]) + + def test_trainOn_predictOn(self): + """Test that prediction happens on the updated model.""" + stkm = StreamingKMeans(decayFactor=0.0, k=2) + stkm.setInitialCenters([[0.0], [1.0]], [1.0, 1.0]) + + # Since decay factor is set to zero, once the first batch + # is passed the clusterCenters are updated to [-0.5, 0.7] + # which causes 0.2 & 0.3 to be classified as 1, even though the + # classification based in the initial model would have been 0 + # proving that the model is updated. + batches = [[[-0.5], [0.6], [0.8]], [[0.2], [-0.1], [0.3]]] + batches = [sc.parallelize(batch) for batch in batches] + input_stream = self.ssc.queueStream(batches) + predict_results = [] + + def collect(rdd): + rdd_collect = rdd.collect() + if rdd_collect: + predict_results.append(rdd_collect) + + stkm.trainOn(input_stream) + predict_stream = stkm.predictOn(input_stream) + predict_stream.foreachRDD(collect) + + t = time() + self.ssc.start() + self._ssc_wait(t, 6.0, 0.01) + self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") From 1fa29c2df2a7846405eed6b409b8deb5329fa7c1 Mon Sep 17 00:00:00 2001 From: Hossein Date: Fri, 19 Jun 2015 15:47:22 -0700 Subject: [PATCH 126/151] [SPARK-8452] [SPARKR] expose jobGroup API in SparkR This pull request adds following methods to SparkR: ```R setJobGroup() cancelJobGroup() clearJobGroup() ``` For each method, the spark context is passed as the first argument. There does not seem to be a good way to test these in R. cc shivaram and davies Author: Hossein Closes #6889 from falaki/SPARK-8452 and squashes the following commits: 9ce9f1e [Hossein] Added basic tests to verify methods can be called and won't throw errors c706af9 [Hossein] Added examples a2c19af [Hossein] taking spark context as first argument 343ca77 [Hossein] Added setJobGroup, cancelJobGroup and clearJobGroup to SparkR --- R/pkg/NAMESPACE | 5 ++++ R/pkg/R/sparkR.R | 44 +++++++++++++++++++++++++++++++++ R/pkg/inst/tests/test_context.R | 7 ++++++ 3 files changed, 56 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index f9447f6c3288d..7f857222452d4 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -10,6 +10,11 @@ export("sparkR.init") export("sparkR.stop") export("print.jobj") +# Job group lifecycle management methods +export("setJobGroup", + "clearJobGroup", + "cancelJobGroup") + exportClasses("DataFrame") exportMethods("arrange", diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 5ced7c688f98a..2efd4f0742e77 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -278,3 +278,47 @@ sparkRHive.init <- function(jsc = NULL) { assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) hiveCtx } + +#' Assigns a group ID to all the jobs started by this thread until the group ID is set to a +#' different value or cleared. +#' +#' @param sc existing spark context +#' @param groupid the ID to be assigned to job groups +#' @param description description for the the job group ID +#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' setJobGroup(sc, "myJobGroup", "My job group description", TRUE) +#'} + +setJobGroup <- function(sc, groupId, description, interruptOnCancel) { + callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel) +} + +#' Clear current job group ID and its description +#' +#' @param sc existing spark context +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' clearJobGroup(sc) +#'} + +clearJobGroup <- function(sc) { + callJMethod(sc, "clearJobGroup") +} + +#' Cancel active jobs for the specified group +#' +#' @param sc existing spark context +#' @param groupId the ID of job group to be cancelled +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' cancelJobGroup(sc, "myJobGroup") +#'} + +cancelJobGroup <- function(sc, groupId) { + callJMethod(sc, "cancelJobGroup", groupId) +} diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R index e4aab37436a74..513bbc8e62059 100644 --- a/R/pkg/inst/tests/test_context.R +++ b/R/pkg/inst/tests/test_context.R @@ -48,3 +48,10 @@ test_that("rdd GC across sparkR.stop", { count(rdd3) count(rdd4) }) + +test_that("job group functions can be called", { + sc <- sparkR.init() + setJobGroup(sc, "groupId", "job description", TRUE) + cancelJobGroup(sc, "groupId") + clearJobGroup(sc) +}) From 9814b971f07dff8a99f1b8ad2adf70614f1c690b Mon Sep 17 00:00:00 2001 From: Nathan Howell Date: Fri, 19 Jun 2015 16:19:28 -0700 Subject: [PATCH 127/151] [SPARK-8093] [SQL] Remove empty structs inferred from JSON documents Author: Nathan Howell Closes #6799 from NathanHowell/spark-8093 and squashes the following commits: 76ac3e8 [Nathan Howell] [SPARK-8093] [SQL] Remove empty structs inferred from JSON documents --- .../apache/spark/sql/json/InferSchema.scala | 52 +++++++++++++------ .../org/apache/spark/sql/json/JsonSuite.scala | 4 ++ .../apache/spark/sql/json/TestJsonData.scala | 9 ++++ 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 565d10247f10e..afe2c6c11ac69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -43,7 +43,7 @@ private[sql] object InferSchema { } // perform schema inference on each row and merge afterwards - schemaData.mapPartitions { iter => + val rootType = schemaData.mapPartitions { iter => val factory = new JsonFactory() iter.map { row => try { @@ -55,8 +55,13 @@ private[sql] object InferSchema { StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) } } - }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) match { - case st: StructType => nullTypeToStringType(st) + }.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) + + canonicalizeType(rootType) match { + case Some(st: StructType) => st + case _ => + // canonicalizeType erases all empty structs, including the only one we want to keep + StructType(Seq()) } } @@ -116,22 +121,35 @@ private[sql] object InferSchema { } } - private def nullTypeToStringType(struct: StructType): StructType = { - val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable, _) => - val newType = dataType match { - case NullType => StringType - case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) - case ArrayType(struct: StructType, containsNull) => - ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType => nullTypeToStringType(struct) - case other: DataType => other - } + /** + * Convert NullType to StringType and remove StructTypes with no fields + */ + private def canonicalizeType: DataType => Option[DataType] = { + case at@ArrayType(elementType, _) => + for { + canonicalType <- canonicalizeType(elementType) + } yield { + at.copy(canonicalType) + } - StructField(fieldName, newType, nullable) - } + case StructType(fields) => + val canonicalFields = for { + field <- fields + if field.name.nonEmpty + canonicalType <- canonicalizeType(field.dataType) + } yield { + field.copy(dataType = canonicalType) + } + + if (canonicalFields.nonEmpty) { + Some(StructType(canonicalFields)) + } else { + // per SPARK-8093: empty structs should be deleted + None + } - StructType(fields) + case NullType => Some(StringType) + case other => Some(other) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 945d4375035fd..c32d9f88dd6ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -1103,4 +1103,8 @@ class JsonSuite extends QueryTest with TestJsonData { } } + test("SPARK-8093 Erase empty structs") { + val emptySchema = InferSchema(emptyRecords, 1.0, "") + assert(StructType(Seq()) === emptySchema) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index b6a6a8dc6a63c..eb62066ac6430 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -189,5 +189,14 @@ trait TestJsonData { """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) + def emptyRecords: RDD[String] = + ctx.sparkContext.parallelize( + """{""" :: + """""" :: + """{"a": {}}""" :: + """{"a": {"b": {}}}""" :: + """{"b": [{"c": {}}]}""" :: + """]""" :: Nil) + def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) } From a333a72e029d2546a66b36d6b3458e965430c530 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 19 Jun 2015 16:54:51 -0700 Subject: [PATCH 128/151] [SPARK-8420] [SQL] Fix comparision of timestamps/dates with strings In earlier versions of Spark SQL we casted `TimestampType` and `DataType` to `StringType` when it was involved in a binary comparison with a `StringType`. This allowed comparing a timestamp with a partial date as a user would expect. - `time > "2014-06-10"` - `time > "2014"` In 1.4.0 we tried to cast the String instead into a Timestamp. However, since partial dates are not a valid complete timestamp this results in `null` which results in the tuple being filtered. This PR restores the earlier behavior. Note that we still special case equality so that these comparisons are not affected by not printing zeros for subsecond precision. Author: Michael Armbrust Closes #6888 from marmbrus/timeCompareString and squashes the following commits: bdef29c [Michael Armbrust] test partial date 1f09adf [Michael Armbrust] special handling of equality 1172c60 [Michael Armbrust] more test fixing 4dfc412 [Michael Armbrust] fix tests aaa9508 [Michael Armbrust] newline 04d908f [Michael Armbrust] [SPARK-8420][SQL] Fix comparision of timestamps/dates with strings --- .../catalyst/analysis/HiveTypeCoercion.scala | 17 +++++- .../sql/catalyst/expressions/predicates.scala | 9 +++ .../apache/spark/sql/DataFrameDateSuite.scala | 56 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 4 ++ .../scala/org/apache/spark/sql/TestData.scala | 6 -- .../columnar/InMemoryColumnarQuerySuite.scala | 7 ++- 6 files changed, 88 insertions(+), 11 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 8012b224eb444..d4ab1fc643c33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -277,15 +277,26 @@ trait HiveTypeCoercion { case a @ BinaryArithmetic(left, right @ StringType()) => a.makeCopy(Array(left, Cast(right, DoubleType))) - // we should cast all timestamp/date/string compare into string compare + // For equality between string and timestamp we cast the string to a timestamp + // so that things like rounding of subsecond precision does not affect the comparison. + case p @ Equality(left @ StringType(), right @ TimestampType()) => + p.makeCopy(Array(Cast(left, TimestampType), right)) + case p @ Equality(left @ TimestampType(), right @ StringType()) => + p.makeCopy(Array(left, Cast(right, TimestampType))) + + // We should cast all relative timestamp/date/string comparison into string comparisions + // This behaves as a user would expect because timestamp strings sort lexicographically. + // i.e. TimeStamp(2013-01-01 00:00 ...) < "2014" = true case p @ BinaryComparison(left @ StringType(), right @ DateType()) => p.makeCopy(Array(left, Cast(right, StringType))) case p @ BinaryComparison(left @ DateType(), right @ StringType()) => p.makeCopy(Array(Cast(left, StringType), right)) case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, TimestampType), right)) + p.makeCopy(Array(left, Cast(right, StringType))) case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(left, Cast(right, TimestampType))) + p.makeCopy(Array(Cast(left, StringType), right)) + + // Comparisons between dates and timestamps. case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) => p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType))) case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 082d72eb438fa..3a12d03ba6bb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -266,6 +266,15 @@ private[sql] object BinaryComparison { def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right)) } +/** An extractor that matches both standard 3VL equality and null-safe equality. */ +private[sql] object Equality { + def unapply(e: BinaryComparison): Option[(Expression, Expression)] = e match { + case EqualTo(l, r) => Some((l, r)) + case EqualNullSafe(l, r) => Some((l, r)) + case _ => None + } +} + case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "=" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala new file mode 100644 index 0000000000000..a4719a38de1d4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameDateSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.{Date, Timestamp} + +class DataFrameDateTimeSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + test("timestamp comparison with date strings") { + val df = Seq( + (1, Timestamp.valueOf("2015-01-01 00:00:00")), + (2, Timestamp.valueOf("2014-01-01 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Timestamp.valueOf("2014-01-01 00:00:00")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2014-06-01"), + Row(Timestamp.valueOf("2015-01-01 00:00:00")) :: Nil) + } + + test("date comparison with date strings") { + val df = Seq( + (1, Date.valueOf("2015-01-01")), + (2, Date.valueOf("2014-01-01"))).toDF("i", "t") + + checkAnswer( + df.select("t").filter($"t" <= "2014-06-01"), + Row(Date.valueOf("2014-01-01")) :: Nil) + + + checkAnswer( + df.select("t").filter($"t" >= "2015"), + Row(Date.valueOf("2015-01-01")) :: Nil) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 82f3fdb48b557..4441afd6bd811 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate @@ -345,6 +347,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-3173 Timestamp support in the parser") { + (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time").registerTempTable("timestamps") + checkAnswer(sql( "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.0'"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 725a18bfae3a7..520a862ea0838 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -174,12 +174,6 @@ object TestData { "3, C3, true, null" :: "4, D4, true, 2147483644" :: Nil) - case class TimestampField(time: Timestamp) - val timestamps = TestSQLContext.sparkContext.parallelize((0 to 3).map { i => - TimestampField(new Timestamp(i)) - }) - timestamps.toDF().registerTempTable("timestamps") - case class IntField(i: Int) // An RDD with 4 elements and 8 partitions val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 12f95eb557c04..01bc23277fa88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -91,15 +91,18 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-2729 regression: timestamp data type") { + val timestamps = (0 to 3).map(i => Tuple1(new Timestamp(i))).toDF("time") + timestamps.registerTempTable("timestamps") + checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq.map(Row.fromTuple)) + timestamps.collect().toSeq) ctx.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), - timestamps.collect().toSeq.map(Row.fromTuple)) + timestamps.collect().toSeq) } test("SPARK-3320 regression: batched column buffer building should work with empty partitions") { From b305e377fb0a2ca67d9924b995c51e483a4944ad Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Fri, 19 Jun 2015 17:16:56 -0700 Subject: [PATCH 129/151] [SPARK-8390] [STREAMING] [KAFKA] fix docs related to HasOffsetRanges Author: cody koeninger Closes #6863 from koeninger/SPARK-8390 and squashes the following commits: 26a06bd [cody koeninger] Merge branch 'master' into SPARK-8390 3744492 [cody koeninger] [Streaming][Kafka][SPARK-8390] doc changes per TD, test to make sure approach shown in docs actually compiles + runs b108c9d [cody koeninger] [Streaming][Kafka][SPARK-8390] further doc fixes, clean up spacing bb4336b [cody koeninger] [Streaming][Kafka][SPARK-8390] fix docs related to HasOffsetRanges, cleanup 3f3c57a [cody koeninger] [Streaming][Kafka][SPARK-8389] Example of getting offset ranges out of the existing java direct stream api --- docs/streaming-kafka-integration.md | 70 +++++++++++++------ .../kafka/JavaDirectKafkaStreamSuite.java | 11 ++- .../kafka/DirectKafkaStreamSuite.scala | 16 +++-- 3 files changed, 71 insertions(+), 26 deletions(-) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 02bc95d0e95f9..775d508d4879b 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -2,7 +2,7 @@ layout: global title: Spark Streaming + Kafka Integration Guide --- -[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. +[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new experimental approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. ## Approach 1: Receiver-based Approach This approach uses a Receiver to receive the data. The Received is implemented using the Kafka high-level consumer API. As with all receivers, the data received from Kafka through a Receiver is stored in Spark executors, and then jobs launched by Spark Streaming processes the data. @@ -74,15 +74,15 @@ Next, we discuss how to use this approach in your streaming application. [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-assembly_2.10%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`. ## Approach 2: Direct Approach (No Receivers) -This is a new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature in Spark 1.3 and is only available in the Scala and Java API. +This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this is an experimental feature introduced in Spark 1.3 for the Scala and Java API. Spark 1.4 added a Python API, but it is not yet at full feature parity. -This approach has the following advantages over the received-based approach (i.e. Approach 1). +This approach has the following advantages over the receiver-based approach (i.e. Approach 1). -- *Simplified Parallelism:* No need to create multiple input Kafka streams and union-ing them. With `directStream`, Spark Streaming will create as many RDD partitions as there is Kafka partitions to consume, which will all read data from Kafka in parallel. So there is one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. +- *Simplified Parallelism:* No need to create multiple input Kafka streams and union them. With `directStream`, Spark Streaming will create as many RDD partitions as there are Kafka partitions to consume, which will all read data from Kafka in parallel. So there is a one-to-one mapping between Kafka and RDD partitions, which is easier to understand and tune. -- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminate the problem as there is no receiver, and hence no need for Write Ahead Logs. +- *Efficiency:* Achieving zero-data loss in the first approach required the data to be stored in a Write Ahead Log, which further replicated the data. This is actually inefficient as the data effectively gets replicated twice - once by Kafka, and a second time by the Write Ahead Log. This second approach eliminates the problem as there is no receiver, and hence no need for Write Ahead Logs. As long as you have sufficient Kafka retention, messages can be recovered from Kafka. -- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper and offsets tracked only by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. +- *Exactly-once semantics:* The first approach uses Kafka's high level API to store consumed offsets in Zookeeper. This is traditionally the way to consume data from Kafka. While this approach (in combination with write ahead logs) can ensure zero data loss (i.e. at-least once semantics), there is a small chance some records may get consumed twice under some failures. This occurs because of inconsistencies between data reliably received by Spark Streaming and offsets tracked by Zookeeper. Hence, in this second approach, we use simple Kafka API that does not use Zookeeper. Offsets are tracked by Spark Streaming within its checkpoints. This eliminates inconsistencies between Spark Streaming and Zookeeper/Kafka, and so each record is received by Spark Streaming effectively exactly once despite failures. In order to achieve exactly-once semantics for output of your results, your output operation that saves the data to an external data store must be either idempotent, or an atomic transaction that saves results and offsets (see [Semanitcs of output operations](streaming-programming-guide.html#semantics-of-output-operations) in the main programming guide for further information). Note that one disadvantage of this approach is that it does not update offsets in Zookeeper, hence Zookeeper-based Kafka monitoring tools will not show progress. However, you can access the offsets processed by this approach in each batch and update Zookeeper yourself (see below). @@ -135,32 +135,60 @@ Next, we discuss how to use this approach in your streaming application.
- directKafkaStream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges] - // offsetRanges.length = # of Kafka partitions being consumed - ... + // Hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + + directKafkaStream.transform { rdd => + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd + }.map { + ... + }.foreachRDD { rdd => + for (o <- offsetRanges) { + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } + ... }
- directKafkaStream.foreachRDD( - new Function, Void>() { - @Override - public Void call(JavaPairRDD rdd) throws IOException { - OffsetRange[] offsetRanges = ((HasOffsetRanges)rdd).offsetRanges - // offsetRanges.length = # of Kafka partitions being consumed - ... - return null; - } + // Hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference(); + + directKafkaStream.transformToPair( + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); + return rdd; } + } + ).map( + ... + ).foreachRDD( + new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws IOException { + for (OffsetRange o : offsetRanges.get()) { + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() + ); + } + ... + return null; + } + } );
- Not supported + Not supported yet
You can use this to update Zookeeper yourself if you want Zookeeper-based Kafka monitoring tools to show progress of the streaming application. - Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate at which each Kafka partition will be read by this direct API. + Note that the typecast to HasOffsetRanges will only succeed if it is done in the first method called on the directKafkaStream, not later down a chain of methods. You can use transform() instead of foreachRDD() as your first method call in order to access offsets, then call further Spark methods. However, be aware that the one-to-one mapping between RDD partition and Kafka partition does not remain after any methods that shuffle or repartition, e.g. reduceByKey() or window(). + + Another thing to note is that since this approach does not use Receivers, the standard receiver-related (that is, [configurations](configuration.html) of the form `spark.streaming.receiver.*` ) will not apply to the input DStreams created by this approach (will apply to other input DStreams though). Instead, use the [configurations](configuration.html) `spark.streaming.kafka.*`. An important one is `spark.streaming.kafka.maxRatePerPartition` which is the maximum rate (in messages per second) at which each Kafka partition will be read by this direct API. 3. **Deploying:** This is same as the first approach, for Scala, Java and Python. diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index 3913b711ba28b..02cd24a35906f 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.*; +import java.util.concurrent.atomic.AtomicReference; import scala.Tuple2; @@ -68,6 +69,8 @@ public void tearDown() { public void testKafkaStream() throws InterruptedException { final String topic1 = "topic1"; final String topic2 = "topic2"; + // hold a reference to the current offset ranges, so it can be used downstream + final AtomicReference offsetRanges = new AtomicReference(); String[] topic1data = createTopicAndSendData(topic1); String[] topic2data = createTopicAndSendData(topic2); @@ -93,7 +96,8 @@ public void testKafkaStream() throws InterruptedException { new Function, JavaPairRDD>() { @Override public JavaPairRDD call(JavaPairRDD rdd) throws Exception { - OffsetRange[] offsets = ((HasOffsetRanges)rdd.rdd()).offsetRanges(); + OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + offsetRanges.set(offsets); Assert.assertEquals(offsets[0].topic(), topic1); return rdd; } @@ -131,6 +135,11 @@ public String call(MessageAndMetadata msgAndMd) throws Exception @Override public Void call(JavaRDD rdd) throws Exception { result.addAll(rdd.collect()); + for (OffsetRange o : offsetRanges.get()) { + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() + ); + } return null; } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 212eb35c61b66..8e1715f6dbb95 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -102,13 +102,21 @@ class DirectKafkaStreamSuite val allReceived = new ArrayBuffer[(String, String)] with mutable.SynchronizedBuffer[(String, String)] - stream.foreachRDD { rdd => - // Get the offset ranges in the RDD - val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + // hold a reference to the current offset ranges, so it can be used downstream + var offsetRanges = Array[OffsetRange]() + + stream.transform { rdd => + // Get the offset ranges in the RDD + offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd + }.foreachRDD { rdd => + for (o <- offsetRanges) { + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } val collected = rdd.mapPartitionsWithIndex { (i, iter) => // For each partition, get size of the range in the partition, // and the number of items in the partition - val off = offsets(i) + val off = offsetRanges(i) val all = iter.toSeq val partSize = all.size val rangeSize = off.untilOffset - off.fromOffset From 093c34838d1db7a9375f36a9a2ab5d96a23ae683 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 19 Jun 2015 17:34:09 -0700 Subject: [PATCH 130/151] [SPARK-8498] [SQL] Add regression test for SPARK-8470 **Summary of the problem in SPARK-8470.** When using `HiveContext` to create a data frame of a user case class, Spark throws `scala.reflect.internal.MissingRequirementError` when it tries to infer the schema using reflection. This is caused by `HiveContext` silently overwriting the context class loader containing the user classes. **What this issue is about.** This issue adds regression tests for SPARK-8470, which is already fixed in #6891. We closed SPARK-8470 as a duplicate because it is a different manifestation of the same problem in SPARK-8368. Due to the complexity of the reproduction, this requires us to pre-package a special test jar and include it in the Spark project itself. I tested this with and without the fix in #6891 and verified that it passes only if the fix is present. Author: Andrew Or Closes #6909 from andrewor14/SPARK-8498 and squashes the following commits: 5e9d688 [Andrew Or] Add regression test for SPARK-8470 --- .../regression-test-SPARK-8498/Main.scala | 43 ++++++++++++++++++ .../MyCoolClass.scala | 20 ++++++++ .../regression-test-SPARK-8498/test.jar | Bin 0 -> 6811 bytes .../spark/sql/hive/HiveSparkSubmitSuite.scala | 13 ++++++ 4 files changed, 76 insertions(+) create mode 100644 sql/hive/src/test/resources/regression-test-SPARK-8498/Main.scala create mode 100644 sql/hive/src/test/resources/regression-test-SPARK-8498/MyCoolClass.scala create mode 100644 sql/hive/src/test/resources/regression-test-SPARK-8498/test.jar diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8498/Main.scala b/sql/hive/src/test/resources/regression-test-SPARK-8498/Main.scala new file mode 100644 index 0000000000000..858dd6b5ddb05 --- /dev/null +++ b/sql/hive/src/test/resources/regression-test-SPARK-8498/Main.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.hive.HiveContext + +/** + * Entry point in test application for SPARK-8498. + * + * This file is not meant to be compiled during tests. It is already included + * in a pre-built "test.jar" located in the same directory as this file. + * This is included here for reference only and should NOT be modified without + * rebuilding the test jar itself. + * + * This is used in org.apache.spark.sql.hive.HiveSparkSubmitSuite. + */ +object Main { + def main(args: Array[String]) { + println("Running regression test for SPARK-8498.") + val sc = new SparkContext("local", "testing") + val hc = new HiveContext(sc) + // This line should not throw scala.reflect.internal.MissingRequirementError. + // See SPARK-8470 for more detail. + val df = hc.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) + df.collect() + println("Regression test for SPARK-8498 success!") + } +} + diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8498/MyCoolClass.scala b/sql/hive/src/test/resources/regression-test-SPARK-8498/MyCoolClass.scala new file mode 100644 index 0000000000000..a72c063a38197 --- /dev/null +++ b/sql/hive/src/test/resources/regression-test-SPARK-8498/MyCoolClass.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Dummy class used in regression test SPARK-8498. */ +case class MyCoolClass(past: String, present: String, future: String) + diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8498/test.jar b/sql/hive/src/test/resources/regression-test-SPARK-8498/test.jar new file mode 100644 index 0000000000000000000000000000000000000000..4f59fba9eab558131b2587e51b7c2e2d54348bd1 GIT binary patch literal 6811 zcmaJ`Wl&t(md1iK5TudD-2#n6aCdjt;0f-ojk~)Cg1cLA2^N9|cXyWrCYk%DZr+`F z^PL}O@2a!cUUh2ex0GaI;jo|(5fPy(9Hu3p{s0yf9F)AcnlM;eL4rwMSV3AsTvd%x zUgG%+6x5ac*r=Q=m~jeK7W`&nY_eX3d4_F$=W+}@EIlqiPA|;>4S=D#Qn`$*z)BE$ z14ITmtB$HfGOGofrfh%2R1t&lx8=}*is5iSQd<<8QE$1Rm60D$2){-A&e0W~?$I{E z9X>Fi0U+gM!ezpRQj$eLWT8oGwELwV88p)Zk`3 zoy1!W350DZzI{rPQTs0z`^dqW!YmE*;x?jaL6eUqDR??bD$=~BTF5I=Rp(S)lo%ct z#raLIKI%fFybC*|r!$L`>HtKkv3jH*IWM>A*1n;pQ5AQKj|yQ-#&YFN1BXu_QfK`v z&UecA*p5PUz#}UXLUQ-B1>sP<`y|gnDBDX&b^o~e4OO&itSD;;2G)&)X%x?5Lllc- zrj-%7?!ulpy#rBXU%xe!@fgdsA;!Rk7NQXV&5BV#aST9s)528d;xFqi^Hsj;xMR;u zwES{V+5r~%9$HXdJ;8XZc;{^cB3`Q7bOhC!1!EkkJ0UEv-lGa6ER$q_*fg0 z*|Gh`d-M*QU;{?a+NS=XeEP~m?|Mh2_2Lf)!`dx~El2w2V5a!!LkAp|z*0Z{tF)L5vMlLg4I~yW$onI@W>eJj( zhuLBdkx7)>U))|3PjB`Rcz%gN@^S4)P0&~(FTnVII!m~7JGJN=)8p;TqtSQrnT&iQ zoWY&RG(^OMjpT|Hb)FgZ40*vyG-URSDgD9S`>>B!K*^He zj(Wwx$J0cQQ;Lu!i(5)_0aB#3_sMAHRLv1G=q3N=I*XTJA=R>Ou4mx=N$Wl3;JlpN z9RQ>Lcy&B6REQ)>%x`^wT{>zTg^Az*$zTQ~1}JweA)S-DT3F<tfRIj1~cW8 zB$~0s;!SndNyj5LR?BM%s$NU*^*kjN!~@AcCaM?@=b$F3v_ZQ4QqgJDUL>PQ@d+jp;T;QRF5kp|q%3jwkXj zl?(=!nOR4bmhO&icDgz-D2b=T_>jw<_Eo{75QqK8tGSJ|F{z4jx;SyAM@=S)PKD@# z)eM``xF9}QY``1lVN=>~=rTv|zy?0f=J|bFq}DU|04xl^OZ-e1k+;32aGw$0M3&Qh zq1ee|7qE$+SbPM(Y+hE-^(?EFpjCCU)>|;l;}~9G_M2ZS0xG&t;5KQrNFRC%9X&=Y znFXdwtI1$mA&v?**D@()lvZ9G_--?0Tyx`@nA|z2Tsh9xq6VRjwG{^=#$J zW%420X|nUj(PzJBWGK1bJvfGgda+hLIjlwGm10;KMZEOxx1Qiqv4PD5R^+VUkbOLB%LBrx=#MNb)r%T*|z{-eb$W#g9oBKKfL z)G+HP>zr_&rVN{TODaL%$FLROil{xJ za6s=<9+A;dJ9m9<+2Xfr6zf zg?j`hg&P9e4JgcIa-N_MW)GGHey4^0Ss9y}US`g2E0Un8gqRxQS#lkNMv0@3EQ@=1 z+JpD(3gLuag-ptm>2i1Bx6Z*XZ}5@mroIFwkw?*W6ZR7s2k%q>Cew3)wa7m+%nNfS zt-n9SCmp$--LE@SG=iKhIdKWT-DXATn-Wy#jAUe6OybiDa3vbLd3>UIEM{2N*3$sS zJ96&2>EJbB(fwvNzBg} zXOOhp5jZQ1zydx+l@KN#e0$BxmYmVooVgm=F>HPR2?C*$=3;;(X~mK<&kTZCuU)va%j0VpCsd+-Egg=S|n zk=r5cn5Vg3c$=mI`58@F1-(P+=G^VIO3L~W<19TYu`z7^-8Hb+ZVGGFrtKo{e4kZp z>5m*$y-$}H#~UKSd;sHUL>Vcw36;o;h(pfn4oTPCn~HSMam2Raz2h;S+F`kLmEVumYjmMir$bX8)6 zw=-^|lda`UB)Y4lqM$2id~93lcs34S|3!(c>Z|MfX-R=5`)VB1=yad>6|I%*g)O-X zB|LSUA{>5!VBAV87E$sAeLjmAMA0LDOuQ+rM1(EW{YJ#j_EO5L72d`mbYT>Kz|TvFP=lL~44?4F2*`eMr#_#;jc$8UVf zB%32zTpZJ|TJPg$-si_i)f5I`Wmsb1Mn2X;`%YROlJAcf;=>aAD#|8bJ|5^4cDaI> z-pV{IZdxe_=DSdI^TYU50$Y{KaAR)lJG{&0Iq{3YKhU9|6JLM%q`RY5T19f84pTLH z)VQD4A1izk7)z?qt;h6MIMFBo-ZD`9a8J=V<2cZxB^Em85rZP3&H2Vs22dl@y$ax0 zpD)+xjOs$kI)mu7N}T~gs`o?$Ec>4cGaSBK28ZZwli_Rq#6vR4wGOL!ovF2WQxD{S zsaXlCotl-!EojYkk_xn0I1uq_4^vj42W#N#)S{+qJi~aUx~+KVj4L)SPK$BJw#zU3 z3A*WBP*4fu^{~7H z7_4jVivx`bOKEX`0KLzX$8ows79bJ9Qfvb8`;US8cZkY*KJ>;cMKJlQ08`%jwzLsZ$Goafim`nfFl0 zwZ=gX1Nuq4%l30kcoWRvgIxOg+xn`1G|V+T#ErCmu#8cZ%{AvUh@lTq<>K&@&09C% z3A&?y@Cpg~+ps&-hswv?{&+yf=MXUaa|scOQ)pbq(iWAkv8-*dL{I*Neg0?IE8^7R z(B;-pLeKk_aZ#AuER&_(iBBhpZjPZ}ue$1epLsUWJqOuuq;ZPIrGAow9{GPTCK2~p z8U%gE@`5EKAaF;`4Sj|ms|eriG3lLWNPVD@Y9~G=G4g_q%iiTBYQU8hC)3w&1pvYLou5NjT5Mvo6E-qFeI} zr2s7_xV~~KiM992=p6VG#?tq`JN2BRNi09xCH%u_H$U$ETPc!4j}X=ptZqO5q5?sc zAG+)!&8Pa|?i;6u&#!xiRQhz>#pAc3o)LZ*kF{k*kFE#utEY0!`1RZ3eb!E(S^myio%Gp9f}9U1ER~9 z&`ISooG*=vUu52Lwb|=E6%tO$^P*Trm^zT`dGGPwbzC0Z-H|}eDU(IbLQdZ6EZAxn z+l7CbtBPw2T8s~i@K~FWYD}LUu#%hx&W@V7H_>;zT=?hyJ^`xM3`^<;V!^nn<%G` z`3{w?D(#2LI$7YCC!DpI_RHSOR}vpaz32-~aMSL?OYLaUH+^ExxtrMB^$+t{(7@*O z>nu6UCh$Nlp?GD$rIx8SiA=?r1yPGfvq=c+T8oE!a+k|$5oi51HoYOEehHEB{D67_ zyX!kOcsqB+`f@_o2lgS0Y*9T$pn(x=_5Eb(yN3^^9;Tc%BX`Z0jeh@4Q!$OI^hf<_ zhU{Pce>RQ$KQ{d@eoEuc2`CBjp-)Ym0H7<&*F#~L4&{MhS(;S4olQwo0=THjEFKt0 zvd;PsOqn6qtMpbML$3`t$A(4uM(+21D5rPuBk`q7^a95%?_BSmx_k`3H$P4O>V8m- zczuB-N3$sO#X^V-+@GR9AgR`~hHr`PT{o2HNeCswEIC?@&%hpA0)hS^SA}l5c+75 zkYx40jlz0$Op=RZt+|9EzLfKG}LOcs>t=P>{{AjvlqFdz~exWt) zpU&yIE=C+DM>u$?NGb`_O9z79}VV;X#GlSvKRgHEmT_$tFvB9h&~P94a*$1+uJX7!G~ zNg5blM(r9rrIED6XisHuFdk^`oqayEBS?=dH1HbxgU}~VO-fp~9=Ry|EzeE#%%x!a z&)tLWH`XZ(XBtxh3OKnWDk=22B1SKB z790Fmu=uF#IZt6LwvYJYyds_+Br3gu)j7Eu4%m~-FT^T7xx4QX2yLQ91lcCJwZnFH zW#$sF+?IHneP2(?e)H@(Q^}+#U#BWAhYM`B=EVN!zE0$r-P)^ zgLB3xL7;n3EP~t2I!oznmRwa-mdrI3`tU^_`$|9@9j8n!_!%@RZ4fbZnG8{Ks3=-a z6B_MRgyjg4^v;EKQYAym6TXbb+~hrUsvE2%r(9d+o(FW9-XvkS3wh^rVSD#|)JWDi zC&~@02g<3&@{eIgQ|_=GB<+F{a-Ym7)|^9%NJe8(4-~8<+P*m3wDp z>H!@+t;L3}jcb)5R^i_=-{g7XW3VI;tlv4k53?PdpB`;c>lXQ9wJiu(!fbMJTo} z^0z>CudQ#YZStR2l(n|Q%!H71$6BLd3vYp|Qw{C5h}vc%WZiMkl;6`>cf_KBVvPk_ z1^OjRwsa_}HcTWA@nO%A1k8?o3vS;}408^2m<>+L&*46aUy`n-OqdF*r(cnt8CU$ML3x3&LK ze!c(S+?8Ypd7R^~Fdc^SJ5)o%U_t#o2>El`@Ou#Q@Ayw+k$+10lTZCUnfM!xzn1@d zPVwjY->mDO?AmYE^>5()Qt59F_Rn*_xwk(HtKZz)-;jm>w|oCeKmR=bXAb*aef^Cv t^uMnCziY97%K7sL|IT=SLl5{bIsc=;t6zFSLBaoee1EZt2zbAD{{ Date: Fri, 19 Jun 2015 17:39:26 -0700 Subject: [PATCH 131/151] [HOTFIX] [SPARK-8489] Correct JIRA number in previous commit It should be SPARK-8489, not SPARK-8498. --- .../Main.scala | 6 +++--- .../MyCoolClass.scala | 2 +- .../test.jar | Bin .../spark/sql/hive/HiveSparkSubmitSuite.scala | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) rename sql/hive/src/test/resources/{regression-test-SPARK-8498 => regression-test-SPARK-8489}/Main.scala (90%) rename sql/hive/src/test/resources/{regression-test-SPARK-8498 => regression-test-SPARK-8489}/MyCoolClass.scala (94%) rename sql/hive/src/test/resources/{regression-test-SPARK-8498 => regression-test-SPARK-8489}/test.jar (100%) diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8498/Main.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala similarity index 90% rename from sql/hive/src/test/resources/regression-test-SPARK-8498/Main.scala rename to sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala index 858dd6b5ddb05..e1715177e3f1b 100644 --- a/sql/hive/src/test/resources/regression-test-SPARK-8498/Main.scala +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala @@ -19,7 +19,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.hive.HiveContext /** - * Entry point in test application for SPARK-8498. + * Entry point in test application for SPARK-8489. * * This file is not meant to be compiled during tests. It is already included * in a pre-built "test.jar" located in the same directory as this file. @@ -30,14 +30,14 @@ import org.apache.spark.sql.hive.HiveContext */ object Main { def main(args: Array[String]) { - println("Running regression test for SPARK-8498.") + println("Running regression test for SPARK-8489.") val sc = new SparkContext("local", "testing") val hc = new HiveContext(sc) // This line should not throw scala.reflect.internal.MissingRequirementError. // See SPARK-8470 for more detail. val df = hc.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) df.collect() - println("Regression test for SPARK-8498 success!") + println("Regression test for SPARK-8489 success!") } } diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8498/MyCoolClass.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/MyCoolClass.scala similarity index 94% rename from sql/hive/src/test/resources/regression-test-SPARK-8498/MyCoolClass.scala rename to sql/hive/src/test/resources/regression-test-SPARK-8489/MyCoolClass.scala index a72c063a38197..b1681745c2ef7 100644 --- a/sql/hive/src/test/resources/regression-test-SPARK-8498/MyCoolClass.scala +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/MyCoolClass.scala @@ -15,6 +15,6 @@ * limitations under the License. */ -/** Dummy class used in regression test SPARK-8498. */ +/** Dummy class used in regression test SPARK-8489. */ case class MyCoolClass(past: String, present: String, future: String) diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8498/test.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar similarity index 100% rename from sql/hive/src/test/resources/regression-test-SPARK-8498/test.jar rename to sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 820af801a76ef..ab443032be20d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -67,13 +67,13 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } - test("SPARK-8498: MissingRequirementError during reflection") { - // This test uses a pre-built jar to test SPARK-8498. In a nutshell, this test creates + test("SPARK-8489: MissingRequirementError during reflection") { + // This test uses a pre-built jar to test SPARK-8489. In a nutshell, this test creates // a HiveContext and uses it to create a data frame from an RDD using reflection. // Before the fix in SPARK-8470, this results in a MissingRequirementError because // the HiveContext code mistakenly overrides the class loader that contains user classes. - // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8498/*scala. - val testJar = "sql/hive/src/test/resources/regression-test-SPARK-8498/test.jar" + // For more detail, see sql/hive/src/test/resources/regression-test-SPARK-8489/*scala. + val testJar = "sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar" val args = Seq("--class", "Main", testJar) runSparkSubmit(args) } From 1b6fe9b1a70aa3f81448c2705ea3a4b501cbda9d Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Fri, 19 Jun 2015 18:54:07 -0700 Subject: [PATCH 132/151] [SPARK-8127] [STREAMING] [KAFKA] KafkaRDD optimize count() take() isEmpty() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ed KafkaRDD methods. Possible fix for [SPARK-7122], but probably a worthwhile optimization regardless. Author: cody koeninger Closes #6632 from koeninger/kafka-rdd-count and squashes the following commits: 321340d [cody koeninger] [SPARK-8127][Streaming][Kafka] additional test of ordering of take() 5a05d0f [cody koeninger] [SPARK-8127][Streaming][Kafka] additional test of isEmpty f68bd32 [cody koeninger] [Streaming][Kafka][SPARK-8127] code cleanup 9555b73 [cody koeninger] Merge branch 'master' into kafka-rdd-count 253031d [cody koeninger] [Streaming][Kafka][SPARK-8127] mima exclusion for change to private method 8974b9e [cody koeninger] [Streaming][Kafka][SPARK-8127] check offset ranges before constructing KafkaRDD c3768c5 [cody koeninger] [Streaming][Kafka] Take advantage of offset range info for size-related KafkaRDD methods. Possible fix for [SPARK-7122], but probably a worthwhile optimization regardless. --- .../kafka/DirectKafkaInputDStream.scala | 8 +--- .../spark/streaming/kafka/KafkaCluster.scala | 8 ++++ .../spark/streaming/kafka/KafkaRDD.scala | 44 ++++++++++++++++++ .../streaming/kafka/KafkaRDDPartition.scala | 5 +- .../spark/streaming/kafka/KafkaUtils.scala | 46 +++++++++++++------ .../spark/streaming/kafka/OffsetRange.scala | 6 +++ .../spark/streaming/kafka/KafkaRDDSuite.scala | 26 +++++++++-- project/MimaExcludes.scala | 3 ++ 8 files changed, 122 insertions(+), 24 deletions(-) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 060c2f23eded8..876456c964770 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -120,8 +120,7 @@ class DirectKafkaInputDStream[ context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) // Report the record number of this batch interval to InputInfoTracker. - val numRecords = rdd.offsetRanges.map(r => r.untilOffset - r.fromOffset).sum - val inputInfo = InputInfo(id, numRecords) + val inputInfo = InputInfo(id, rdd.count) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) @@ -153,10 +152,7 @@ class DirectKafkaInputDStream[ override def restore() { // this is assuming that the topics don't change during execution, which is true currently val topics = fromOffsets.keySet - val leaders = kc.findLeaders(topics).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) + val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics)) batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) => logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}") diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 65d51d87f8486..3e6b937af57b0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -360,6 +360,14 @@ private[spark] object KafkaCluster { type Err = ArrayBuffer[Throwable] + /** If the result is right, return it, otherwise throw SparkException */ + def checkErrors[T](result: Either[Err, T]): T = { + result.fold( + errs => throw new SparkException(errs.mkString("\n")), + ok => ok + ) + } + private[spark] case class LeaderOffset(host: String, port: Int, offset: Long) diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index a1b4a12e5d6a0..c5cd2154772ac 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -17,9 +17,11 @@ package org.apache.spark.streaming.kafka +import scala.collection.mutable.ArrayBuffer import scala.reflect.{classTag, ClassTag} import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.partial.{PartialResult, BoundedDouble} import org.apache.spark.rdd.RDD import org.apache.spark.util.NextIterator @@ -60,6 +62,48 @@ class KafkaRDD[ }.toArray } + override def count(): Long = offsetRanges.map(_.count).sum + + override def countApprox( + timeout: Long, + confidence: Double = 0.95 + ): PartialResult[BoundedDouble] = { + val c = count + new PartialResult(new BoundedDouble(c, 1.0, c, c), true) + } + + override def isEmpty(): Boolean = count == 0L + + override def take(num: Int): Array[R] = { + val nonEmptyPartitions = this.partitions + .map(_.asInstanceOf[KafkaRDDPartition]) + .filter(_.count > 0) + + if (num < 1 || nonEmptyPartitions.size < 1) { + return new Array[R](0) + } + + // Determine in advance how many messages need to be taken from each partition + val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) => + val remain = num - result.values.sum + if (remain > 0) { + val taken = Math.min(remain, part.count) + result + (part.index -> taken.toInt) + } else { + result + } + } + + val buf = new ArrayBuffer[R] + val res = context.runJob( + this, + (tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray, + parts.keys.toArray, + allowLocal = true) + res.foreach(buf ++= _) + buf.toArray + } + override def getPreferredLocations(thePart: Partition): Seq[String] = { val part = thePart.asInstanceOf[KafkaRDDPartition] // TODO is additional hostname resolution necessary here diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala index a842a6f17766f..a660d2a00c35d 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala @@ -35,4 +35,7 @@ class KafkaRDDPartition( val untilOffset: Long, val host: String, val port: Int -) extends Partition +) extends Partition { + /** Number of messages this partition refers to */ + def count(): Long = untilOffset - fromOffset +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 0b8a391a2c569..0e33362d34acd 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -158,15 +158,31 @@ object KafkaUtils { /** get leaders for the given offset ranges, or throw an exception */ private def leadersForRanges( - kafkaParams: Map[String, String], + kc: KafkaCluster, offsetRanges: Array[OffsetRange]): Map[TopicAndPartition, (String, Int)] = { - val kc = new KafkaCluster(kafkaParams) val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet - val leaders = kc.findLeaders(topics).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) - leaders + val leaders = kc.findLeaders(topics) + KafkaCluster.checkErrors(leaders) + } + + /** Make sure offsets are available in kafka, or throw an exception */ + private def checkOffsets( + kc: KafkaCluster, + offsetRanges: Array[OffsetRange]): Unit = { + val topics = offsetRanges.map(_.topicAndPartition).toSet + val result = for { + low <- kc.getEarliestLeaderOffsets(topics).right + high <- kc.getLatestLeaderOffsets(topics).right + } yield { + offsetRanges.filterNot { o => + low(o.topicAndPartition).offset <= o.fromOffset && + o.untilOffset <= high(o.topicAndPartition).offset + } + } + val badRanges = KafkaCluster.checkErrors(result) + if (!badRanges.isEmpty) { + throw new SparkException("Offsets not available on leader: " + badRanges.mkString(",")) + } } /** @@ -191,7 +207,9 @@ object KafkaUtils { offsetRanges: Array[OffsetRange] ): RDD[(K, V)] = sc.withScope { val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) - val leaders = leadersForRanges(kafkaParams, offsetRanges) + val kc = new KafkaCluster(kafkaParams) + val leaders = leadersForRanges(kc, offsetRanges) + checkOffsets(kc, offsetRanges) new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler) } @@ -225,8 +243,9 @@ object KafkaUtils { leaders: Map[TopicAndPartition, Broker], messageHandler: MessageAndMetadata[K, V] => R ): RDD[R] = sc.withScope { + val kc = new KafkaCluster(kafkaParams) val leaderMap = if (leaders.isEmpty) { - leadersForRanges(kafkaParams, offsetRanges) + leadersForRanges(kc, offsetRanges) } else { // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker leaders.map { @@ -234,6 +253,7 @@ object KafkaUtils { }.toMap } val cleanedHandler = sc.clean(messageHandler) + checkOffsets(kc, offsetRanges) new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, cleanedHandler) } @@ -399,7 +419,7 @@ object KafkaUtils { val kc = new KafkaCluster(kafkaParams) val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase) - (for { + val result = for { topicPartitions <- kc.getPartitions(topics).right leaderOffsets <- (if (reset == Some("smallest")) { kc.getEarliestLeaderOffsets(topicPartitions) @@ -412,10 +432,8 @@ object KafkaUtils { } new DirectKafkaInputDStream[K, V, KD, VD, (K, V)]( ssc, kafkaParams, fromOffsets, messageHandler) - }).fold( - errs => throw new SparkException(errs.mkString("\n")), - ok => ok - ) + } + KafkaCluster.checkErrors(result) } /** diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala index 9c3dfeb8f5928..2675042666304 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala @@ -55,6 +55,12 @@ final class OffsetRange private( val untilOffset: Long) extends Serializable { import OffsetRange.OffsetRangeTuple + /** Kafka TopicAndPartition object, for convenience */ + def topicAndPartition(): TopicAndPartition = TopicAndPartition(topic, partition) + + /** Number of messages this OffsetRange refers to */ + def count(): Long = untilOffset - fromOffset + override def equals(obj: Any): Boolean = obj match { case that: OffsetRange => this.topic == that.topic && diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index d5baf5fd89994..f52a738afd65b 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -55,8 +55,8 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { test("basic usage") { val topic = s"topicbasic-${Random.nextInt}" kafkaTestUtils.createTopic(topic) - val messages = Set("the", "quick", "brown", "fox") - kafkaTestUtils.sendMessages(topic, messages.toArray) + val messages = Array("the", "quick", "brown", "fox") + kafkaTestUtils.sendMessages(topic, messages) val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress, "group.id" -> s"test-consumer-${Random.nextInt}") @@ -67,7 +67,27 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { sc, kafkaParams, offsetRanges) val received = rdd.map(_._2).collect.toSet - assert(received === messages) + assert(received === messages.toSet) + + // size-related method optimizations return sane results + assert(rdd.count === messages.size) + assert(rdd.countApprox(0).getFinalValue.mean === messages.size) + assert(!rdd.isEmpty) + assert(rdd.take(1).size === 1) + assert(rdd.take(1).head._2 === messages.head) + assert(rdd.take(messages.size + 10).size === messages.size) + + val emptyRdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0))) + + assert(emptyRdd.isEmpty) + + // invalid offset ranges throw exceptions + val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1)) + intercept[SparkException] { + KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + sc, kafkaParams, badRanges) + } } test("iterator boundary conditions") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 8a93ca2999510..015d0296dd369 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -44,6 +44,9 @@ object MimaExcludes { // JavaRDDLike is not meant to be extended by user programs ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.partitioner"), + // Modification of private static method + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"), // Mima false positive (was a private[spark] class) ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.util.collection.PairIterator"), From 0b8995168f02bb55afb0a5b7dbdb941c3c89cb4c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 20 Jun 2015 13:01:59 -0700 Subject: [PATCH 133/151] [SPARK-8468] [ML] Take the negative of some metrics in RegressionEvaluator to get correct cross validation JIRA: https://issues.apache.org/jira/browse/SPARK-8468 Author: Liang-Chi Hsieh Closes #6905 from viirya/cv_min and squashes the following commits: 930d3db [Liang-Chi Hsieh] Fix python unit test and add document. d632135 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cv_min 16e3b2c [Liang-Chi Hsieh] Take the negative instead of reciprocal. c3dd8d9 [Liang-Chi Hsieh] For comments. b5f52c1 [Liang-Chi Hsieh] Add param to CrossValidator for choosing whether to maximize evaulation value. --- .../ml/evaluation/RegressionEvaluator.scala | 10 ++++-- .../org/apache/spark/ml/param/params.scala | 2 +- .../evaluation/RegressionEvaluatorSuite.scala | 4 +-- .../spark/ml/tuning/CrossValidatorSuite.scala | 35 +++++++++++++++++-- python/pyspark/ml/evaluation.py | 8 +++-- 5 files changed, 48 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 8670e9679d055..01c000b47514c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -37,6 +37,10 @@ final class RegressionEvaluator(override val uid: String) /** * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`) + * + * Because we will maximize evaluation value (ref: `CrossValidator`), + * when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), + * we take and output the negative of this metric. * @group param */ val metricName: Param[String] = { @@ -70,13 +74,13 @@ final class RegressionEvaluator(override val uid: String) val metrics = new RegressionMetrics(predictionAndLabels) val metric = $(metricName) match { case "rmse" => - metrics.rootMeanSquaredError + -metrics.rootMeanSquaredError case "mse" => - metrics.meanSquaredError + -metrics.meanSquaredError case "r2" => metrics.r2 case "mae" => - metrics.meanAbsoluteError + -metrics.meanAbsoluteError } metric } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 15ebad8838a2a..50c0d855066f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -297,7 +297,7 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array /** * :: Experimental :: - * A param amd its value. + * A param and its value. */ @Experimental case class ParamPair[T](param: Param[T], value: T) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index aa722da323935..5b203784559e2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -63,7 +63,7 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext // default = rmse val evaluator = new RegressionEvaluator() - assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001) // r2 score evaluator.setMetricName("r2") @@ -71,6 +71,6 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext // mae evaluator.setMetricName("mae") - assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 36af4b34a9e40..db64511a76055 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -20,11 +20,12 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator} +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol +import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.types.StructType @@ -58,6 +59,36 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(cvModel.avgMetrics.length === lrParamMaps.length) } + test("cross validation with linear regression") { + val dataset = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + + val trainer = new LinearRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(trainer.regParam, Array(1000.0, 0.001)) + .addGrid(trainer.maxIter, Array(0, 10)) + .build() + val eval = new RegressionEvaluator() + val cv = new CrossValidator() + .setEstimator(trainer) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3) + val cvModel = cv.fit(dataset) + val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) + + eval.setMetricName("r2") + val cvModel2 = cv.fit(dataset) + val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent2.getRegParam === 0.001) + assert(parent2.getMaxIter === 10) + assert(cvModel2.avgMetrics.length === lrParamMaps.length) + } + test("validateParams should check estimatorParamMaps") { import CrossValidatorSuite._ diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index d8ddb78c6d639..595593a7f2cde 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -160,13 +160,15 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): ... >>> evaluator = RegressionEvaluator(predictionCol="raw") >>> evaluator.evaluate(dataset) - 2.842... + -2.842... >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"}) 0.993... >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) - 2.649... + -2.649... """ - # a placeholder to make it appear in the generated doc + # Because we will maximize evaluation value (ref: `CrossValidator`), + # when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), + # we take and output the negative of this metric. metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (mse|rmse|r2|mae)") From 7a3c424ecf815b9d5e06e222dd875e5a31a26400 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 20 Jun 2015 16:04:45 -0700 Subject: [PATCH 134/151] [SPARK-8422] [BUILD] [PROJECT INFRA] Add a module abstraction to dev/run-tests This patch builds upon #5694 to add a 'module' abstraction to the `dev/run-tests` script which groups together the per-module test logic, including the mapping from file paths to modules, the mapping from modules to test goals and build profiles, and the dependencies / relationships between modules. This refactoring makes it much easier to increase the granularity of test modules, which will let us skip even more tests. It's also a prerequisite for other changes that will reduce test time, such as running subsets of the Python tests based on which files / modules have changed. This patch also adds doctests for the new graph traversal / change mapping code. Author: Josh Rosen Closes #6866 from JoshRosen/more-dev-run-tests-refactoring and squashes the following commits: 75de450 [Josh Rosen] Use module system to determine which build profiles to enable. 4224da5 [Josh Rosen] Add documentation to Module. a86a953 [Josh Rosen] Clean up modules; add new modules for streaming external projects e46539f [Josh Rosen] Fix camel-cased endswith() 35a3052 [Josh Rosen] Enable Hive tests when running all tests df10e23 [Josh Rosen] update to reflect fact that no module depends on root 3670d50 [Josh Rosen] mllib should depend on streaming dc6f1c6 [Josh Rosen] Use changed files' extensions to decide whether to run style checks 7092d3e [Josh Rosen] Skip SBT tests if no test goals are specified 43a0ced [Josh Rosen] Minor fixes 3371441 [Josh Rosen] Test everything if nothing has changed (needed for non-PRB builds) 37f3fb3 [Josh Rosen] Remove doc profiles option, since it's not actually needed (see #6865) f53864b [Josh Rosen] Finish integrating module changes f0249bd [Josh Rosen] WIP --- dev/run-tests.py | 567 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 411 insertions(+), 156 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index c64c71f4f723f..2cccfed75edee 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -17,6 +17,7 @@ # limitations under the License. # +import itertools import os import re import sys @@ -28,6 +29,361 @@ USER_HOME = os.environ.get("HOME") +# ------------------------------------------------------------------------------------------------- +# Test module definitions and functions for traversing module dependency graph +# ------------------------------------------------------------------------------------------------- + + +all_modules = [] + + +class Module(object): + """ + A module is the basic abstraction in our test runner script. Each module consists of a set of + source files, a set of test commands, and a set of dependencies on other modules. We use modules + to define a dependency graph that lets determine which tests to run based on which files have + changed. + """ + + def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), + sbt_test_goals=(), should_run_python_tests=False, should_run_r_tests=False): + """ + Define a new module. + + :param name: A short module name, for display in logging and error messages. + :param dependencies: A set of dependencies for this module. This should only include direct + dependencies; transitive dependencies are resolved automatically. + :param source_file_regexes: a set of regexes that match source files belonging to this + module. These regexes are applied by attempting to match at the beginning of the + filename strings. + :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in + order to build and test this module (e.g. '-PprofileName'). + :param sbt_test_goals: A set of SBT test goals for testing this module. + :param should_run_python_tests: If true, changes in this module will trigger Python tests. + For now, this has the effect of causing _all_ Python tests to be run, although in the + future this should be changed to run only a subset of the Python tests that depend + on this module. + :param should_run_r_tests: If true, changes in this module will trigger all R tests. + """ + self.name = name + self.dependencies = dependencies + self.source_file_prefixes = source_file_regexes + self.sbt_test_goals = sbt_test_goals + self.build_profile_flags = build_profile_flags + self.should_run_python_tests = should_run_python_tests + self.should_run_r_tests = should_run_r_tests + + self.dependent_modules = set() + for dep in dependencies: + dep.dependent_modules.add(self) + all_modules.append(self) + + def contains_file(self, filename): + return any(re.match(p, filename) for p in self.source_file_prefixes) + + +sql = Module( + name="sql", + dependencies=[], + source_file_regexes=[ + "sql/(?!hive-thriftserver)", + "bin/spark-sql", + ], + build_profile_flags=[ + "-Phive", + ], + sbt_test_goals=[ + "catalyst/test", + "sql/test", + "hive/test", + ] +) + + +hive_thriftserver = Module( + name="hive-thriftserver", + dependencies=[sql], + source_file_regexes=[ + "sql/hive-thriftserver", + "sbin/start-thriftserver.sh", + ], + build_profile_flags=[ + "-Phive-thriftserver", + ], + sbt_test_goals=[ + "hive-thriftserver/test", + ] +) + + +graphx = Module( + name="graphx", + dependencies=[], + source_file_regexes=[ + "graphx/", + ], + sbt_test_goals=[ + "graphx/test" + ] +) + + +streaming = Module( + name="streaming", + dependencies=[], + source_file_regexes=[ + "streaming", + ], + sbt_test_goals=[ + "streaming/test", + ] +) + + +streaming_kinesis_asl = Module( + name="kinesis-asl", + dependencies=[streaming], + source_file_regexes=[ + "extras/kinesis-asl/", + ], + build_profile_flags=[ + "-Pkinesis-asl", + ], + sbt_test_goals=[ + "kinesis-asl/test", + ] +) + + +streaming_zeromq = Module( + name="streaming-zeromq", + dependencies=[streaming], + source_file_regexes=[ + "external/zeromq", + ], + sbt_test_goals=[ + "streaming-zeromq/test", + ] +) + + +streaming_twitter = Module( + name="streaming-twitter", + dependencies=[streaming], + source_file_regexes=[ + "external/twitter", + ], + sbt_test_goals=[ + "streaming-twitter/test", + ] +) + + +streaming_mqqt = Module( + name="streaming-mqqt", + dependencies=[streaming], + source_file_regexes=[ + "external/mqqt", + ], + sbt_test_goals=[ + "streaming-mqqt/test", + ] +) + + +streaming_kafka = Module( + name="streaming-kafka", + dependencies=[streaming], + source_file_regexes=[ + "external/kafka", + "external/kafka-assembly", + ], + sbt_test_goals=[ + "streaming-kafka/test", + ] +) + + +streaming_flume_sink = Module( + name="streaming-flume-sink", + dependencies=[streaming], + source_file_regexes=[ + "external/flume-sink", + ], + sbt_test_goals=[ + "streaming-flume-sink/test", + ] +) + + +streaming_flume = Module( + name="streaming_flume", + dependencies=[streaming], + source_file_regexes=[ + "external/flume", + ], + sbt_test_goals=[ + "streaming-flume/test", + ] +) + + +mllib = Module( + name="mllib", + dependencies=[streaming, sql], + source_file_regexes=[ + "data/mllib/", + "mllib/", + ], + sbt_test_goals=[ + "mllib/test", + ] +) + + +examples = Module( + name="examples", + dependencies=[graphx, mllib, streaming, sql], + source_file_regexes=[ + "examples/", + ], + sbt_test_goals=[ + "examples/test", + ] +) + + +pyspark = Module( + name="pyspark", + dependencies=[mllib, streaming, streaming_kafka, sql], + source_file_regexes=[ + "python/" + ], + should_run_python_tests=True +) + + +sparkr = Module( + name="sparkr", + dependencies=[sql, mllib], + source_file_regexes=[ + "R/", + ], + should_run_r_tests=True +) + + +docs = Module( + name="docs", + dependencies=[], + source_file_regexes=[ + "docs/", + ] +) + + +ec2 = Module( + name="ec2", + dependencies=[], + source_file_regexes=[ + "ec2/", + ] +) + + +# The root module is a dummy module which is used to run all of the tests. +# No other modules should directly depend on this module. +root = Module( + name="root", + dependencies=[], + source_file_regexes=[], + # In order to run all of the tests, enable every test profile: + build_profile_flags= + list(set(itertools.chain.from_iterable(m.build_profile_flags for m in all_modules))), + sbt_test_goals=[ + "test", + ], + should_run_python_tests=True, + should_run_r_tests=True +) + + +def determine_modules_for_files(filenames): + """ + Given a list of filenames, return the set of modules that contain those files. + If a file is not associated with a more specific submodule, then this method will consider that + file to belong to the 'root' module. + + >>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/test/foo"])) + ['pyspark', 'sql'] + >>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])] + ['root'] + """ + changed_modules = set() + for filename in filenames: + matched_at_least_one_module = False + for module in all_modules: + if module.contains_file(filename): + changed_modules.add(module) + matched_at_least_one_module = True + if not matched_at_least_one_module: + changed_modules.add(root) + return changed_modules + + +def identify_changed_files_from_git_commits(patch_sha, target_branch=None, target_ref=None): + """ + Given a git commit and target ref, use the set of files changed in the diff in order to + determine which modules' tests should be run. + + >>> [x.name for x in determine_modules_for_files( \ + identify_changed_files_from_git_commits("fc0a1475ef", target_ref="5da21f07"))] + ['graphx'] + >>> 'root' in [x.name for x in determine_modules_for_files( \ + identify_changed_files_from_git_commits("50a0496a43", target_ref="6765ef9"))] + True + """ + if target_branch is None and target_ref is None: + raise AttributeError("must specify either target_branch or target_ref") + elif target_branch is not None and target_ref is not None: + raise AttributeError("must specify either target_branch or target_ref, not both") + if target_branch is not None: + diff_target = target_branch + run_cmd(['git', 'fetch', 'origin', str(target_branch+':'+target_branch)]) + else: + diff_target = target_ref + raw_output = subprocess.check_output(['git', 'diff', '--name-only', patch_sha, diff_target]) + # Remove any empty strings + return [f for f in raw_output.split('\n') if f] + + +def determine_modules_to_test(changed_modules): + """ + Given a set of modules that have changed, compute the transitive closure of those modules' + dependent modules in order to determine the set of modules that should be tested. + + >>> sorted(x.name for x in determine_modules_to_test([root])) + ['root'] + >>> sorted(x.name for x in determine_modules_to_test([graphx])) + ['examples', 'graphx'] + >>> sorted(x.name for x in determine_modules_to_test([sql])) + ['examples', 'hive-thriftserver', 'mllib', 'pyspark', 'sparkr', 'sql'] + """ + # If we're going to have to run all of the tests, then we can just short-circuit + # and return 'root'. No module depends on root, so if it appears then it will be + # in changed_modules. + if root in changed_modules: + return [root] + modules_to_test = set() + for module in changed_modules: + modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules)) + return modules_to_test.union(set(changed_modules)) + + +# ------------------------------------------------------------------------------------------------- +# Functions for working with subprocesses and shell tools +# ------------------------------------------------------------------------------------------------- + def get_error_codes(err_code_file): """Function to retrieve all block numbers from the `run-tests-codes.sh` file to maintain backwards compatibility with the `run-tests-jenkins` @@ -43,7 +399,7 @@ def get_error_codes(err_code_file): def exit_from_command_with_retcode(cmd, retcode): - print "[error] running", cmd, "; received return code", retcode + print "[error] running", ' '.join(cmd), "; received return code", retcode sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) @@ -82,7 +438,7 @@ def which(program): """Find and return the given program by its absolute path or 'None' - from: http://stackoverflow.com/a/377028""" - fpath, fname = os.path.split(program) + fpath = os.path.split(program)[0] if fpath: if is_exe(program): @@ -134,6 +490,11 @@ def determine_java_version(java_exe): update=version_info[3]) +# ------------------------------------------------------------------------------------------------- +# Functions for running the other build and test scripts +# ------------------------------------------------------------------------------------------------- + + def set_title_and_block(title, err_block): os.environ["CURRENT_BLOCK"] = ERROR_CODES[err_block] line_str = '=' * 72 @@ -177,14 +538,14 @@ def build_spark_documentation(): os.chdir(SPARK_HOME) -def exec_maven(mvn_args=[]): +def exec_maven(mvn_args=()): """Will call Maven in the current directory with the list of mvn_args passed in and returns the subprocess for any further processing""" run_cmd([os.path.join(SPARK_HOME, "build", "mvn")] + mvn_args) -def exec_sbt(sbt_args=[]): +def exec_sbt(sbt_args=()): """Will call SBT in the current directory with the list of mvn_args passed in and returns the subprocess for any further processing""" @@ -213,8 +574,10 @@ def exec_sbt(sbt_args=[]): def get_hadoop_profiles(hadoop_version): - """Return a list of profiles indicating which Hadoop version to use from - a Hadoop version tag.""" + """ + For the given Hadoop version tag, return a list of SBT profile flags for + building and testing against that Hadoop version. + """ sbt_maven_hadoop_profiles = { "hadoop1.0": ["-Phadoop-1", "-Dhadoop.version=1.0.4"], @@ -231,35 +594,9 @@ def get_hadoop_profiles(hadoop_version): sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) -def get_build_profiles(hadoop_version="hadoop2.3", - enable_base_profiles=True, - enable_hive_profiles=False, - enable_doc_profiles=False): - """Returns a list of hadoop profiles to be used as looked up from the passed in hadoop profile - key with the option of adding on the base and hive profiles.""" - - base_profiles = ["-Pkinesis-asl"] - hive_profiles = ["-Phive", "-Phive-thriftserver"] - doc_profiles = [] - hadoop_profiles = get_hadoop_profiles(hadoop_version) - - build_profiles = hadoop_profiles - - if enable_base_profiles: - build_profiles += base_profiles - - if enable_hive_profiles: - build_profiles += hive_profiles - - if enable_doc_profiles: - build_profiles += doc_profiles - - return build_profiles - - def build_spark_maven(hadoop_version): - # we always build with Hive support even if we skip Hive tests in most builds - build_profiles = get_build_profiles(hadoop_version, enable_hive_profiles=True) + # Enable all of the profiles for the build: + build_profiles = get_hadoop_profiles(hadoop_version) + root.build_profile_flags mvn_goals = ["clean", "package", "-DskipTests"] profiles_and_goals = build_profiles + mvn_goals @@ -270,7 +607,8 @@ def build_spark_maven(hadoop_version): def build_spark_sbt(hadoop_version): - build_profiles = get_build_profiles(hadoop_version, enable_hive_profiles=True) + # Enable all of the profiles for the build: + build_profiles = get_hadoop_profiles(hadoop_version) + root.build_profile_flags sbt_goals = ["package", "assembly/assembly", "streaming-kafka-assembly/assembly"] @@ -301,84 +639,6 @@ def detect_binary_inop_with_mima(): run_cmd([os.path.join(SPARK_HOME, "dev", "mima")]) -def identify_changed_modules(test_env): - """Given the passed in environment will determine the changed modules and - return them as a set. If the environment is local, will simply run all tests. - If run under the `amplab_jenkins` environment will determine the changed files - as compared to the `ghprbTargetBranch` and execute the necessary set of tests - to provide coverage for the changed code.""" - changed_modules = set() - - if test_env == "amplab_jenkins": - target_branch = os.environ["ghprbTargetBranch"] - - run_cmd(['git', 'fetch', 'origin', str(target_branch+':'+target_branch)]) - - raw_output = subprocess.check_output(['git', 'diff', '--name-only', target_branch]) - # remove any empty strings - changed_files = [f for f in raw_output.split('\n') if f] - - sql_files = [f for f in changed_files - if any(f.startswith(p) for p in - ["sql/", - "bin/spark-sql", - "sbin/start-thriftserver.sh", - "examples/src/main/java/org/apache/spark/examples/sql/", - "examples/src/main/scala/org/apache/spark/examples/sql/"])] - mllib_files = [f for f in changed_files - if any(f.startswith(p) for p in - ["examples/src/main/java/org/apache/spark/examples/mllib/", - "examples/src/main/scala/org/apache/spark/examples/mllib", - "data/mllib/", - "mllib/"])] - streaming_files = [f for f in changed_files - if any(f.startswith(p) for p in - ["examples/scala-2.10/", - "examples/src/main/java/org/apache/spark/examples/streaming/", - "examples/src/main/scala/org/apache/spark/examples/streaming/", - "external/", - "extras/java8-tests/", - "extras/kinesis-asl/", - "streaming/"])] - graphx_files = [f for f in changed_files - if any(f.startswith(p) for p in - ["examples/src/main/scala/org/apache/spark/examples/graphx/", - "graphx/"])] - doc_files = [f for f in changed_files if f.startswith("docs/")] - - # union together all changed top level project files - top_level_project_files = set().union(*[set(f) for f in [sql_files, - mllib_files, - streaming_files, - graphx_files, - doc_files]]) - changed_core_files = set(changed_files).difference(top_level_project_files) - - if changed_core_files: - changed_modules.add("CORE") - if sql_files: - print "[info] Detected changes in SQL. Will run Hive test suite." - changed_modules.add("SQL") - if mllib_files: - print "[info] Detected changes in MLlib. Will run MLlib test suite." - changed_modules.add("MLLIB") - if streaming_files: - print "[info] Detected changes in Streaming. Will run Streaming test suite." - changed_modules.add("STREAMING") - if graphx_files: - print "[info] Detected changes in GraphX. Will run GraphX test suite." - changed_modules.add("GRAPHX") - if doc_files: - print "[info] Detected changes in documentation. Will build spark with documentation." - changed_modules.add("DOCS") - - return changed_modules - else: - # we aren't in the Amplab environment so simply run all tests - changed_modules.add("ALL") - return changed_modules - - def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] profiles_and_goals = test_profiles + mvn_test_goals @@ -390,38 +650,13 @@ def run_scala_tests_maven(test_profiles): def run_scala_tests_sbt(test_modules, test_profiles): - # declare the variable for reference - sbt_test_goals = [] - if "ALL" in test_modules: - sbt_test_goals = ["test"] - else: - # if we only have changes in SQL, MLlib, Streaming, or GraphX then build - # a custom test list - if "SQL" in test_modules and "CORE" not in test_modules: - sbt_test_goals += ["catalyst/test", - "sql/test", - "hive/test", - "hive-thriftserver/test", - "mllib/test", - "examples/test"] - if "MLLIB" in test_modules and "CORE" not in test_modules: - sbt_test_goals += ["mllib/test", "examples/test"] - if "STREAMING" in test_modules and "CORE" not in test_modules: - sbt_test_goals += ["streaming/test", - "streaming-flume/test", - "streaming-flume-sink/test", - "streaming-kafka/test", - "streaming-mqtt/test", - "streaming-twitter/test", - "streaming-zeromq/test", - "examples/test"] - if "GRAPHX" in test_modules and "CORE" not in test_modules: - sbt_test_goals += ["graphx/test", "examples/test"] - if not sbt_test_goals: - sbt_test_goals = ["test"] - - profiles_and_goals = test_profiles + sbt_test_goals + sbt_test_goals = set(itertools.chain.from_iterable(m.sbt_test_goals for m in test_modules)) + + if not sbt_test_goals: + return + + profiles_and_goals = test_profiles + list(sbt_test_goals) print "[info] Running Spark tests using SBT with these arguments:", print " ".join(profiles_and_goals) @@ -436,9 +671,8 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): test_modules = set(test_modules) - hive_profiles = ("SQL" in test_modules) - test_profiles = get_build_profiles(hadoop_version, enable_hive_profiles=hive_profiles) - + test_profiles = get_hadoop_profiles(hadoop_version) + \ + list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))) if build_tool == "maven": run_scala_tests_maven(test_profiles) else: @@ -502,19 +736,29 @@ def main(): hadoop_version = "hadoop2.3" test_env = "local" - print "[info] Using build tool", build_tool, "with profile", hadoop_version, + print "[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version, print "under environment", test_env - # determine high level changes - changed_modules = identify_changed_modules(test_env) - print "[info] Found the following changed modules:", ", ".join(changed_modules) + changed_modules = None + changed_files = None + if test_env == "amplab_jenkins" and os.environ.get("AMP_JENKINS_PRB"): + target_branch = os.environ["ghprbTargetBranch"] + changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) + changed_modules = determine_modules_for_files(changed_files) + if not changed_modules: + changed_modules = [root] + print "[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules) + + test_modules = determine_modules_to_test(changed_modules) # license checks run_apache_rat_checks() # style checks - run_scala_style_checks() - run_python_style_checks() + if not changed_files or any(f.endswith(".scala") for f in changed_files): + run_scala_style_checks() + if not changed_files or any(f.endswith(".py") for f in changed_files): + run_python_style_checks() # determine if docs were changed and if we're inside the amplab environment # note - the below commented out until *all* Jenkins workers can get `jekyll` installed @@ -528,9 +772,20 @@ def main(): detect_binary_inop_with_mima() # run the test suites - run_scala_tests(build_tool, hadoop_version, changed_modules) - run_python_tests() - run_sparkr_tests() + run_scala_tests(build_tool, hadoop_version, test_modules) + + if any(m.should_run_python_tests for m in test_modules): + run_python_tests() + if any(m.should_run_r_tests for m in test_modules): + run_sparkr_tests() + + +def _test(): + import doctest + failure_count = doctest.testmod()[0] + if failure_count: + exit(-1) if __name__ == "__main__": + _test() main() From 004f57374b98c4df32d9f1e19221f68e92639a49 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Sat, 20 Jun 2015 16:10:14 -0700 Subject: [PATCH 135/151] [SPARK-8495] [SPARKR] Add a `.lintr` file to validate the SparkR files and the `lint-r` script Thank Shivaram Venkataraman for your support. This is a prototype script to validate the R files. Author: Yu ISHIKAWA Closes #6922 from yu-iskw/SPARK-6813 and squashes the following commits: c1ffe6b [Yu ISHIKAWA] Modify to save result to a log file and add a rule to validate 5520806 [Yu ISHIKAWA] Exclude the .lintr file not to check Apache lincence 8f94680 [Yu ISHIKAWA] [SPARK-8495][SparkR] Add a `.lintr` file to validate the SparkR files and the `lint-r` script --- .gitignore | 1 + .rat-excludes | 1 + R/pkg/.lintr | 2 ++ dev/lint-r | 30 ++++++++++++++++++++++++++++++ dev/lint-r.R | 29 +++++++++++++++++++++++++++++ 5 files changed, 63 insertions(+) create mode 100644 R/pkg/.lintr create mode 100755 dev/lint-r create mode 100644 dev/lint-r.R diff --git a/.gitignore b/.gitignore index 3624d12269612..debad77ec2ad3 100644 --- a/.gitignore +++ b/.gitignore @@ -66,6 +66,7 @@ scalastyle-output.xml R-unit-tests.log R/unit-tests.out python/lib/pyspark.zip +lint-r-report.log # For Hive metastore_db/ diff --git a/.rat-excludes b/.rat-excludes index aa008e6e920f5..c24667c18dbda 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -86,3 +86,4 @@ local-1430917381535_2 DESCRIPTION NAMESPACE test_support/* +.lintr diff --git a/R/pkg/.lintr b/R/pkg/.lintr new file mode 100644 index 0000000000000..b10ebd35c4ca7 --- /dev/null +++ b/R/pkg/.lintr @@ -0,0 +1,2 @@ +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL) +exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/dev/lint-r b/dev/lint-r new file mode 100755 index 0000000000000..7d5f4cd31153d --- /dev/null +++ b/dev/lint-r @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" +LINT_R_REPORT_FILE_NAME="$SPARK_ROOT_DIR/dev/lint-r-report.log" + + +if ! type "Rscript" > /dev/null; then + echo "ERROR: You should install R" + exit +fi + +`which Rscript` --vanilla "$SPARK_ROOT_DIR/dev/lint-r.R" "$SPARK_ROOT_DIR" | tee "$LINT_R_REPORT_FILE_NAME" diff --git a/dev/lint-r.R b/dev/lint-r.R new file mode 100644 index 0000000000000..dcb1a184291e1 --- /dev/null +++ b/dev/lint-r.R @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Installs lintr from Github. +# NOTE: The CRAN's version is too old to adapt to our rules. +if ("lintr" %in% row.names(installed.packages()) == FALSE) { + devtools::install_github("jimhester/lintr") +} +library(lintr) + +argv <- commandArgs(TRUE) +SPARK_ROOT_DIR <- as.character(argv[1]) + +path.to.package <- file.path(SPARK_ROOT_DIR, "R", "pkg") +lint_package(path.to.package, cache = FALSE) From 41ab2853f41de2abc415358b69671f37a0653533 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Sat, 20 Jun 2015 20:03:59 -0700 Subject: [PATCH 136/151] [SPARK-8301] [SQL] Improve UTF8String substring/startsWith/endsWith/contains performance Jira: https://issues.apache.org/jira/browse/SPARK-8301 Added the private method startsWith(prefix, offset) to implement startsWith, endsWith and contains without copying the array I hope that the component SQL is still correct. I copied it from the Jira ticket. Author: Tarek Auel Author: Tarek Auel Closes #6804 from tarekauel/SPARK-8301 and squashes the following commits: f5d6b9a [Tarek Auel] fixed parentheses and annotation 6d7b068 [Tarek Auel] [SPARK-8301] removed null checks 9ca0473 [Tarek Auel] [SPARK-8301] removed null checks 1c327eb [Tarek Auel] [SPARK-8301] removed new 9f17cc8 [Tarek Auel] [SPARK-8301] fixed conversion byte to string in codegen 3a0040f [Tarek Auel] [SPARK-8301] changed call of UTF8String.set to UTF8String.from e4530d2 [Tarek Auel] [SPARK-8301] changed call of UTF8String.set to UTF8String.from a5f853a [Tarek Auel] [SPARK-8301] changed visibility of set to protected. Changed annotation of bytes from Nullable to Nonnull d2fb05f [Tarek Auel] [SPARK-8301] added additional null checks 79cb55b [Tarek Auel] [SPARK-8301] null check. Added test cases for null check. b17909e [Tarek Auel] [SPARK-8301] removed unnecessary copying of UTF8String. Added a private function startsWith(prefix, offset) to implement the check for startsWith, endsWith and contains. --- .../sql/catalyst/expressions/UnsafeRow.java | 4 +-- .../spark/sql/catalyst/expressions/Cast.scala | 6 ++-- .../apache/spark/unsafe/types/UTF8String.java | 30 +++++++++++-------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index c4b7f8490a05b..ed04d2e50ec84 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -312,7 +312,6 @@ public double getDouble(int i) { public UTF8String getUTF8String(int i) { assertIndexIsValid(i); - final UTF8String str = new UTF8String(); final long offsetToStringSize = getLong(i); final int stringSizeInBytes = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize); @@ -324,8 +323,7 @@ public UTF8String getUTF8String(int i) { PlatformDependent.BYTE_ARRAY_OFFSET, stringSizeInBytes ); - str.set(strBytes); - return str; + return UTF8String.fromBytes(strBytes); } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b20086bcc48b9..ad920f287820c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -438,17 +438,17 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BinaryType, StringType) => defineCodeGen (ctx, ev, c => - s"new ${ctx.stringType}().set($c)") + s"${ctx.stringType}.fromBytes($c)") case (DateType, StringType) => defineCodeGen(ctx, ev, c => - s"""new ${ctx.stringType}().set( + s"""${ctx.stringType}.fromString( org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""") // Special handling required for timestamps in hive test cases since the toString function // does not match the expected output. case (TimestampType, StringType) => super.genCode(ctx, ev) case (_, StringType) => - defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))") + defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") // fallback for DecimalType, this must be before other numeric types case (_, dt: DecimalType) => diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index a35168019549e..9871a70a40e69 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -20,7 +20,7 @@ import java.io.Serializable; import java.io.UnsupportedEncodingException; import java.util.Arrays; -import javax.annotation.Nullable; +import javax.annotation.Nonnull; import org.apache.spark.unsafe.PlatformDependent; @@ -34,7 +34,7 @@ */ public final class UTF8String implements Comparable, Serializable { - @Nullable + @Nonnull private byte[] bytes; private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, @@ -55,7 +55,7 @@ public static UTF8String fromString(String str) { /** * Updates the UTF8String with String. */ - public UTF8String set(final String str) { + protected UTF8String set(final String str) { try { bytes = str.getBytes("utf-8"); } catch (UnsupportedEncodingException e) { @@ -69,7 +69,7 @@ public UTF8String set(final String str) { /** * Updates the UTF8String with byte[], which should be encoded in UTF-8. */ - public UTF8String set(final byte[] bytes) { + protected UTF8String set(final byte[] bytes) { this.bytes = bytes; return this; } @@ -131,24 +131,30 @@ public boolean contains(final UTF8String substring) { } for (int i = 0; i <= bytes.length - b.length; i++) { - // TODO: Avoid copying. - if (bytes[i] == b[0] && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) { + if (bytes[i] == b[0] && startsWith(b, i)) { return true; } } return false; } + private boolean startsWith(final byte[] prefix, int offsetInBytes) { + if (prefix.length + offsetInBytes > bytes.length || offsetInBytes < 0) { + return false; + } + int i = 0; + while (i < prefix.length && prefix[i] == bytes[i + offsetInBytes]) { + i++; + } + return i == prefix.length; + } + public boolean startsWith(final UTF8String prefix) { - final byte[] b = prefix.getBytes(); - // TODO: Avoid copying. - return b.length <= bytes.length && Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b); + return startsWith(prefix.getBytes(), 0); } public boolean endsWith(final UTF8String suffix) { - final byte[] b = suffix.getBytes(); - return b.length <= bytes.length && - Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b); + return startsWith(suffix.getBytes(), bytes.length - suffix.getBytes().length); } public UTF8String toUpperCase() { From a1e3649c8775d71ca78796b6544284e942ac1331 Mon Sep 17 00:00:00 2001 From: jeanlyn Date: Sun, 21 Jun 2015 00:13:40 -0700 Subject: [PATCH 137/151] [SPARK-8379] [SQL] avoid speculative tasks write to the same file The issue link [SPARK-8379](https://issues.apache.org/jira/browse/SPARK-8379) Currently,when we insert data to the dynamic partition with speculative tasks we will get the Exception ``` org.apache.hadoop.ipc.RemoteException(org.apache.hadoop.hdfs.server.namenode.LeaseExpiredException): Lease mismatch on /tmp/hive-jeanlyn/hive_2015-06-15_15-20-44_734_8801220787219172413-1/-ext-10000/ds=2015-06-15/type=2/part-00301.lzo owned by DFSClient_attempt_201506031520_0011_m_000189_0_-1513487243_53 but is accessed by DFSClient_attempt_201506031520_0011_m_000042_0_-1275047721_57 ``` This pr try to write the data to temporary dir when using dynamic parition avoid the speculative tasks writing the same file Author: jeanlyn Closes #6833 from jeanlyn/speculation and squashes the following commits: 64bbfab [jeanlyn] use FileOutputFormat.getTaskOutputPath to get the path 8860af0 [jeanlyn] remove the never using code e19a3bd [jeanlyn] avoid speculative tasks write same file --- .../sql/hive/execution/InsertIntoHiveTable.scala | 1 - .../apache/spark/sql/hive/hiveWriterContainers.scala | 11 +++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 404bb937aaf87..05f425f2b65f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -198,7 +198,6 @@ case class InsertIntoHiveTable( table.hiveQlTable.getPartCols().foreach { entry => orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse("")) } - val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) // inheritTableSpecs is set to true. It should be set to false for a IMPORT query // which is currently considered as a Hive native command. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 0bc69c00c241c..8b928861fcc70 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -228,12 +228,11 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) newFileSinkDesc.setCompressType(fileSinkConf.getCompressType) - val path = { - val outputPath = FileOutputFormat.getOutputPath(conf.value) - assert(outputPath != null, "Undefined job output-path") - val workPath = new Path(outputPath, dynamicPartPath.stripPrefix("/")) - new Path(workPath, getOutputName) - } + // use the path like ${hive_tmp}/_temporary/${attemptId}/ + // to avoid write to the same file when `spark.speculation=true` + val path = FileOutputFormat.getTaskOutputPath( + conf.value, + dynamicPartPath.stripPrefix("/") + "/" + getOutputName) HiveFileFormatUtils.getHiveRecordWriter( conf.value, From 32e3cdaa647722671adcb5068bd5ffbf2f157806 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 21 Jun 2015 12:04:20 -0700 Subject: [PATCH 138/151] [SPARK-7604] [MLLIB] Python API for PCA and PCAModel Python API for PCA and PCAModel Author: Yanbo Liang Closes #6315 from yanboliang/spark-7604 and squashes the following commits: 1d58734 [Yanbo Liang] remove transform() in PCAModel, use default behavior 4d9d121 [Yanbo Liang] Python API for PCA and PCAModel --- .../mllib/api/python/PythonMLLibAPI.scala | 10 ++++++ python/pyspark/mllib/feature.py | 35 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 2897865af6912..634d56d08d17e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -519,6 +519,16 @@ private[python] class PythonMLLibAPI extends Serializable { new ChiSqSelector(numTopFeatures).fit(data.rdd) } + /** + * Java stub for PCA.fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def fitPCA(k: Int, data: JavaRDD[Vector]): PCAModel = { + new PCA(k).fit(data.rdd) + } + /** * Java stub for IDF.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index cf5fdf2cf9788..334f5b86cd392 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -252,6 +252,41 @@ def fit(self, data): return ChiSqSelectorModel(jmodel) +class PCAModel(JavaVectorTransformer): + """ + Model fitted by [[PCA]] that can project vectors to a low-dimensional space using PCA. + """ + + +class PCA(object): + """ + A feature transformer that projects vectors to a low-dimensional space using PCA. + + >>> data = [Vectors.sparse(5, [(1, 1.0), (3, 7.0)]), + ... Vectors.dense([2.0, 0.0, 3.0, 4.0, 5.0]), + ... Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0])] + >>> model = PCA(2).fit(sc.parallelize(data)) + >>> pcArray = model.transform(Vectors.sparse(5, [(1, 1.0), (3, 7.0)])).toArray() + >>> pcArray[0] + 1.648... + >>> pcArray[1] + -4.013... + """ + def __init__(self, k): + """ + :param k: number of principal components. + """ + self.k = int(k) + + def fit(self, data): + """ + Computes a [[PCAModel]] that contains the principal components of the input vectors. + :param data: source vectors + """ + jmodel = callMLlibFunc("fitPCA", self.k, data) + return PCAModel(jmodel) + + class HashingTF(object): """ .. note:: Experimental From 83cdfd84f8ca679e1ec451ed88b946e8e7f13a94 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 21 Jun 2015 13:20:28 -0700 Subject: [PATCH 139/151] [SPARK-8508] [SQL] Ignores a test case to cleanup unnecessary testing output until #6882 is merged Currently [the test case for SPARK-7862] [1] writes 100,000 lines of integer triples to stderr and makes Jenkins build output unnecessarily large and it's hard to debug other build errors. A proper fix is on the way in #6882. This PR ignores this test case temporarily until #6882 is merged. [1]: https://github.com/apache/spark/pull/6404/files#diff-1ea02a6fab84e938582f7f87cc4d9ea1R641 Author: Cheng Lian Closes #6925 from liancheng/spark-8508 and squashes the following commits: 41e5b47 [Cheng Lian] Ignores the test case until #6882 is merged --- .../org/apache/spark/sql/hive/execution/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index e1c9926bed524..a2e666586c186 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -653,7 +653,7 @@ class SQLQuerySuite extends QueryTest { .queryExecution.toRdd.count()) } - test("test script transform for stderr") { + ignore("test script transform for stderr") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(0 === From a1894422ad6b3335c84c73ba9466da6677d893cb Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 21 Jun 2015 16:25:25 -0700 Subject: [PATCH 140/151] [SPARK-7715] [MLLIB] [ML] [DOC] Updated MLlib programming guide for release 1.4 Reorganized docs a bit. Added migration guides. **Q**: Do we want to say more for the 1.3 -> 1.4 migration guide for ```spark.ml```? It would be a lot. CC: mengxr Author: Joseph K. Bradley Closes #6897 from jkbradley/ml-guide-1.4 and squashes the following commits: 4bf26d6 [Joseph K. Bradley] tiny fix 8085067 [Joseph K. Bradley] fixed spacing/layout issues in ml guide from previous commit in this PR 6cd5c78 [Joseph K. Bradley] Updated MLlib programming guide for release 1.4 --- docs/ml-guide.md | 32 +++++++++++++--------- docs/mllib-feature-extraction.md | 3 +- docs/mllib-guide.md | 47 +++++++++++++++++++------------- docs/mllib-migration-guides.md | 16 +++++++++++ 4 files changed, 65 insertions(+), 33 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 4eb622d4b95e8..c74cb1f1ef8ea 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -3,10 +3,10 @@ layout: global title: Spark ML Programming Guide --- -`spark.ml` is a new package introduced in Spark 1.2, which aims to provide a uniform set of +Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. -It is currently an alpha component, and we would like to hear back from the community about -how it fits real-world use cases and how it could be improved. + +*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. Note that we will keep supporting and adding features to `spark.mllib` along with the development of `spark.ml`. @@ -14,6 +14,12 @@ Users should be comfortable using `spark.mllib` features and expect more feature Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. +Guides for sub-packages of `spark.ml` include: + +* [Feature Extraction, Transformation, and Selection](ml-features.html): Details on transformers supported in the Pipelines API, including a few not in the lower-level `spark.mllib` API +* [Ensembles](ml-ensembles.html): Details on ensemble learning methods in the Pipelines API + + **Table of Contents** * This will become a table of contents (this text will be scraped). @@ -148,16 +154,6 @@ Parameters belong to specific instances of `Estimator`s and `Transformer`s. For example, if we have two `LogisticRegression` instances `lr1` and `lr2`, then we can build a `ParamMap` with both `maxIter` parameters specified: `ParamMap(lr1.maxIter -> 10, lr2.maxIter -> 20)`. This is useful if there are two algorithms with the `maxIter` parameter in a `Pipeline`. -# Algorithm Guides - -There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. - -**Pipelines API Algorithm Guides** - -* [Feature Extraction, Transformation, and Selection](ml-features.html) -* [Ensembles](ml-ensembles.html) - - # Code Examples This section gives code examples illustrating the functionality discussed above. @@ -783,6 +779,16 @@ Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not # Migration Guide +## From 1.3 to 1.4 + +Several major API changes occurred, including: +* `Param` and other APIs for specifying parameters +* `uid` unique IDs for Pipeline components +* Reorganization of certain classes +Since the `spark.ml` API was an Alpha Component in Spark 1.3, we do not list all changes here. + +However, now that `spark.ml` is no longer an Alpha Component, we will provide details on any API changes for future releases. + ## From 1.2 to 1.3 The main API changes are from Spark SQL. We list the most important changes here: diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 1197dbbb8d982..83e937635a55b 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -576,8 +576,9 @@ parsedData = data.map(lambda x: [float(t) for t in x.split(" ")]) transformingVector = Vectors.dense([0.0, 1.0, 2.0]) transformer = ElementwiseProduct(transformingVector) -# Batch transform. +# Batch transform transformedData = transformer.transform(parsedData) +# Single-row transform transformedData2 = transformer.transform(parsedData.first()) {% endhighlight %} diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index de7d66fb2dedf..d2d1cc93fe006 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -7,7 +7,19 @@ description: MLlib machine learning library overview for Spark SPARK_VERSION_SHO MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives, as outlined below: +filtering, dimensionality reduction, as well as underlying optimization primitives. +Guides for individual algorithms are listed below. + +The API is divided into 2 parts: + +* [The original `spark.mllib` API](mllib-guide.html#mllib-types-algorithms-and-utilities) is the primary API. +* [The "Pipelines" `spark.ml` API](mllib-guide.html#sparkml-high-level-apis-for-ml-pipelines) is a higher-level API for constructing ML workflows. + +We list major functionality from both below, with links to detailed guides. + +# MLlib types, algorithms and utilities + +This lists functionality included in `spark.mllib`, the main MLlib API. * [Data types](mllib-data-types.html) * [Basic statistics](mllib-statistics.html) @@ -49,8 +61,8 @@ and the migration guide below will explain all changes between releases. Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of high-level APIs that help users create and tune practical machine learning pipelines. -It is currently an alpha component, and we would like to hear back from the community about -how it fits real-world use cases and how it could be improved. + +*Graduated from Alpha!* The Pipelines API is no longer an alpha component, although many elements of it are still `Experimental` or `DeveloperApi`. Note that we will keep supporting and adding features to `spark.mllib` along with the development of `spark.ml`. @@ -58,7 +70,11 @@ Users should be comfortable using `spark.mllib` features and expect more feature Developers should contribute new algorithms to `spark.mllib` and can optionally contribute to `spark.ml`. -See the **[spark.ml programming guide](ml-guide.html)** for more information on this package. +More detailed guides for `spark.ml` include: + +* **[spark.ml programming guide](ml-guide.html)**: overview of the Pipelines API and major concepts +* [Feature transformers](ml-features.html): Details on transformers supported in the Pipelines API, including a few not in the lower-level `spark.mllib` API +* [Ensembles](ml-ensembles.html): Details on ensemble learning methods in the Pipelines API # Dependencies @@ -90,21 +106,14 @@ version 1.4 or newer. For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide). -## From 1.2 to 1.3 - -In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. - -* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. -* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. -* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: - * The constructor taking arguments was removed in favor of a builder patten using the default constructor plus parameter setter methods. - * Variable `model` is no longer public. -* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: - * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) - * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. -* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. -* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. - So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. +## From 1.3 to 1.4 + +In the `spark.mllib` package, there were several breaking changes, but all in `DeveloperApi` or `Experimental` APIs: + +* Gradient-Boosted Trees + * *(Breaking change)* The signature of the [`Loss.gradient`](api/scala/index.html#org.apache.spark.mllib.tree.loss.Loss) method was changed. This is only an issues for users who wrote their own losses for GBTs. + * *(Breaking change)* The `apply` and `copy` methods for the case class [`BoostingStrategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.BoostingStrategy) have been changed because of a modification to the case class fields. This could be an issue for users who use `BoostingStrategy` to set GBT parameters. +* *(Breaking change)* The return value of [`LDA.run`](api/scala/index.html#org.apache.spark.mllib.clustering.LDA) has changed. It now returns an abstract class `LDAModel` instead of the concrete class `DistributedLDAModel`. The object of type `LDAModel` can still be cast to the appropriate concrete type, which depends on the optimization algorithm. ## Previous Spark Versions diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md index 4de2d9491ac2b..8df68d81f3c78 100644 --- a/docs/mllib-migration-guides.md +++ b/docs/mllib-migration-guides.md @@ -7,6 +7,22 @@ description: MLlib migration guides from before Spark SPARK_VERSION_SHORT The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide). +## From 1.2 to 1.3 + +In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental. + +* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed. +* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`. +* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes: + * The constructor taking arguments was removed in favor of a builder pattern using the default constructor plus parameter setter methods. + * Variable `model` is no longer public. +* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes: + * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.) + * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training. +* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use. +* In linear regression (including Lasso and ridge regression), the squared loss is now divided by 2. + So in order to produce the same result as in 1.2, the regularization parameter needs to be divided by 2 and the step size needs to be multiplied by 2. + ## From 1.1 to 1.2 The only API changes in MLlib v1.2 are in From 47c1d5629373566df9d12fdc4ceb22f38b869482 Mon Sep 17 00:00:00 2001 From: Mike Dusenberry Date: Sun, 21 Jun 2015 18:25:36 -0700 Subject: [PATCH 141/151] [SPARK-7426] [MLLIB] [ML] Updated Attribute.fromStructField to allow any NumericType. Updated `Attribute.fromStructField` to allow any `NumericType`, rather than just `DoubleType`, and added unit tests for a few of the other NumericTypes. Author: Mike Dusenberry Closes #6540 from dusenberrymw/SPARK-7426_AttributeFactory.fromStructField_Should_Allow_NumericTypes and squashes the following commits: 87fecb3 [Mike Dusenberry] Updated Attribute.fromStructField to allow any NumericType, rather than just DoubleType, and added unit tests for a few of the other NumericTypes. --- .../scala/org/apache/spark/ml/attribute/attributes.scala | 4 ++-- .../scala/org/apache/spark/ml/attribute/AttributeSuite.scala | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index ce43a450daad0..e479f169021d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute import scala.annotation.varargs import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField} +import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField} /** * :: DeveloperApi :: @@ -127,7 +127,7 @@ private[attribute] trait AttributeFactory { * Creates an [[Attribute]] from a [[StructField]] instance. */ def fromStructField(field: StructField): Attribute = { - require(field.dataType == DoubleType) + require(field.dataType.isInstanceOf[NumericType]) val metadata = field.metadata val mlAttr = AttributeKeys.ML_ATTR if (metadata.contains(mlAttr)) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index 72b575d022547..c5fd2f9d5a22a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -215,5 +215,10 @@ class AttributeSuite extends SparkFunSuite { assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute) val fldWithMeta = new StructField("x", DoubleType, false, metadata) assert(Attribute.fromStructField(fldWithMeta).isNumeric) + // Attribute.fromStructField should accept any NumericType, not just DoubleType + val longFldWithMeta = new StructField("x", LongType, false, metadata) + assert(Attribute.fromStructField(longFldWithMeta).isNumeric) + val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata) + assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric) } } From 0818fdec3733ec5c0a9caa48a9c0f2cd25f84d13 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 22 Jun 2015 10:03:57 -0700 Subject: [PATCH 142/151] [SPARK-8406] [SQL] Adding UUID to output file name to avoid accidental overwriting This PR fixes a Parquet output file name collision bug which may cause data loss. Changes made: 1. Identify each write job issued by `InsertIntoHadoopFsRelation` with a UUID All concrete data sources which extend `HadoopFsRelation` (Parquet and ORC for now) must use this UUID to generate task output file path to avoid name collision. 2. Make `TestHive` use a local mode `SparkContext` with 32 threads to increase parallelism The major reason for this is that, the original parallelism of 2 is too low to reproduce the data loss issue. Also, higher concurrency may potentially caught more concurrency bugs during testing phase. (It did help us spotted SPARK-8501.) 3. `OrcSourceSuite` was updated to workaround SPARK-8501, which we detected along the way. NOTE: This PR is made a little bit more complicated than expected because we hit two other bugs on the way and have to work them around. See [SPARK-8501] [1] and [SPARK-8513] [2]. [1]: https://github.com/liancheng/spark/tree/spark-8501 [2]: https://github.com/liancheng/spark/tree/spark-8513 ---- Some background and a summary of offline discussion with yhuai about this issue for better understanding: In 1.4.0, we added `HadoopFsRelation` to abstract partition support of all data sources that are based on Hadoop `FileSystem` interface. Specifically, this makes partition discovery, partition pruning, and writing dynamic partitions for data sources much easier. To support appending, the Parquet data source tries to find out the max part number of part-files in the destination directory (i.e., `` in output file name `part-r-.gz.parquet`) at the beginning of the write job. In 1.3.0, this step happens on driver side before any files are written. However, in 1.4.0, this is moved to task side. Unfortunately, for tasks scheduled later, they may see wrong max part number generated of files newly written by other finished tasks within the same job. This actually causes a race condition. In most cases, this only causes nonconsecutive part numbers in output file names. But when the DataFrame contains thousands of RDD partitions, it's likely that two tasks may choose the same part number, then one of them gets overwritten by the other. Before `HadoopFsRelation`, Spark SQL already supports appending data to Hive tables. From a user's perspective, these two look similar. However, they differ a lot internally. When data are inserted into Hive tables via Spark SQL, `InsertIntoHiveTable` simulates Hive's behaviors: 1. Write data to a temporary location 2. Move data in the temporary location to the final destination location using - `Hive.loadTable()` for non-partitioned table - `Hive.loadPartition()` for static partitions - `Hive.loadDynamicPartitions()` for dynamic partitions The important part is that, `Hive.copyFiles()` is invoked in step 2 to move the data to the destination directory (I found the name is kinda confusing since no "copying" occurs here, we are just moving and renaming stuff). If a file in the source directory and another file in the destination directory happen to have the same name, say `part-r-00001.parquet`, the former is moved to the destination directory and renamed with a `_copy_N` postfix (`part-r-00001_copy_1.parquet`). That's how Hive handles appending and avoids name collision between different write jobs. Some alternatives fixes considered for this issue: 1. Use a similar approach as Hive This approach is not preferred in Spark 1.4.0 mainly because file metadata operations in S3 tend to be slow, especially for tables with lots of file and/or partitions. That's why `InsertIntoHadoopFsRelation` just inserts to destination directory directly, and is often used together with `DirectParquetOutputCommitter` to reduce latency when working with S3. This means, we don't have the chance to do renaming, and must avoid name collision from the very beginning. 2. Same as 1.3, just move max part number detection back to driver side This isn't doable because unlike 1.3, 1.4 also takes dynamic partitioning into account. When inserting into dynamic partitions, we don't know which partition directories will be touched on driver side before issuing the write job. Checking all partition directories is simply too expensive for tables with thousands of partitions. 3. Add extra component to output file names to avoid name collision This seems to be the only reasonable solution for now. To be more specific, we need a JOB level unique identifier to identify all write jobs issued by `InsertIntoHadoopFile`. Notice that TASK level unique identifiers can NOT be used. Because in this way a speculative task will write to a different output file from the original task. If both tasks succeed, duplicate output will be left behind. Currently, the ORC data source adds `System.currentTimeMillis` to the output file name for uniqueness. This doesn't work because of exactly the same reason. That's why this PR adds a job level random UUID in `BaseWriterContainer` (which is used by `InsertIntoHadoopFsRelation` to issue write jobs). The drawback is that record order is not preserved any more (output files of a later job may be listed before those of a earlier job). However, we never promise to preserve record order when writing data, and Hive doesn't promise this either because the `_copy_N` trick breaks the order. Author: Cheng Lian Closes #6864 from liancheng/spark-8406 and squashes the following commits: db7a46a [Cheng Lian] More comments f5c1133 [Cheng Lian] Addresses comments 85c478e [Cheng Lian] Workarounds SPARK-8513 088c76c [Cheng Lian] Adds comment about SPARK-8501 99a5e7e [Cheng Lian] Uses job level UUID in SimpleTextRelation and avoids double task abortion 4088226 [Cheng Lian] Works around SPARK-8501 1d7d206 [Cheng Lian] Adds more logs 8966bbb [Cheng Lian] Fixes Scala style issue 18b7003 [Cheng Lian] Uses job level UUID to take speculative tasks into account 3806190 [Cheng Lian] Lets TestHive use all cores by default 748dbd7 [Cheng Lian] Adding UUID to output file name to avoid accidental overwriting --- .../apache/spark/sql/parquet/newParquet.scala | 43 ++----------- .../apache/spark/sql/sources/commands.scala | 64 +++++++++++++++---- .../spark/sql/hive/orc/OrcFileOperator.scala | 9 +-- .../spark/sql/hive/orc/OrcRelation.scala | 5 +- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- .../spark/sql/hive/orc/OrcSourceSuite.scala | 28 ++++---- .../sql/sources/SimpleTextRelation.scala | 4 +- .../sql/sources/hadoopFsRelationSuites.scala | 37 +++++++++-- 8 files changed, 120 insertions(+), 72 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index c9de45e0ddfbb..e049d54bf55dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, SparkException, Partition => SparkPartition} +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} private[sql] class DefaultSource extends HadoopFsRelationProvider { override def createRelation( @@ -60,50 +60,21 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext extends OutputWriter { private val recordWriter: RecordWriter[Void, InternalRow] = { - val conf = context.getConfiguration val outputFormat = { - // When appending new Parquet files to an existing Parquet file directory, to avoid - // overwriting existing data files, we need to find out the max task ID encoded in these data - // file names. - // TODO Make this snippet a utility function for other data source developers - val maxExistingTaskId = { - // Note that `path` may point to a temporary location. Here we retrieve the real - // destination path from the configuration - val outputPath = new Path(conf.get("spark.sql.sources.output.path")) - val fs = outputPath.getFileSystem(conf) - - if (fs.exists(outputPath)) { - // Pattern used to match task ID in part file names, e.g.: - // - // part-r-00001.gz.parquet - // ^~~~~ - val partFilePattern = """part-.-(\d{1,}).*""".r - - fs.listStatus(outputPath).map(_.getPath.getName).map { - case partFilePattern(id) => id.toInt - case name if name.startsWith("_") => 0 - case name if name.startsWith(".") => 0 - case name => throw new AnalysisException( - s"Trying to write Parquet files to directory $outputPath, " + - s"but found items with illegal name '$name'.") - }.reduceOption(_ max _).getOrElse(0) - } else { - 0 - } - } - new ParquetOutputFormat[InternalRow]() { // Here we override `getDefaultWorkFile` for two reasons: // - // 1. To allow appending. We need to generate output file name based on the max available - // task ID computed above. + // 1. To allow appending. We need to generate unique output file names to avoid + // overwriting existing files (either exist before the write job, or are just written + // by other tasks within the same write job). // // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all // partitions in the case of dynamic partitioning. override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val split = context.getTaskAttemptID.getTaskID.getId + maxExistingTaskId + 1 - new Path(path, f"part-r-$split%05d$extension") + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") + val split = context.getTaskAttemptID.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index c16bd9ae52c81..215e53c020849 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.sources -import java.util.Date +import java.util.{Date, UUID} import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, FileOutputCommitter => MapReduceFileOutputCommitter} -import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil @@ -59,6 +58,28 @@ private[sql] case class InsertIntoDataSource( } } +/** + * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. + * Writing to dynamic partitions is also supported. Each [[InsertIntoHadoopFsRelation]] issues a + * single write job, and owns a UUID that identifies this job. Each concrete implementation of + * [[HadoopFsRelation]] should use this UUID together with task id to generate unique file path for + * each task output file. This UUID is passed to executor side via a property named + * `spark.sql.sources.writeJobUUID`. + * + * Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]] + * are used to write to normal tables and tables with dynamic partitions. + * + * Basic work flow of this command is: + * + * 1. Driver side setup, including output committer initialization and data source specific + * preparation work for the write job to be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ private[sql] case class InsertIntoHadoopFsRelation( @transient relation: HadoopFsRelation, @transient query: LogicalPlan, @@ -261,7 +282,14 @@ private[sql] abstract class BaseWriterContainer( with Logging with Serializable { - protected val serializableConf = new SerializableConfiguration(ContextUtil.getConfiguration(job)) + protected val serializableConf = new SerializableConfiguration(job.getConfiguration) + + // This UUID is used to avoid output file name collision between different appending write jobs. + // These jobs may belong to different SparkContext instances. Concrete data source implementations + // may use this UUID to generate unique file names (e.g., `part-r--.parquet`). + // The reason why this ID is used to identify a job rather than a single task output file is + // that, speculative tasks must generate the same output file name as the original task. + private val uniqueWriteJobId = UUID.randomUUID() // This is only used on driver side. @transient private val jobContext: JobContext = job @@ -290,6 +318,11 @@ private[sql] abstract class BaseWriterContainer( setupIDs(0, 0, 0) setupConf() + // This UUID is sent to executor side together with the serialized `Configuration` object within + // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate + // unique task output files. + job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, // configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext. @@ -417,15 +450,16 @@ private[sql] class DefaultWriterContainer( assert(writer != null, "OutputWriter instance should have been initialized") writer.close() super.commitTask() - } catch { - case cause: Throwable => - super.abortTask() - throw new RuntimeException("Failed to commit task", cause) + } catch { case cause: Throwable => + // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and will + // cause `abortTask()` to be invoked. + throw new RuntimeException("Failed to commit task", cause) } } override def abortTask(): Unit = { try { + // It's possible that the task fails before `writer` gets initialized if (writer != null) { writer.close() } @@ -469,21 +503,25 @@ private[sql] class DynamicPartitionWriterContainer( }) } - override def commitTask(): Unit = { - try { + private def clearOutputWriters(): Unit = { + if (outputWriters.nonEmpty) { outputWriters.values.foreach(_.close()) outputWriters.clear() + } + } + + override def commitTask(): Unit = { + try { + clearOutputWriters() super.commitTask() } catch { case cause: Throwable => - super.abortTask() throw new RuntimeException("Failed to commit task", cause) } } override def abortTask(): Unit = { try { - outputWriters.values.foreach(_.close()) - outputWriters.clear() + clearOutputWriters() } finally { super.abortTask() } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 1e51173a19882..e3ab9442b4821 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -27,13 +27,13 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.hive.HiveMetastoreTypes import org.apache.spark.sql.types.StructType -private[orc] object OrcFileOperator extends Logging{ +private[orc] object OrcFileOperator extends Logging { def getFileReader(pathStr: String, config: Option[Configuration] = None ): Reader = { val conf = config.getOrElse(new Configuration) val fspath = new Path(pathStr) val fs = fspath.getFileSystem(conf) val orcFiles = listOrcFiles(pathStr, conf) - + logDebug(s"Creating ORC Reader from ${orcFiles.head}") // TODO Need to consider all files when schema evolution is taken into account. OrcFile.createReader(fs, orcFiles.head) } @@ -42,6 +42,7 @@ private[orc] object OrcFileOperator extends Logging{ val reader = getFileReader(path, conf) val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] val schema = readerInspector.getTypeName + logDebug(s"Reading schema from file $path, got Hive schema string: $schema") HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] } @@ -52,14 +53,14 @@ private[orc] object OrcFileOperator extends Logging{ def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) val fs = origPath.getFileSystem(conf) - val path = origPath.makeQualified(fs) + val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory) val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) .filterNot(_.isDir) .map(_.getPath) .filterNot(_.getName.startsWith("_")) .filterNot(_.getName.startsWith(".")) - if (paths == null || paths.size == 0) { + if (paths == null || paths.isEmpty) { throw new IllegalArgumentException( s"orcFileOperator: path $path does not have valid orc files matching the pattern") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index dbce39f21d271..705f48f1cd9f0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, Reco import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} @@ -39,7 +40,6 @@ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreType import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.{Logging} import org.apache.spark.util.SerializableConfiguration /* Implicit conversions */ @@ -105,8 +105,9 @@ private[orc] class OrcOutputWriter( recordWriterInstantiated = true val conf = context.getConfiguration + val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") val partition = context.getTaskAttemptID.getTaskID.getId - val filename = f"part-r-$partition%05d-${System.currentTimeMillis}%015d.orc" + val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" new OrcOutputFormat().getRecordWriter( new Path(path, filename).getFileSystem(conf), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index f901bd8171508..ea325cc93cb85 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -49,7 +49,7 @@ import scala.collection.JavaConversions._ object TestHive extends TestHiveContext( new SparkContext( - System.getProperty("spark.sql.test.master", "local[2]"), + System.getProperty("spark.sql.test.master", "local[32]"), "TestSQLContext", new SparkConf() .set("spark.sql.test", "") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 82e08caf46457..a0cdd0db42d65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -43,8 +43,14 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { orcTableDir.mkdir() import org.apache.spark.sql.hive.test.TestHive.implicits._ + // Originally we were using a 10-row RDD for testing. However, when default parallelism is + // greater than 10 (e.g., running on a node with 32 cores), this RDD contains empty partitions, + // which result in empty ORC files. Unfortunately, ORC doesn't handle empty files properly and + // causes build failure on Jenkins, which happens to have 32 cores. Please refer to SPARK-8501 + // for more details. To workaround this issue before fixing SPARK-8501, we simply increase row + // number in this RDD to avoid empty partitions. sparkContext - .makeRDD(1 to 10) + .makeRDD(1 to 100) .map(i => OrcData(i, s"part-$i")) .toDF() .registerTempTable(s"orc_temp_table") @@ -70,35 +76,35 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { } test("create temporary orc table") { - checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(100)) checkAnswer( sql("SELECT * FROM normal_orc_source"), - (1 to 10).map(i => Row(i, s"part-$i"))) + (1 to 100).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT * FROM normal_orc_source where intField > 5"), - (6 to 10).map(i => Row(i, s"part-$i"))) + (6 to 100).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), - (1 to 10).map(i => Row(1, s"part-$i"))) + (1 to 100).map(i => Row(1, s"part-$i"))) } test("create temporary orc table as") { - checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(100)) checkAnswer( sql("SELECT * FROM normal_orc_source"), - (1 to 10).map(i => Row(i, s"part-$i"))) + (1 to 100).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT * FROM normal_orc_source WHERE intField > 5"), - (6 to 10).map(i => Row(i, s"part-$i"))) + (6 to 100).map(i => Row(i, s"part-$i"))) checkAnswer( sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), - (1 to 10).map(i => Row(1, s"part-$i"))) + (1 to 100).map(i => Row(1, s"part-$i"))) } test("appending insert") { @@ -106,7 +112,7 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM normal_orc_source"), - (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i => + (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 100).flatMap { i => Seq.fill(2)(Row(i, s"part-$i")) }) } @@ -119,7 +125,7 @@ abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT * FROM normal_orc_as_source"), - (6 to 10).map(i => Row(i, s"part-$i"))) + (6 to 100).map(i => Row(i, s"part-$i"))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 0f959b3d0b86d..5d7cd16c129cd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -53,9 +53,10 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") val split = context.getTaskAttemptID.getTaskID.getId val name = FileOutputFormat.getOutputName(context) - new Path(outputFile, s"$name-${numberFormat.format(split)}-${UUID.randomUUID()}") + new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") } } @@ -156,6 +157,7 @@ class CommitFailureTestRelation( context: TaskAttemptContext): OutputWriter = { new SimpleTextOutputWriter(path, context) { override def close(): Unit = { + super.close() sys.error("Intentional task commitment failure for testing purpose.") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 76469d7a3d6a5..e0d8277a8ed3f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -35,7 +35,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { import sqlContext.sql import sqlContext.implicits._ - val dataSourceName = classOf[SimpleTextSource].getCanonicalName + val dataSourceName: String val dataSchema = StructType( @@ -470,6 +470,33 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect()) } } + + // NOTE: This test suite is not super deterministic. On nodes with only relatively few cores + // (4 or even 1), it's hard to reproduce the data loss issue. But on nodes with for example 8 or + // more cores, the issue can be reproduced steadily. Fortunately our Jenkins builder meets this + // requirement. We probably want to move this test case to spark-integration-tests or spark-perf + // later. + test("SPARK-8406: Avoids name collision while writing Parquet files") { + withTempPath { dir => + val path = dir.getCanonicalPath + sqlContext + .range(10000) + .repartition(250) + .write + .mode(SaveMode.Overwrite) + .format(dataSourceName) + .save(path) + + assertResult(10000) { + sqlContext + .read + .format(dataSourceName) + .option("dataSchema", StructType(StructField("id", LongType) :: Nil).json) + .load(path) + .count() + } + } + } } class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { @@ -502,15 +529,17 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - import TestHive.implicits._ - override val sqlContext = TestHive + // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName test("SPARK-7684: commitTask() failure should fallback to abortTask()") { withTempPath { file => - val df = (1 to 3).map(i => i -> s"val_$i").toDF("a", "b") + // Here we coalesce partition number to 1 to ensure that only a single task is issued. This + // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` + // directory while committing/aborting the job. See SPARK-8513 for more details. + val df = sqlContext.range(0, 10).coalesce(1) intercept[SparkException] { df.write.format(dataSourceName).save(file.getCanonicalPath) } From 42a1f716fa35533507784be5e9117a984a03e62d Mon Sep 17 00:00:00 2001 From: Stefano Parmesan Date: Mon, 22 Jun 2015 11:43:10 -0700 Subject: [PATCH 143/151] [SPARK-8429] [EC2] Add ability to set additional tags Add the `--additional-tags` parameter that allows to set additional tags to all the created instances (masters and slaves). The user can specify multiple tags by separating them with a comma (`,`), while each tag name and value should be separated by a colon (`:`); for example, `Task:MySparkProject,Env:production` would add two tags, `Task` and `Env`, with the given values. Author: Stefano Parmesan Closes #6857 from armisael/patch-1 and squashes the following commits: c5ac92c [Stefano Parmesan] python style (pep8) 8e614f1 [Stefano Parmesan] Set multiple tags in a single request bfc56af [Stefano Parmesan] Address SPARK-7900 by inceasing sleep time daf8615 [Stefano Parmesan] Add ability to set additional tags --- ec2/spark_ec2.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 56087499464e0..103735685485b 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -289,6 +289,10 @@ def parse_args(): parser.add_option( "--additional-security-group", type="string", default="", help="Additional security group to place the machines in") + parser.add_option( + "--additional-tags", type="string", default="", + help="Additional tags to set on the machines; tags are comma-separated, while name and " + + "value are colon separated; ex: \"Task:MySparkProject,Env:production\"") parser.add_option( "--copy-aws-credentials", action="store_true", default=False, help="Add AWS credentials to hadoop configuration to allow Spark to access S3") @@ -684,16 +688,24 @@ def launch_cluster(conn, opts, cluster_name): # This wait time corresponds to SPARK-4983 print("Waiting for AWS to propagate instance metadata...") - time.sleep(5) - # Give the instances descriptive names + time.sleep(15) + + # Give the instances descriptive names and set additional tags + additional_tags = {} + if opts.additional_tags.strip(): + additional_tags = dict( + map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',') + ) + for master in master_nodes: - master.add_tag( - key='Name', - value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + master.add_tags( + dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + ) + for slave in slave_nodes: - slave.add_tag( - key='Name', - value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + slave.add_tags( + dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + ) # Return all the instances return (master_nodes, slave_nodes) From ba8a4537fee7d85f968cccf8d1c607731daae307 Mon Sep 17 00:00:00 2001 From: Pradeep Chhetri Date: Mon, 22 Jun 2015 11:45:31 -0700 Subject: [PATCH 144/151] [SPARK-8482] Added M4 instances to the list. AWS recently added M4 instances (https://aws.amazon.com/blogs/aws/the-new-m4-instance-type-bonus-price-reduction-on-m3-c4/). Author: Pradeep Chhetri Closes #6899 from pradeepchhetri/master and squashes the following commits: 4f4ea79 [Pradeep Chhetri] Added t2.large instance 3d2bb6c [Pradeep Chhetri] Added M4 instances to the list --- ec2/spark_ec2.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 103735685485b..63e2c79669763 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -362,7 +362,7 @@ def get_validate_spark_version(version, repo): # Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ -# Last Updated: 2015-05-08 +# Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. EC2_INSTANCE_TYPES = { "c1.medium": "pvm", @@ -404,6 +404,11 @@ def get_validate_spark_version(version, repo): "m3.large": "hvm", "m3.xlarge": "hvm", "m3.2xlarge": "hvm", + "m4.large": "hvm", + "m4.xlarge": "hvm", + "m4.2xlarge": "hvm", + "m4.4xlarge": "hvm", + "m4.10xlarge": "hvm", "r3.large": "hvm", "r3.xlarge": "hvm", "r3.2xlarge": "hvm", @@ -413,6 +418,7 @@ def get_validate_spark_version(version, repo): "t2.micro": "hvm", "t2.small": "hvm", "t2.medium": "hvm", + "t2.large": "hvm", } @@ -923,7 +929,7 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Last Updated: 2015-05-08 + # Last Updated: 2015-06-19 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { "c1.medium": 1, @@ -965,6 +971,11 @@ def get_num_disks(instance_type): "m3.large": 1, "m3.xlarge": 2, "m3.2xlarge": 2, + "m4.large": 0, + "m4.xlarge": 0, + "m4.2xlarge": 0, + "m4.4xlarge": 0, + "m4.10xlarge": 0, "r3.large": 1, "r3.xlarge": 1, "r3.2xlarge": 1, @@ -974,6 +985,7 @@ def get_num_disks(instance_type): "t2.micro": 0, "t2.small": 0, "t2.medium": 0, + "t2.large": 0, } if instance_type in disks_by_instance: return disks_by_instance[instance_type] From 5d89d9f00ba4d6d0767a4c4964d3af324bf6f14b Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 22 Jun 2015 11:53:11 -0700 Subject: [PATCH 145/151] [SPARK-8511] [PYSPARK] Modify a test to remove a saved model in `regression.py` [[SPARK-8511] Modify a test to remove a saved model in `regression.py` - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-8511) Author: Yu ISHIKAWA Closes #6926 from yu-iskw/SPARK-8511 and squashes the following commits: 7cd0948 [Yu ISHIKAWA] Use `shutil.rmtree()` to temporary directories for saving model testings, instead of `os.removedirs()` 4a01c9e [Yu ISHIKAWA] [SPARK-8511][pyspark] Modify a test to remove a saved model in `regression.py` --- python/pyspark/mllib/classification.py | 9 ++++++--- python/pyspark/mllib/clustering.py | 3 ++- python/pyspark/mllib/recommendation.py | 3 ++- python/pyspark/mllib/regression.py | 14 +++++++++----- python/pyspark/mllib/tests.py | 3 ++- 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 42e41397bf4bc..758accf4b41eb 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -135,8 +135,9 @@ class LogisticRegressionModel(LinearClassificationModel): 1 >>> sameModel.predict(SparseVector(2, {0: 1.0})) 0 + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: ... pass >>> multi_class_data = [ @@ -387,8 +388,9 @@ class SVMModel(LinearClassificationModel): 1 >>> sameModel.predict(SparseVector(2, {0: -1.0})) 0 + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: ... pass """ @@ -515,8 +517,9 @@ class NaiveBayesModel(Saveable, Loader): >>> sameModel = NaiveBayesModel.load(sc, path) >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0})) True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except OSError: ... pass """ diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index c38229864d3b4..e6ef72942ce77 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -79,8 +79,9 @@ class KMeansModel(Saveable, Loader): >>> sameModel = KMeansModel.load(sc, path) >>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0]) True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except OSError: ... pass """ diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 9c4647ddfdcfd..506ca2151cce7 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -106,8 +106,9 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): 0.4... >>> sameModel.predictAll(testset).collect() [Rating(... + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except OSError: ... pass """ diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 0c4d7d3bbee02..5ddbbee4babdd 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -133,10 +133,11 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: - ... pass + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), @@ -275,8 +276,9 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: ... pass >>> data = [ @@ -389,8 +391,9 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except: ... pass >>> data = [ @@ -500,8 +503,9 @@ class IsotonicRegressionModel(Saveable, Loader): 2.0 >>> sameModel.predict(5) 16.5 + >>> from shutil import rmtree >>> try: - ... os.removedirs(path) + ... rmtree(path) ... except OSError: ... pass """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 744dc112d9209..b13159e29d2aa 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -24,6 +24,7 @@ import tempfile import array as pyarray from time import time, sleep +from shutil import rmtree from numpy import array, array_equal, zeros, inf, all, random from numpy import sum as array_sum @@ -398,7 +399,7 @@ def test_classification(self): self.assertEqual(same_gbt_model.toDebugString(), gbt_model.toDebugString()) try: - os.removedirs(temp_dir) + rmtree(temp_dir) except OSError: pass From da7bbb9435dae9a3bedad578599d96ea858f349e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Jun 2015 12:13:00 -0700 Subject: [PATCH 146/151] [SPARK-8104] [SQL] auto alias expressions in analyzer Currently we auto alias expression in parser. However, during parser phase we don't have enough information to do the right alias. For example, Generator that has more than 1 kind of element need MultiAlias, ExtractValue don't need Alias if it's in middle of a ExtractValue chain. Author: Wenchen Fan Closes #6647 from cloud-fan/alias and squashes the following commits: 552eba4 [Wenchen Fan] fix python 5b5786d [Wenchen Fan] fix agg 73a90cb [Wenchen Fan] fix case-preserve of ExtractValue 4cfd23c [Wenchen Fan] fix order by d18f401 [Wenchen Fan] refine 9f07359 [Wenchen Fan] address comments 39c1aef [Wenchen Fan] small fix 33640ec [Wenchen Fan] auto alias expressions in analyzer --- python/pyspark/sql/context.py | 9 ++- .../apache/spark/sql/catalyst/SqlParser.scala | 11 +-- .../sql/catalyst/analysis/Analyzer.scala | 77 ++++++++++++------- .../sql/catalyst/analysis/CheckAnalysis.scala | 9 +-- .../sql/catalyst/analysis/unresolved.scala | 20 ++++- .../catalyst/expressions/ExtractValue.scala | 36 +++++---- .../sql/catalyst/planning/patterns.scala | 6 +- .../catalyst/plans/logical/LogicalPlan.scala | 11 ++- .../plans/logical/basicOperators.scala | 20 ++++- .../scala/org/apache/spark/sql/Column.scala | 1 - .../org/apache/spark/sql/DataFrame.scala | 6 +- .../org/apache/spark/sql/GroupedData.scala | 43 +++++------ .../spark/sql/execution/pythonUdfs.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 6 +- .../scala/org/apache/spark/sql/TestData.scala | 1 - .../org/apache/spark/sql/hive/HiveQl.scala | 9 +-- 16 files changed, 150 insertions(+), 117 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 599c9ac5794a2..dc239226e6d3c 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -86,7 +86,8 @@ def __init__(self, sparkContext, sqlContext=None): >>> df.registerTempTable("allTypes") >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \ + time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ @@ -176,17 +177,17 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(c0=u'4')] + [Row(_c0=u'4')] >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] + [Row(_c0=4)] >>> from pyspark.sql.types import IntegerType >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] + [Row(_c0=4)] """ func = lambda _, it: map(lambda x: f(*x), it) ser = AutoBatchedSerializer(PickleSerializer()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index da3a717f90058..79f526e823cd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -99,13 +99,6 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val WHERE = Keyword("WHERE") protected val WITH = Keyword("WITH") - protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { - exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"c$i")() - } - } - protected lazy val start: Parser[LogicalPlan] = start1 | insert | cte @@ -130,8 +123,8 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { val base = r.getOrElse(OneRowRelation) val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g - .map(Aggregate(_, assignAliases(p), withFilter)) - .getOrElse(Project(assignAliases(p), withFilter)) + .map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter)) + .getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter)) val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct) val withOrder = o.map(_(withHaving)).getOrElse(withHaving) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 21b05760256b4..6311784422a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ @@ -74,10 +72,10 @@ class Analyzer( ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: + ResolveAliases :: ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: - TrimGroupingAliases :: typeCoercionRules ++ extendedResolutionRules : _*) ) @@ -132,12 +130,38 @@ class Analyzer( } /** - * Removes no-op Alias expressions from the plan. + * Replaces [[UnresolvedAlias]]s with concrete aliases. */ - object TrimGroupingAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Aggregate(groups, aggs, child) => - Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child) + object ResolveAliases extends Rule[LogicalPlan] { + private def assignAliases(exprs: Seq[NamedExpression]) = { + // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need + // to transform down the whole tree. + exprs.zipWithIndex.map { + case (u @ UnresolvedAlias(child), i) => + child match { + case _: UnresolvedAttribute => u + case ne: NamedExpression => ne + case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)() + case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) + case e if !e.resolved => u + case other => Alias(other, s"_c$i")() + } + case (other, _) => other + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case Aggregate(groups, aggs, child) + if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) => + Aggregate(groups, assignAliases(aggs), child) + + case g: GroupingAnalytics + if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) => + g.withNewAggs(assignAliases(g.aggregations)) + + case Project(projectList, child) + if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) => + Project(assignAliases(projectList), child) } } @@ -228,7 +252,7 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => + case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => i.copy(table = EliminateSubQueries(getTable(u))) case u: UnresolvedRelation => getTable(u) @@ -248,24 +272,24 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) - case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(child = f.copy(children = expandedArgs), name)() :: Nil - case Alias(c @ CreateArray(args), name) if containsStar(args) => + UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil + case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(c.copy(children = expandedArgs), name)() :: Nil - case Alias(c @ CreateStruct(args), name) if containsStar(args) => + UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil + case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(c.copy(children = expandedArgs), name)() :: Nil + UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil case o => o :: Nil }, child) @@ -353,7 +377,9 @@ class Analyzer( case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + withPosition(u) { + q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) + } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -379,6 +405,11 @@ class Analyzer( exprs.exists(_.collect { case _: Star => true }.nonEmpty) } + private def trimUnresolvedAlias(ne: NamedExpression) = ne match { + case UnresolvedAlias(child) => child + case other => other + } + private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { ordering.map { order => // Resolve SortOrder in one round. @@ -388,7 +419,7 @@ class Analyzer( try { val newOrder = order transformUp { case u @ UnresolvedAttribute(nameParts) => - plan.resolve(nameParts, resolver).getOrElse(u) + plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } @@ -586,18 +617,6 @@ class Analyzer( /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ private object AliasedGenerator { def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { - case Alias(g: Generator, name) - if g.resolved && - g.elementTypes.size > 1 && - java.util.regex.Pattern.matches("_c[0-9]+", name) => { - // Assume the default name given by parser is "_c[0-9]+", - // TODO in long term, move the naming logic from Parser to Analyzer. - // In projection, Parser gave default name for TGF as does for normal UDF, - // but the TGF probably have multiple output columns/names. - // e.g. SELECT explode(map(key, value)) FROM src; - // Let's simply ignore the default given name for this case. - Some((g, Nil)) - } case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 => // If not given the default names, and the TGF with multiple output columns failAnalysis( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7fabd2bfc80ab..c5a1437be6d05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -95,14 +95,7 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } - val cleaned = aggregateExprs.map(_.transform { - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - case Alias(g, _) => g - }) - - cleaned.foreach(checkValidAggregateExpression) + aggregateExprs.foreach(checkValidAggregateExpression) case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index c9d91425788a8..ae3adbab05108 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.{errors, trees} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ @@ -206,3 +205,22 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override def toString: String = s"$child[$extraction]" } + +/** + * Holds the expression that has yet to be aliased. + */ +case class UnresolvedAlias(child: Expression) extends NamedExpression + with trees.UnaryNode[Expression] { + + override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def name: String = throw new UnresolvedException(this, "name") + + override lazy val resolved = false + + override def eval(input: InternalRow = null): Any = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 4aaabff15b6ee..013027b199e63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.{catalyst, AnalysisException} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.types._ @@ -41,16 +41,22 @@ object ExtractValue { resolver: Resolver): ExtractValue = { (child.dataType, extraction) match { - case (StructType(fields), Literal(fieldName, StringType)) => - val ordinal = findField(fields, fieldName.toString, resolver) - GetStructField(child, fields(ordinal), ordinal) - case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) => - val ordinal = findField(fields, fieldName.toString, resolver) - GetArrayStructFields(child, fields(ordinal), ordinal, containsNull) + case (StructType(fields), NonNullLiteral(v, StringType)) => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + + case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull) + case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) + case (_: MapType, _) => GetMapValue(child, extraction) + case (otherType, _) => val errorMsg = otherType match { case StructType(_) | ArrayType(StructType(_), _) => @@ -94,16 +100,21 @@ trait ExtractValue extends UnaryExpression { self: Product => } +abstract class ExtractValueWithStruct extends ExtractValue { + self: Product => + + def field: StructField + override def toString: String = s"$child.${field.name}" +} + /** * Returns the value of fields in the Struct `child`. */ case class GetStructField(child: Expression, field: StructField, ordinal: Int) - extends ExtractValue { + extends ExtractValueWithStruct { override def dataType: DataType = field.dataType override def nullable: Boolean = child.nullable || field.nullable - override def foldable: Boolean = child.foldable - override def toString: String = s"$child.${field.name}" override def eval(input: InternalRow): Any = { val baseValue = child.eval(input).asInstanceOf[InternalRow] @@ -118,12 +129,9 @@ case class GetArrayStructFields( child: Expression, field: StructField, ordinal: Int, - containsNull: Boolean) extends ExtractValue { + containsNull: Boolean) extends ExtractValueWithStruct { override def dataType: DataType = ArrayType(field.dataType, containsNull) - override def nullable: Boolean = child.nullable - override def foldable: Boolean = child.foldable - override def toString: String = s"$child.${field.name}" override def eval(input: InternalRow): Any = { val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 3b6f8bfd9ff9b..179a348d5baac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -156,12 +156,8 @@ object PartialAggregation { partialEvaluations(new TreeNodeRef(e)).finalEvaluation case e: Expression => - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - val trimmed = e.transform { case Alias(g: ExtractValue, _) => g } namedGroupingExpressions.collectFirst { - case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute + case (expr, ne) if expr semanticEquals e => ne.toAttribute }.getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a853e27c1212d..b009a200b920f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode @@ -252,14 +252,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and aliases it with the last part of the identifier. + // and wrap it with UnresolvedAlias which will be removed later. // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias - // the final expression as "c". + // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as + // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))). val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => ExtractValue(expr, Literal(fieldName), resolver)) - val aliasName = nestedFields.last - Some(Alias(fieldExprs, aliasName)()) + Some(UnresolvedAlias(fieldExprs)) // No matches. case Seq() => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 963c7820914f3..f8e5916d69f9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -242,6 +242,8 @@ trait GroupingAnalytics extends UnaryNode { def aggregations: Seq[NamedExpression] override def output: Seq[Attribute] = aggregations.map(_.toAttribute) + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics } /** @@ -266,7 +268,11 @@ case class GroupingSets( groupByExprs: Seq[Expression], child: LogicalPlan, aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} /** * Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, @@ -284,7 +290,11 @@ case class Cube( groupByExprs: Seq[Expression], child: LogicalPlan, aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} /** * Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, @@ -303,7 +313,11 @@ case class Rollup( groupByExprs: Seq[Expression], child: LogicalPlan, aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b4e008a6e8480..f201c8ea8a110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -21,7 +21,6 @@ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 466258e76f9f6..492a3321bc0bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -32,7 +32,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -629,6 +629,10 @@ class DataFrame private[sql]( @scala.annotation.varargs def select(cols: Column*): DataFrame = { val namedExpressions = cols.map { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) case Column(expr: NamedExpression) => expr // Leave an unaliased explode with an empty list of names since the analzyer will generate the // correct defaults after the nested expression's type has been resolved. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 45b3e1bc627d5..99d557b03a033 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType @@ -70,27 +70,31 @@ class GroupedData protected[sql]( groupingExprs: Seq[Expression], private val groupType: GroupedData.GroupType) { - private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - val retainedExprs = groupingExprs.map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - retainedExprs ++ aggExprs - } else { - aggExprs - } + groupingExprs ++ aggExprs + } else { + aggExprs + } + val aliasedAgg = aggregates.map { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } groupType match { case GroupedData.GroupByType => DataFrame( - df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan)) + df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case GroupedData.RollupType => DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates)) + df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) case GroupedData.CubeType => DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates)) + df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) } } @@ -112,10 +116,7 @@ class GroupedData protected[sql]( namedExpr } } - toDF(columnExprs.map { c => - val a = f(c) - Alias(a, a.prettyString)() - }) + toDF(columnExprs.map(f)) } private[this] def strToExpr(expr: String): (Expression => Expression) = { @@ -169,8 +170,7 @@ class GroupedData protected[sql]( */ def agg(exprs: Map[String, String]): DataFrame = { toDF(exprs.map { case (colName, expr) => - val a = strToExpr(expr)(df(colName).expr) - Alias(a, a.prettyString)() + strToExpr(expr)(df(colName).expr) }.toSeq) } @@ -224,10 +224,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr).map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - }) + toDF((expr +: exprs).map(_.expr)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 1ce150ceaf5f9..c8c67ce334002 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -74,7 +74,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { // Skip EvaluatePython nodes. case plan: EvaluatePython => plan - case plan: LogicalPlan => + case plan: LogicalPlan if plan.resolved => // Extract any PythonUDFs from the current operator. val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) if (udfs.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4441afd6bd811..73bc6c999164e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1367,9 +1367,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-6145: special cases") { sqlContext.read.json(sqlContext.sparkContext.makeRDD( - """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") - checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) - checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) + """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 520a862ea0838..207d7a352c7b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import java.sql.Timestamp -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ca4b80b51b23f..7c4620952ba4b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -415,13 +415,6 @@ private[hive] object HiveQl { throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") } - protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = { - exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"_c$i")() - } - } - protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { val (db, tableName) = tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { @@ -942,7 +935,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { val selectExpressions = - nameExpressions(select.getChildren.flatMap(selExprNodeToExpr).toSeq) + select.getChildren.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)).toSeq Seq( groupByClause.map(e => e match { case Token("TOK_GROUPBY", children) => From 5ab9fcfb01a0ad2f6c103f67c1a785d3b49e33f0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 22 Jun 2015 13:51:23 -0700 Subject: [PATCH 147/151] [SPARK-8532] [SQL] In Python's DataFrameWriter, save/saveAsTable/json/parquet/jdbc always override mode https://issues.apache.org/jira/browse/SPARK-8532 This PR has two changes. First, it fixes the bug that save actions (i.e. `save/saveAsTable/json/parquet/jdbc`) always override mode. Second, it adds input argument `partitionBy` to `save/saveAsTable/parquet`. Author: Yin Huai Closes #6937 from yhuai/SPARK-8532 and squashes the following commits: f972d5d [Yin Huai] davies's comment. d37abd2 [Yin Huai] style. d21290a [Yin Huai] Python doc. 889eb25 [Yin Huai] Minor refactoring and add partitionBy to save, saveAsTable, and parquet. 7fbc24b [Yin Huai] Use None instead of "error" as the default value of mode since JVM-side already uses "error" as the default value. d696dff [Yin Huai] Python style. 88eb6c4 [Yin Huai] If mode is "error", do not call mode method. c40c461 [Yin Huai] Regression test. --- python/pyspark/sql/readwriter.py | 30 +++++++++++++++++++----------- python/pyspark/sql/tests.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index f036644acc961..1b7bc0f9a12be 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -218,7 +218,10 @@ def mode(self, saveMode): >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - self._jwrite = self._jwrite.mode(saveMode) + # At the JVM side, the default value of mode is already set to "error". + # So, if the given saveMode is None, we will not call JVM-side's mode method. + if saveMode is not None: + self._jwrite = self._jwrite.mode(saveMode) return self @since(1.4) @@ -253,11 +256,12 @@ def partitionBy(self, *cols): """ if len(cols) == 1 and isinstance(cols[0], (list, tuple)): cols = cols[0] - self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) + if len(cols) > 0: + self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) return self @since(1.4) - def save(self, path=None, format=None, mode="error", **options): + def save(self, path=None, format=None, mode=None, partitionBy=(), **options): """Saves the contents of the :class:`DataFrame` to a data source. The data source is specified by the ``format`` and a set of ``options``. @@ -272,11 +276,12 @@ def save(self, path=None, format=None, mode="error", **options): * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns :param options: all other string options >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - self.mode(mode).options(**options) + self.partitionBy(partitionBy).mode(mode).options(**options) if format is not None: self.format(format) if path is None: @@ -296,7 +301,7 @@ def insertInto(self, tableName, overwrite=False): self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName) @since(1.4) - def saveAsTable(self, name, format=None, mode="error", **options): + def saveAsTable(self, name, format=None, mode=None, partitionBy=(), **options): """Saves the content of the :class:`DataFrame` as the specified table. In the case the table already exists, behavior of this function depends on the @@ -312,15 +317,16 @@ def saveAsTable(self, name, format=None, mode="error", **options): :param name: the table name :param format: the format used to save :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param partitionBy: names of partitioning columns :param options: all other string options """ - self.mode(mode).options(**options) + self.partitionBy(partitionBy).mode(mode).options(**options) if format is not None: self.format(format) self._jwrite.saveAsTable(name) @since(1.4) - def json(self, path, mode="error"): + def json(self, path, mode=None): """Saves the content of the :class:`DataFrame` in JSON format at the specified path. :param path: the path in any Hadoop supported file system @@ -333,10 +339,10 @@ def json(self, path, mode="error"): >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) """ - self._jwrite.mode(mode).json(path) + self.mode(mode)._jwrite.json(path) @since(1.4) - def parquet(self, path, mode="error"): + def parquet(self, path, mode=None, partitionBy=()): """Saves the content of the :class:`DataFrame` in Parquet format at the specified path. :param path: the path in any Hadoop supported file system @@ -346,13 +352,15 @@ def parquet(self, path, mode="error"): * ``overwrite``: Overwrite existing data. * ``ignore``: Silently ignore this operation if data already exists. * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - self._jwrite.mode(mode).parquet(path) + self.partitionBy(partitionBy).mode(mode) + self._jwrite.parquet(path) @since(1.4) - def jdbc(self, url, table, mode="error", properties={}): + def jdbc(self, url, table, mode=None, properties={}): """Saves the content of the :class:`DataFrame` to a external database table via JDBC. .. note:: Don't create too many partitions in parallel on a large cluster;\ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b5fbb7d098820..13f4556943ac8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -539,6 +539,38 @@ def test_save_and_load(self): shutil.rmtree(tmpPath) + def test_save_and_load_builder(self): + df = self.df + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + df.write.json(tmpPath) + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + schema = StructType([StructField("value", StringType(), True)]) + actual = self.sqlCtx.read.json(tmpPath, schema) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + + df.write.mode("overwrite").json(tmpPath) + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + df.write.mode("overwrite").options(noUse="this options will not be used in save.")\ + .format("json").save(path=tmpPath) + actual =\ + self.sqlCtx.read.format("json")\ + .load(path=tmpPath, noUse="this options will not be used in load.") + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + + defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.sqlCtx.load(path=tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + + shutil.rmtree(tmpPath) + def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) From afe35f0519bc7dcb85010a7eedcff854d4fc313a Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 22 Jun 2015 14:15:35 -0700 Subject: [PATCH 148/151] [SPARK-8455] [ML] Implement n-gram feature transformer Implementation of n-gram feature transformer for ML. Author: Feynman Liang Closes #6887 from feynmanliang/ngram-featurizer and squashes the following commits: d2c839f [Feynman Liang] Make n > input length yield empty output 9fadd36 [Feynman Liang] Add empty and corner test cases, fix names and spaces fe93873 [Feynman Liang] Implement n-gram feature transformer --- .../org/apache/spark/ml/feature/NGram.scala | 69 ++++++++++++++ .../apache/spark/ml/feature/NGramSuite.scala | 94 +++++++++++++++++++ 2 files changed, 163 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala new file mode 100644 index 0000000000000..8de10eb51f923 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.types.{ArrayType, DataType, StringType} + +/** + * :: Experimental :: + * A feature transformer that converts the input array of strings into an array of n-grams. Null + * values in the input array are ignored. + * It returns an array of n-grams where each n-gram is represented by a space-separated string of + * words. + * + * When the input is empty, an empty array is returned. + * When the input array length is less than n (number of elements per n-gram), no n-grams are + * returned. + */ +@Experimental +class NGram(override val uid: String) + extends UnaryTransformer[Seq[String], Seq[String], NGram] { + + def this() = this(Identifiable.randomUID("ngram")) + + /** + * Minimum n-gram length, >= 1. + * Default: 2, bigram features + * @group param + */ + val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)", + ParamValidators.gtEq(1)) + + /** @group setParam */ + def setN(value: Int): this.type = set(n, value) + + /** @group getParam */ + def getN: Int = $(n) + + setDefault(n -> 2) + + override protected def createTransformFunc: Seq[String] => Seq[String] = { + _.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.sameType(ArrayType(StringType)), + s"Input type must be ArrayType(StringType) but got $inputType.") + } + + override protected def outputDataType: DataType = new ArrayType(StringType, false) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala new file mode 100644 index 0000000000000..ab97e3dbc6ee0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +@BeanInfo +case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) + +class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { + import org.apache.spark.ml.feature.NGramSuite._ + + test("default behavior yields bigram features") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array("Test", "for", "ngram", "."), + Array("Test for", "for ngram", "ngram .") + ))) + testNGram(nGram, dataset) + } + + test("NGramLength=4 yields length 4 n-grams") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + .setN(4) + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array("a", "b", "c", "d", "e"), + Array("a b c d", "b c d e") + ))) + testNGram(nGram, dataset) + } + + test("empty input yields empty output") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + .setN(4) + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array(), + Array() + ))) + testNGram(nGram, dataset) + } + + test("input array < n yields empty output") { + val nGram = new NGram() + .setInputCol("inputTokens") + .setOutputCol("nGrams") + .setN(6) + val dataset = sqlContext.createDataFrame(Seq( + NGramTestData( + Array("a", "b", "c", "d", "e"), + Array() + ))) + testNGram(nGram, dataset) + } +} + +object NGramSuite extends SparkFunSuite { + + def testNGram(t: NGram, dataset: DataFrame): Unit = { + t.transform(dataset) + .select("nGrams", "wantedNGrams") + .collect() + .foreach { case Row(actualNGrams, wantedNGrams) => + assert(actualNGrams === wantedNGrams) + } + } +} From b1f3a489efc6f4f9d172344c3345b9b38ae235e0 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 22 Jun 2015 14:35:38 -0700 Subject: [PATCH 149/151] [SPARK-8537] [SPARKR] Add a validation rule about the curly braces in SparkR to `.lintr` [[SPARK-8537] Add a validation rule about the curly braces in SparkR to `.lintr` - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-8537) Author: Yu ISHIKAWA Closes #6940 from yu-iskw/SPARK-8537 and squashes the following commits: 7eec1a0 [Yu ISHIKAWA] [SPARK-8537][SparkR] Add a validation rule about the curly braces in SparkR to `.lintr` --- R/pkg/.lintr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/.lintr b/R/pkg/.lintr index b10ebd35c4ca7..038236fc149e6 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL) +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") From 50d3242d6a5530a51dacab249e3f3d49e2d50635 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Mon, 22 Jun 2015 15:06:47 -0700 Subject: [PATCH 150/151] [SPARK-8356] [SQL] Reconcile callUDF and callUdf Deprecates ```callUdf``` in favor of ```callUDF```. Author: BenFradet Closes #6902 from BenFradet/SPARK-8356 and squashes the following commits: ef4e9d8 [BenFradet] deprecated callUDF, use udf instead 9b1de4d [BenFradet] reinstated unit test for the deprecated callUdf cbd80a5 [BenFradet] deprecated callUdf in favor of callUDF --- .../org/apache/spark/sql/functions.scala | 45 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 11 ++++- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7e7a099a8318b..8cea826ae6921 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1448,7 +1448,9 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { ScalaUdf(f, returnType, Seq($argsInUdf)) }""") @@ -1584,7 +1586,9 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function0[_], returnType: DataType): Column = { ScalaUdf(f, returnType, Seq()) } @@ -1595,7 +1599,9 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { ScalaUdf(f, returnType, Seq(arg1.expr)) } @@ -1606,7 +1612,9 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) } @@ -1617,7 +1625,9 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } @@ -1628,7 +1638,9 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } @@ -1639,7 +1651,9 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } @@ -1650,7 +1664,9 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } @@ -1661,7 +1677,9 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } @@ -1672,7 +1690,9 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } @@ -1683,7 +1703,9 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } @@ -1694,13 +1716,34 @@ object functions { * * @group udf_funcs * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() */ + @deprecated("Use udf", "1.5.0") def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } // scalastyle:on + /** + * Call an user-defined function. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + * val sqlContext = df.sqlContext + * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) + * df.select($"id", callUDF("simpleUdf", $"value")) + * }}} + * + * @group udf_funcs + * @since 1.5.0 + */ + def callUDF(udfName: String, cols: Column*): Column = { + UnresolvedFunction(udfName, cols.map(_.expr)) + } + /** * Call an user-defined function. * Example: @@ -1715,7 +1758,9 @@ object functions { * * @group udf_funcs * @since 1.4.0 + * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF */ + @deprecated("Use callUDF", "1.5.0") def callUdf(udfName: String, cols: Column*): Column = { UnresolvedFunction(udfName, cols.map(_.expr)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ba1d020f22f11..47443a917b765 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -301,7 +301,7 @@ class DataFrameSuite extends QueryTest { ) } - test("call udf in SQLContext") { + test("deprecated callUdf in SQLContext") { val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") val sqlctx = df.sqlContext sqlctx.udf.register("simpleUdf", (v: Int) => v * v) @@ -310,6 +310,15 @@ class DataFrameSuite extends QueryTest { Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) } + test("callUDF in SQLContext") { + val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + val sqlctx = df.sqlContext + sqlctx.udf.register("simpleUDF", (v: Int) => v * v) + checkAnswer( + df.select($"id", callUDF("simpleUDF", $"value")), + Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) + } + test("withColumn") { val df = testData.toDF().withColumn("newCol", col("key") + 1) checkAnswer( From 96aa01378e3b3dbb4601d31c7312a311cb65b22e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 22 Jun 2015 15:22:17 -0700 Subject: [PATCH 151/151] [SPARK-8492] [SQL] support binaryType in UnsafeRow Support BinaryType in UnsafeRow, just like StringType. Also change the layout of StringType and BinaryType in UnsafeRow, by combining offset and size together as Long, which will limit the size of Row to under 2G (given that fact that any single buffer can not be bigger than 2G in JVM). Author: Davies Liu Closes #6911 from davies/unsafe_bin and squashes the following commits: d68706f [Davies Liu] update comment 519f698 [Davies Liu] address comment 98a964b [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_bin 180b49d [Davies Liu] fix zero-out 22e4c0a [Davies Liu] zero-out padding bytes 6abfe93 [Davies Liu] fix style 447dea0 [Davies Liu] support binaryType in UnsafeRow --- .../UnsafeFixedWidthAggregationMap.java | 8 --- .../sql/catalyst/expressions/UnsafeRow.java | 34 ++++++----- .../expressions/UnsafeRowConverter.scala | 60 ++++++++++++++----- .../expressions/UnsafeRowConverterSuite.scala | 16 ++--- 4 files changed, 72 insertions(+), 46 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index f7849ebebc573..83f2a312972fb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import java.util.Arrays; import java.util.Iterator; import org.apache.spark.sql.catalyst.InternalRow; @@ -142,14 +141,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); // Make sure that the buffer is large enough to hold the key. If it's not, grow it: if (groupingKeySize > groupingKeyConversionScratchSpace.length) { - // This new array will be initially zero, so there's no need to zero it out here groupingKeyConversionScratchSpace = new byte[groupingKeySize]; - } else { - // Zero out the buffer that's used to hold the current row. This is necessary in order - // to ensure that rows hash properly, since garbage data from the previous row could - // otherwise end up as padding in this row. As a performance optimization, we only zero out - // the portion of the buffer that we'll actually write to. - Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, (byte) 0); } final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( groupingKey, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index ed04d2e50ec84..bb2f2079b40f0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -47,7 +47,8 @@ * In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length * primitive types, such as long, double, or int, we store the value directly in the word. For * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the - * base address of the row) that points to the beginning of the variable-length field. + * base address of the row) that points to the beginning of the variable-length field, and length + * (they are combined into a long). * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ @@ -92,6 +93,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { */ public static final Set readableFieldTypes; + // TODO: support DecimalType static { settableFieldTypes = Collections.unmodifiableSet( new HashSet( @@ -111,7 +113,8 @@ public static int calculateBitSetWidthInBytes(int numFields) { // We support get() on a superset of the types for which we support set(): final Set _readableFieldTypes = new HashSet( Arrays.asList(new DataType[]{ - StringType + StringType, + BinaryType })); _readableFieldTypes.addAll(settableFieldTypes); readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); @@ -221,11 +224,6 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } - @Override - public void setString(int ordinal, String value) { - throw new UnsupportedOperationException(); - } - @Override public int size() { return numFields; @@ -249,6 +247,8 @@ public Object get(int i) { return null; } else if (dataType == StringType) { return getUTF8String(i); + } else if (dataType == BinaryType) { + return getBinary(i); } else { throw new UnsupportedOperationException(); } @@ -311,19 +311,23 @@ public double getDouble(int i) { } public UTF8String getUTF8String(int i) { + return UTF8String.fromBytes(getBinary(i)); + } + + public byte[] getBinary(int i) { assertIndexIsValid(i); - final long offsetToStringSize = getLong(i); - final int stringSizeInBytes = - (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize); - final byte[] strBytes = new byte[stringSizeInBytes]; + final long offsetAndSize = getLong(i); + final int offset = (int)(offsetAndSize >> 32); + final int size = (int)(offsetAndSize & ((1L << 32) - 1)); + final byte[] bytes = new byte[size]; PlatformDependent.copyMemory( baseObject, - baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data - strBytes, + baseOffset + offset, + bytes, PlatformDependent.BYTE_ARRAY_OFFSET, - stringSizeInBytes + size ); - return UTF8String.fromBytes(strBytes); + return bytes; } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 72f740ecaead3..89adaf053b1a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -72,6 +70,19 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { */ def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long): Int = { unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + + if (writers.length > 0) { + // zero-out the bitset + var n = writers.length / 64 + while (n >= 0) { + PlatformDependent.UNSAFE.putLong( + unsafeRow.getBaseObject, + unsafeRow.getBaseOffset + n * 8, + 0L) + n -= 1 + } + } + var fieldNumber = 0 var appendCursor: Int = fixedLengthSize while (fieldNumber < writers.length) { @@ -122,6 +133,7 @@ private object UnsafeColumnWriter { case FloatType => FloatUnsafeColumnWriter case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter + case BinaryType => BinaryUnsafeColumnWriter case DateType => IntUnsafeColumnWriter case TimestampType => LongUnsafeColumnWriter case t => @@ -141,6 +153,7 @@ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter +private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: @@ -235,10 +248,13 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr } } -private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { +private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { + + def getBytes(source: InternalRow, column: Int): Array[Byte] + def getSize(source: InternalRow, column: Int): Int = { - val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length - 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + val numBytes = getBytes(source, column).length + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } override def write( @@ -246,19 +262,33 @@ private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { target: UnsafeRow, column: Int, appendCursor: Int): Int = { - val value = source.get(column).asInstanceOf[UTF8String] - val baseObject = target.getBaseObject - val baseOffset = target.getBaseOffset - val numBytes = value.getBytes.length - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) + val offset = target.getBaseOffset + appendCursor + val bytes = getBytes(source, column) + val numBytes = bytes.length + if ((numBytes & 0x07) > 0) { + // zero-out the padding bytes + PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + ((numBytes >> 3) << 3), 0L) + } PlatformDependent.copyMemory( - value.getBytes, + bytes, PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - baseOffset + appendCursor + 8, + target.getBaseObject, + offset, numBytes ) - target.setLong(column, appendCursor) - 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + target.setLong(column, (appendCursor.toLong << 32L) | numBytes.toLong) + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } +} + +private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + def getBytes(source: InternalRow, column: Int): Array[Byte] = { + source.getAs[UTF8String](column).getBytes + } +} + +private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + def getBytes(source: InternalRow, column: Int): Array[Byte] = { + source.getAs[Array[Byte]](column) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 721ef8a22608c..d8f3351d6dff6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -23,8 +23,8 @@ import java.util.Arrays import org.scalatest.Matchers import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -52,19 +52,19 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { unsafeRow.getInt(2) should be (2) } - test("basic conversion with primitive and string types") { - val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType) + test("basic conversion with primitive, string and binary types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) val converter = new UnsafeRowConverter(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.setString(1, "Hello") - row.setString(2, "World") + row.update(2, "World".getBytes) val sizeRequired: Int = converter.getSizeRequirement(row) sizeRequired should be (8 + (8 * 3) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8)) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) numBytesWritten should be (sizeRequired) @@ -73,7 +73,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) unsafeRow.getLong(0) should be (0) unsafeRow.getString(1) should be ("Hello") - unsafeRow.getString(2) should be ("World") + unsafeRow.getBinary(2) should be ("World".getBytes) } test("basic conversion with primitive, string, date and timestamp types") { @@ -88,7 +88,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(row) sizeRequired should be (8 + (8 * 4) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8)) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) numBytesWritten should be (sizeRequired)