From d96d7b55746cf034e3935ec4b22614a99e48c498 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 23 Jun 2015 14:19:21 -0700 Subject: [PATCH 01/44] [DOC] [SQL] Addes Hive metastore Parquet table conversion section This PR adds a section about Hive metastore Parquet table conversion. It documents: 1. Schema reconciliation rules introduced in #5214 (see [this comment] [1] in #5188) 2. Metadata refreshing requirement introduced in #5339 [1]: https://github.com/apache/spark/pull/5188#issuecomment-86531248 Author: Cheng Lian Closes #5348 from liancheng/sql-doc-parquet-conversion and squashes the following commits: 42ae0d0 [Cheng Lian] Adds Python `refreshTable` snippet 4c9847d [Cheng Lian] Resorts to SQL for Python metadata refreshing snippet 756e660 [Cheng Lian] Adds Python snippet for metadata refreshing 50675db [Cheng Lian] Addes Hive metastore Parquet table conversion section --- docs/sql-programming-guide.md | 94 ++++++++++++++++++++++++++++++++--- 1 file changed, 88 insertions(+), 6 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 26c036f6648da..9107c9b67681f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -22,7 +22,7 @@ The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark. All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. -## Starting Point: `SQLContext` +## Starting Point: SQLContext
@@ -1036,6 +1036,15 @@ for (teenName in collect(teenNames)) {
+
+ +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.sql("REFRESH TABLE my_table") +{% endhighlight %} + +
+
{% highlight sql %} @@ -1054,7 +1063,7 @@ SELECT * FROM parquetTable
-### Partition discovery +### Partition Discovery Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in @@ -1108,7 +1117,7 @@ can be configured by `spark.sql.sources.partitionColumnTypeInference.enabled`, w `true`. When type inference is disabled, string type will be used for the partitioning columns. -### Schema merging +### Schema Merging Like ProtocolBuffer, Avro, and Thrift, Parquet also supports schema evolution. Users can start with a simple schema, and gradually add more columns to the schema as needed. In this way, users may end @@ -1208,6 +1217,79 @@ printSchema(df3)
+### Hive metastore Parquet table conversion + +When reading from and writing to Hive metastore Parquet tables, Spark SQL will try to use its own +Parquet support instead of Hive SerDe for better performance. This behavior is controlled by the +`spark.sql.hive.convertMetastoreParquet` configuration, and is turned on by default. + +#### Hive/Parquet Schema Reconciliation + +There are two key differences between Hive and Parquet from the perspective of table schema +processing. + +1. Hive is case insensitive, while Parquet is not +1. Hive considers all columns nullable, while nullability in Parquet is significant + +Due to this reason, we must reconcile Hive metastore schema with Parquet schema when converting a +Hive metastore Parquet table to a Spark SQL Parquet table. The reconciliation rules are: + +1. Fields that have the same name in both schema must have the same data type regardless of + nullability. The reconciled field should have the data type of the Parquet side, so that + nullability is respected. + +1. The reconciled schema contains exactly those fields defined in Hive metastore schema. + + - Any fields that only appear in the Parquet schema are dropped in the reconciled schema. + - Any fileds that only appear in the Hive metastore schema are added as nullable field in the + reconciled schema. + +#### Metadata Refreshing + +Spark SQL caches Parquet metadata for better performance. When Hive metastore Parquet table +conversion is enabled, metadata of those converted tables are also cached. If these tables are +updated by Hive or other external tools, you need to refresh them manually to ensure consistent +metadata. + +
+ +
+ +{% highlight scala %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
+ +
+ +{% highlight java %} +// sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
+ +
+ +{% highlight python %} +# sqlContext is an existing HiveContext +sqlContext.refreshTable("my_table") +{% endhighlight %} + +
+ +
+ +{% highlight sql %} +REFRESH TABLE my_table; +{% endhighlight %} + +
+ +
+ ### Configuration Configuration of Parquet can be done using the `setConf` method on `SQLContext` or by running @@ -1445,8 +1527,8 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running -the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running +the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the `spark-submit` command. @@ -1889,7 +1971,7 @@ options. #### DataFrame data reader/writer interface Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) -and writing data out (`DataFrame.write`), +and writing data out (`DataFrame.write`), and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). See the API docs for `SQLContext.read` ( From 7fb5ae5024284593204779ff463bfbdb4d1c6da5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 23 Jun 2015 15:51:16 -0700 Subject: [PATCH 02/44] [SPARK-8573] [SPARK-8568] [SQL] [PYSPARK] raise Exception if column is used in booelan expression It's a common mistake that user will put Column in a boolean expression (together with `and` , `or`), which does not work as expected, we should raise a exception in that case, and suggest user to use `&`, `|` instead. Author: Davies Liu Closes #6961 from davies/column_bool and squashes the following commits: 9f19beb [Davies Liu] update message af74bd6 [Davies Liu] fix tests 07dff84 [Davies Liu] address comments, fix tests f70c08e [Davies Liu] raise Exception if column is used in booelan expression --- python/pyspark/sql/column.py | 5 +++++ python/pyspark/sql/tests.py | 10 +++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 1ecec5b126505..0a85da7443d3d 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -396,6 +396,11 @@ def over(self, window): jc = self._jc.over(window._jspec) return Column(jc) + def __nonzero__(self): + raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', " + "'~' for 'not' when building DataFrame boolean expressions.") + __bool__ = __nonzero__ + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 13f4556943ac8..e6a434e4b2dff 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -164,6 +164,14 @@ def test_explode(self): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") + def test_and_in_expression(self): + self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) + self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) + self.assertEqual(14, self.df.filter((self.df.key <= 3) | (self.df.value < "2")).count()) + self.assertRaises(ValueError, lambda: self.df.key <= 3 or self.df.value < "2") + self.assertEqual(99, self.df.filter(~(self.df.key == 1)).count()) + self.assertRaises(ValueError, lambda: not self.df.key == 1) + def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) @@ -408,7 +416,7 @@ def test_column_operators(self): self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) self.assertTrue(all(isinstance(c, Column) for c in rcc)) - cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7] self.assertTrue(all(isinstance(c, Column) for c in cb)) cbool = (ci & ci), (ci | ci), (~ci) self.assertTrue(all(isinstance(c, Column) for c in cbool)) From 111d6b9b8a584b962b6ae80c7aa8c45845ce0099 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 23 Jun 2015 17:24:26 -0700 Subject: [PATCH 03/44] [SPARK-8139] [SQL] Updates docs and comments of data sources and Parquet output committer options This PR only applies to master branch (1.5.0-SNAPSHOT) since it references `org.apache.parquet` classes which only appear in Parquet 1.7.0. Author: Cheng Lian Closes #6683 from liancheng/output-committer-docs and squashes the following commits: b4648b8 [Cheng Lian] Removes spark.sql.sources.outputCommitterClass as it's not a public option ee63923 [Cheng Lian] Updates docs and comments of data sources and Parquet output committer options --- docs/sql-programming-guide.md | 30 +++++++++++++++- .../scala/org/apache/spark/sql/SQLConf.scala | 30 ++++++++++++---- .../DirectParquetOutputCommitter.scala | 34 +++++++++++++------ .../apache/spark/sql/parquet/newParquet.scala | 4 +-- 4 files changed, 78 insertions(+), 20 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9107c9b67681f..2786e3d2cd6bf 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1348,6 +1348,34 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` support. + + spark.sql.parquet.output.committer.class + org.apache.parquet.hadoop.
ParquetOutputCommitter
+ +

+ The output committer class used by Parquet. The specified class needs to be a subclass of + org.apache.hadoop.
mapreduce.OutputCommitter
. Typically, it's also a + subclass of org.apache.parquet.hadoop.ParquetOutputCommitter. +

+

+ Note: +

    +
  • + This option must be set via Hadoop Configuration rather than Spark + SQLConf. +
  • +
  • + This option overrides spark.sql.sources.
    outputCommitterClass
    . +
  • +
+

+

+ Spark SQL comes with a builtin + org.apache.spark.sql.
parquet.DirectParquetOutputCommitter
, which can be more + efficient then the default Parquet output committer when writing data to S3. +

+ + ## JSON Datasets @@ -1876,7 +1904,7 @@ that these options will be deprecated in future release as more optimizations ar Configures the number of partitions to use when shuffling data for joins or aggregations. - + spark.sql.planner.externalSort false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 16493c3d7c19c..265352647fa9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -22,6 +22,8 @@ import java.util.Properties import scala.collection.immutable import scala.collection.JavaConversions._ +import org.apache.parquet.hadoop.ParquetOutputCommitter + import org.apache.spark.sql.catalyst.CatalystConf private[spark] object SQLConf { @@ -252,9 +254,9 @@ private[spark] object SQLConf { val PARQUET_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.parquet.filterPushdown", defaultValue = Some(false), - doc = "Turn on Parquet filter pushdown optimization. This feature is turned off by default" + - " because of a known bug in Paruet 1.6.0rc3 " + - "(PARQUET-136). However, " + + doc = "Turn on Parquet filter pushdown optimization. This feature is turned off by default " + + "because of a known bug in Parquet 1.6.0rc3 " + + "(PARQUET-136, https://issues.apache.org/jira/browse/PARQUET-136). However, " + "if your table doesn't contain any nullable string or binary columns, it's still safe to " + "turn this feature on.") @@ -262,11 +264,21 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "") + val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf( + key = "spark.sql.parquet.output.committer.class", + defaultValue = Some(classOf[ParquetOutputCommitter].getName), + doc = "The output committer class used by Parquet. The specified class needs to be a " + + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + + "of org.apache.parquet.hadoop.ParquetOutputCommitter. NOTE: 1. Instead of SQLConf, this " + + "option must be set in Hadoop Configuration. 2. This option overrides " + + "\"spark.sql.sources.outputCommitterClass\"." + ) + val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", defaultValue = Some(false), doc = "") - val HIVE_VERIFY_PARTITIONPATH = booleanConf("spark.sql.hive.verifyPartitionPath", + val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", defaultValue = Some(true), doc = "") @@ -325,9 +337,13 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "") - // The output committer class used by FSBasedRelation. The specified class needs to be a + // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. - // NOTE: This property should be set in Hadoop `Configuration` rather than Spark `SQLConf` + // + // NOTE: + // + // 1. Instead of SQLConf, this option *must be set in Hadoop Configuration*. + // 2. This option can be overriden by "spark.sql.parquet.output.committer.class". val OUTPUT_COMMITTER_CLASS = stringConf("spark.sql.sources.outputCommitterClass", isPublic = false) @@ -415,7 +431,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) /** When true uses verifyPartitionPath to prune the path which is not exists. */ - private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITIONPATH) + private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala index 62c4e92ebec68..1551afd7b7bf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -17,19 +17,35 @@ package org.apache.spark.sql.parquet +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter - +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.parquet.Log import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} +/** + * An output committer for writing Parquet files. In stead of writing to the `_temporary` folder + * like what [[ParquetOutputCommitter]] does, this output committer writes data directly to the + * destination folder. This can be useful for data stored in S3, where directory operations are + * relatively expensive. + * + * To enable this output committer, users may set the "spark.sql.parquet.output.committer.class" + * property via Hadoop [[Configuration]]. Not that this property overrides + * "spark.sql.sources.outputCommitterClass". + * + * *NOTE* + * + * NEVER use [[DirectParquetOutputCommitter]] when appending data, because currently there's + * no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are + * left * empty). + */ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) extends ParquetOutputCommitter(outputPath, context) { val LOG = Log.getLog(classOf[ParquetOutputCommitter]) - override def getWorkPath(): Path = outputPath + override def getWorkPath: Path = outputPath override def abortTask(taskContext: TaskAttemptContext): Unit = {} override def commitTask(taskContext: TaskAttemptContext): Unit = {} override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true @@ -46,13 +62,11 @@ private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: T val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) try { ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) - } catch { - case e: Exception => { - LOG.warn("could not write summary file for " + outputPath, e) - val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fileSystem.exists(metadataPath)) { - fileSystem.delete(metadataPath, true) - } + } catch { case e: Exception => + LOG.warn("could not write summary file for " + outputPath, e) + val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fileSystem.exists(metadataPath)) { + fileSystem.delete(metadataPath, true) } } } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index e049d54bf55dc..1d353bd8e1114 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -178,11 +178,11 @@ private[sql] class ParquetRelation2( val committerClass = conf.getClass( - "spark.sql.parquet.output.committer.class", + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) - if (conf.get("spark.sql.parquet.output.committer.class") == null) { + if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { logInfo("Using default output committer for Parquet: " + classOf[ParquetOutputCommitter].getCanonicalName) } else { From 0401cbaa8ee51c71f43604f338b65022a479da0a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 23 Jun 2015 17:46:29 -0700 Subject: [PATCH 04/44] [SPARK-7157][SQL] add sampleBy to DataFrame Add `sampleBy` to DataFrame. rxin Author: Xiangrui Meng Closes #6769 from mengxr/SPARK-7157 and squashes the following commits: 991f26f [Xiangrui Meng] fix seed 4a14834 [Xiangrui Meng] move sampleBy to stat 832f7cc [Xiangrui Meng] add sampleBy to DataFrame --- python/pyspark/sql/dataframe.py | 40 +++++++++++++++++++ .../spark/sql/DataFrameStatFunctions.scala | 24 +++++++++++ .../apache/spark/sql/DataFrameStatSuite.scala | 12 +++++- 3 files changed, 74 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 152b87351db31..213338dfe58a4 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -448,6 +448,41 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + @since(1.5) + def sampleBy(self, col, fractions, seed=None): + """ + Returns a stratified sample without replacement based on the + fraction given on each stratum. + + :param col: column that defines strata + :param fractions: + sampling fraction for each stratum. If a stratum is not + specified, we treat its fraction as zero. + :param seed: random seed + :return: a new DataFrame that represents the stratified sample + + >>> from pyspark.sql.functions import col + >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) + >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) + >>> sampled.groupBy("key").count().orderBy("key").show() + +---+-----+ + |key|count| + +---+-----+ + | 0| 5| + | 1| 8| + +---+-----+ + """ + if not isinstance(col, str): + raise ValueError("col must be a string, but got %r" % type(col)) + if not isinstance(fractions, dict): + raise ValueError("fractions must be a dict but got %r" % type(fractions)) + for k, v in fractions.items(): + if not isinstance(k, (float, int, long, basestring)): + raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) + fractions[k] = float(v) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. @@ -1322,6 +1357,11 @@ def freqItems(self, cols, support=None): freqItems.__doc__ = DataFrame.freqItems.__doc__ + def sampleBy(self, col, fractions, seed=None): + return self.df.sampleBy(col, fractions, seed) + + sampleBy.__doc__ = DataFrame.sampleBy.__doc__ + def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index edb9ed7bba56a..955d28771b4df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.UUID + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ @@ -163,4 +165,26 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = { + require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), + s"Fractions must be in [0, 1], but got $fractions.") + import org.apache.spark.sql.functions.rand + val c = Column(col) + val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8)) + val expr = fractions.toSeq.map { case (k, v) => + (c === k) && (r < v) + }.reduce(_ || _) || false + df.filter(expr) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0d3ff899dad72..3dd46889127ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import org.scalatest.Matchers._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions.col -class DataFrameStatSuite extends SparkFunSuite { +class DataFrameStatSuite extends QueryTest { private val sqlCtx = org.apache.spark.sql.test.TestSQLContext import sqlCtx.implicits._ @@ -98,4 +98,12 @@ class DataFrameStatSuite extends SparkFunSuite { val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) } + + test("sampleBy") { + val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 4), Row(1, 9))) + } } From a458efc66c31dc281af379b914bfa2b077ca6635 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 23 Jun 2015 19:30:25 -0700 Subject: [PATCH 05/44] Revert "[SPARK-7157][SQL] add sampleBy to DataFrame" This reverts commit 0401cbaa8ee51c71f43604f338b65022a479da0a. The new test case on Jenkins is failing. --- python/pyspark/sql/dataframe.py | 40 ------------------- .../spark/sql/DataFrameStatFunctions.scala | 24 ----------- .../apache/spark/sql/DataFrameStatSuite.scala | 12 +----- 3 files changed, 2 insertions(+), 74 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 213338dfe58a4..152b87351db31 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -448,41 +448,6 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) - @since(1.5) - def sampleBy(self, col, fractions, seed=None): - """ - Returns a stratified sample without replacement based on the - fraction given on each stratum. - - :param col: column that defines strata - :param fractions: - sampling fraction for each stratum. If a stratum is not - specified, we treat its fraction as zero. - :param seed: random seed - :return: a new DataFrame that represents the stratified sample - - >>> from pyspark.sql.functions import col - >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) - >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) - >>> sampled.groupBy("key").count().orderBy("key").show() - +---+-----+ - |key|count| - +---+-----+ - | 0| 5| - | 1| 8| - +---+-----+ - """ - if not isinstance(col, str): - raise ValueError("col must be a string, but got %r" % type(col)) - if not isinstance(fractions, dict): - raise ValueError("fractions must be a dict but got %r" % type(fractions)) - for k, v in fractions.items(): - if not isinstance(k, (float, int, long, basestring)): - raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) - fractions[k] = float(v) - seed = seed if seed is not None else random.randint(0, sys.maxsize) - return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) - @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. @@ -1357,11 +1322,6 @@ def freqItems(self, cols, support=None): freqItems.__doc__ = DataFrame.freqItems.__doc__ - def sampleBy(self, col, fractions, seed=None): - return self.df.sampleBy(col, fractions, seed) - - sampleBy.__doc__ = DataFrame.sampleBy.__doc__ - def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 955d28771b4df..edb9ed7bba56a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.util.UUID - import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ @@ -165,26 +163,4 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } - - /** - * Returns a stratified sample without replacement based on the fraction given on each stratum. - * @param col column that defines strata - * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat - * its fraction as zero. - * @param seed random seed - * @return a new [[DataFrame]] that represents the stratified sample - * - * @since 1.5.0 - */ - def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = { - require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), - s"Fractions must be in [0, 1], but got $fractions.") - import org.apache.spark.sql.functions.rand - val c = Column(col) - val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8)) - val expr = fractions.toSeq.map { case (k, v) => - (c === k) && (r < v) - }.reduce(_ || _) || false - df.filter(expr) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 3dd46889127ff..0d3ff899dad72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import org.scalatest.Matchers._ -import org.apache.spark.sql.functions.col +import org.apache.spark.SparkFunSuite -class DataFrameStatSuite extends QueryTest { +class DataFrameStatSuite extends SparkFunSuite { private val sqlCtx = org.apache.spark.sql.test.TestSQLContext import sqlCtx.implicits._ @@ -98,12 +98,4 @@ class DataFrameStatSuite extends QueryTest { val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) } - - test("sampleBy") { - val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) - val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) - checkAnswer( - sampled.groupBy("key").count().orderBy("key"), - Seq(Row(0, 4), Row(1, 9))) - } } From 50c3a86f42d7dfd1acbda65c1e5afbd3db1406df Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 23 Jun 2015 22:27:17 -0700 Subject: [PATCH 06/44] [SPARK-6749] [SQL] Make metastore client robust to underlying socket connection loss This works around a bug in the underlying RetryingMetaStoreClient (HIVE-10384) by refreshing the metastore client on thrift exceptions. We attempt to emulate the proper hive behavior by retrying only as configured by hiveconf. Author: Eric Liang Closes #6912 from ericl/spark-6749 and squashes the following commits: 2d54b55 [Eric Liang] use conf from state 0e3a74e [Eric Liang] use shim properly 980b3e5 [Eric Liang] Fix conf parsing hive 0.14 conf. 92459b6 [Eric Liang] Work around RetryingMetaStoreClient bug --- .../spark/sql/hive/client/ClientWrapper.scala | 55 ++++++++++++++++++- .../spark/sql/hive/client/HiveShim.scala | 19 +++++++ 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 42c2d4c98ffb2..2f771d76793e5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.client import java.io.{BufferedReader, InputStreamReader, File, PrintStream} import java.net.URI import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} +import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConversions._ import scala.language.reflectiveCalls @@ -136,12 +137,62 @@ private[hive] class ClientWrapper( // TODO: should be a def?s // When we create this val client, the HiveConf of it (conf) is the one associated with state. - private val client = Hive.get(conf) + @GuardedBy("this") + private var client = Hive.get(conf) + + // We use hive's conf for compatibility. + private val retryLimit = conf.getIntVar(HiveConf.ConfVars.METASTORETHRIFTFAILURERETRIES) + private val retryDelayMillis = shim.getMetastoreClientConnectRetryDelayMillis(conf) + + /** + * Runs `f` with multiple retries in case the hive metastore is temporarily unreachable. + */ + private def retryLocked[A](f: => A): A = synchronized { + // Hive sometimes retries internally, so set a deadline to avoid compounding delays. + val deadline = System.nanoTime + (retryLimit * retryDelayMillis * 1e6).toLong + var numTries = 0 + var caughtException: Exception = null + do { + numTries += 1 + try { + return f + } catch { + case e: Exception if causedByThrift(e) => + caughtException = e + logWarning( + "HiveClientWrapper got thrift exception, destroying client and retrying " + + s"(${retryLimit - numTries} tries remaining)", e) + Thread.sleep(retryDelayMillis) + try { + client = Hive.get(state.getConf, true) + } catch { + case e: Exception if causedByThrift(e) => + logWarning("Failed to refresh hive client, will retry.", e) + } + } + } while (numTries <= retryLimit && System.nanoTime < deadline) + if (System.nanoTime > deadline) { + logWarning("Deadline exceeded") + } + throw caughtException + } + + private def causedByThrift(e: Throwable): Boolean = { + var target = e + while (target != null) { + val msg = target.getMessage() + if (msg != null && msg.matches("(?s).*(TApplication|TProtocol|TTransport)Exception.*")) { + return true + } + target = target.getCause() + } + false + } /** * Runs `f` with ThreadLocal session state and classloaders configured for this version of hive. */ - private def withHiveState[A](f: => A): A = synchronized { + private def withHiveState[A](f: => A): A = retryLocked { val original = Thread.currentThread().getContextClassLoader // Set the thread local metastore client to the client associated with this ClientWrapper. Hive.set(client) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 5ae2dbb50d86b..e7c1779f80ce6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -21,6 +21,7 @@ import java.lang.{Boolean => JBoolean, Integer => JInteger} import java.lang.reflect.{Method, Modifier} import java.net.URI import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} +import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ @@ -64,6 +65,8 @@ private[client] sealed abstract class Shim { def getDriverResults(driver: Driver): Seq[String] + def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long + def loadPartition( hive: Hive, loadPath: Path, @@ -192,6 +195,10 @@ private[client] class Shim_v0_12 extends Shim { res.toSeq } + override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { + conf.getIntVar(HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY) * 1000 + } + override def loadPartition( hive: Hive, loadPath: Path, @@ -321,6 +328,12 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { JBoolean.TYPE, JBoolean.TYPE, JBoolean.TYPE) + private lazy val getTimeVarMethod = + findMethod( + classOf[HiveConf], + "getTimeVar", + classOf[HiveConf.ConfVars], + classOf[TimeUnit]) override def loadPartition( hive: Hive, @@ -359,4 +372,10 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE) } + override def getMetastoreClientConnectRetryDelayMillis(conf: HiveConf): Long = { + getTimeVarMethod.invoke( + conf, + HiveConf.ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY, + TimeUnit.MILLISECONDS).asInstanceOf[Long] + } } From 13ae806b255cfb2bd5470b599a95c87a2cd5e978 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 23 Jun 2015 23:03:59 -0700 Subject: [PATCH 07/44] [HOTFIX] [BUILD] Fix MiMa checks in master branch; enable MiMa for launcher project This commit changes the MiMa tests to test against the released 1.4.0 artifacts rather than 1.4.0-rc4; this change is necessary to fix a Jenkins build break since it seems that the RC4 snapshot is no longer available via Maven. I also enabled MiMa checks for the `launcher` subproject, which we should have done right after 1.4.0 was released. Author: Josh Rosen Closes #6974 from JoshRosen/mima-hotfix and squashes the following commits: 4b4175a [Josh Rosen] [HOTFIX] [BUILD] Fix MiMa checks in master branch; enable MiMa for launcher project --- project/MimaBuild.scala | 3 +-- project/SparkBuild.scala | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 5812b72f0aa78..f16bf989f200b 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,8 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - // TODO: Change this once Spark 1.4.0 is released - val previousSparkVersion = "1.4.0-rc4" + val previousSparkVersion = "1.4.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e01720296fed0..f5f1c9a1a247a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -166,9 +166,8 @@ object SparkBuild extends PomBuild { /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) - // TODO: remove launcher from this list after 1.4.0 allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach { + networkCommon, networkShuffle, networkYarn, unsafe).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } From 09fcf96b8f881988a4bc7fe26a3f6ed12dfb6adb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 23 Jun 2015 23:11:42 -0700 Subject: [PATCH 08/44] [SPARK-8371] [SQL] improve unit test for MaxOf and MinOf and fix bugs a follow up of https://github.com/apache/spark/pull/6813 Author: Wenchen Fan Closes #6825 from cloud-fan/cg and squashes the following commits: 43170cc [Wenchen Fan] fix bugs in code gen --- .../expressions/codegen/CodeGenerator.scala | 4 +- .../ArithmeticExpressionSuite.scala | 46 +++++++++++++------ 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bd5475d2066fc..47c5455435ec6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -175,8 +175,10 @@ class CodeGenContext { * Generate code for compare expression in Java */ def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { + // java boolean doesn't support > or < operator + case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" // use c1 - c2 may overflow - case dt: DataType if isPrimitiveType(dt) => s"(int)($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" + case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case other => s"$c1.compare($c2)" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 4bbbbe6c7f091..6c93698f8017b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType} +import org.apache.spark.sql.types.Decimal class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -123,23 +123,39 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("MaxOf") { - checkEvaluation(MaxOf(1, 2), 2) - checkEvaluation(MaxOf(2, 1), 2) - checkEvaluation(MaxOf(1L, 2L), 2L) - checkEvaluation(MaxOf(2L, 1L), 2L) + test("MaxOf basic") { + testNumericDataTypes { convert => + val small = Literal(convert(1)) + val large = Literal(convert(2)) + checkEvaluation(MaxOf(small, large), convert(2)) + checkEvaluation(MaxOf(large, small), convert(2)) + checkEvaluation(MaxOf(Literal.create(null, small.dataType), large), convert(2)) + checkEvaluation(MaxOf(large, Literal.create(null, small.dataType)), convert(2)) + } + } - checkEvaluation(MaxOf(Literal.create(null, IntegerType), 2), 2) - checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2) + test("MaxOf for atomic type") { + checkEvaluation(MaxOf(true, false), true) + checkEvaluation(MaxOf("abc", "bcd"), "bcd") + checkEvaluation(MaxOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), + Array(1.toByte, 3.toByte)) } - test("MinOf") { - checkEvaluation(MinOf(1, 2), 1) - checkEvaluation(MinOf(2, 1), 1) - checkEvaluation(MinOf(1L, 2L), 1L) - checkEvaluation(MinOf(2L, 1L), 1L) + test("MinOf basic") { + testNumericDataTypes { convert => + val small = Literal(convert(1)) + val large = Literal(convert(2)) + checkEvaluation(MinOf(small, large), convert(1)) + checkEvaluation(MinOf(large, small), convert(1)) + checkEvaluation(MinOf(Literal.create(null, small.dataType), large), convert(2)) + checkEvaluation(MinOf(small, Literal.create(null, small.dataType)), convert(1)) + } + } - checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) - checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) + test("MinOf for atomic type") { + checkEvaluation(MinOf(true, false), false) + checkEvaluation(MinOf("abc", "bcd"), "abc") + checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)), + Array(1.toByte, 2.toByte)) } } From cc465fd92482737c21971d82e30d4cf247acf932 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 24 Jun 2015 02:17:12 -0700 Subject: [PATCH 09/44] [SPARK-8138] [SQL] Improves error message when conflicting partition columns are found This PR improves the error message shown when conflicting partition column names are detected. This can be particularly annoying and confusing when there are a large number of partitions while a handful of them happened to contain unexpected temporary file(s). Now all suspicious directories are listed as below: ``` java.lang.AssertionError: assertion failed: Conflicting partition column names detected: Partition column name list #0: b, c, d Partition column name list #1: b, c Partition column name list #2: b For partitioned table directories, data files should only live in leaf directories. Please check the following directories for unexpected files: file:/tmp/foo/b=0 file:/tmp/foo/b=1 file:/tmp/foo/b=1/c=1 file:/tmp/foo/b=0/c=0 ``` Author: Cheng Lian Closes #6610 from liancheng/part-errmsg and squashes the following commits: 7d05f2c [Cheng Lian] Fixes Scala style issue a149250 [Cheng Lian] Adds test case for the error message 6b74dd8 [Cheng Lian] Also lists suspicious non-leaf partition directories a935eb8 [Cheng Lian] Improves error message when conflicting partition columns are found --- .../spark/sql/sources/PartitioningUtils.scala | 47 +++++++++++++++---- .../ParquetPartitionDiscoverySuite.scala | 45 ++++++++++++++++++ 2 files changed, 82 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index c6f535dde7676..8b2a45d8e970a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -84,7 +84,7 @@ private[sql] object PartitioningUtils { } else { // This dataset is partitioned. We need to check whether all partitions have the same // partition columns and resolve potential type conflicts. - val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues.map(_._2)) + val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues) // Creates the StructType which represents the partition columns. val fields = { @@ -181,19 +181,18 @@ private[sql] object PartitioningUtils { * StringType * }}} */ - private[sql] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = { - // Column names of all partitions must match - val distinctPartitionsColNames = values.map(_.columnNames).distinct - - if (distinctPartitionsColNames.isEmpty) { + private[sql] def resolvePartitions( + pathsWithPartitionValues: Seq[(Path, PartitionValues)]): Seq[PartitionValues] = { + if (pathsWithPartitionValues.isEmpty) { Seq.empty } else { - assert(distinctPartitionsColNames.size == 1, { - val list = distinctPartitionsColNames.mkString("\t", "\n\t", "") - s"Conflicting partition column names detected:\n$list" - }) + val distinctPartColNames = pathsWithPartitionValues.map(_._2.columnNames).distinct + assert( + distinctPartColNames.size == 1, + listConflictingPartitionColumns(pathsWithPartitionValues)) // Resolves possible type conflicts for each column + val values = pathsWithPartitionValues.map(_._2) val columnCount = values.head.columnNames.size val resolvedValues = (0 until columnCount).map { i => resolveTypeConflicts(values.map(_.literals(i))) @@ -206,6 +205,34 @@ private[sql] object PartitioningUtils { } } + private[sql] def listConflictingPartitionColumns( + pathWithPartitionValues: Seq[(Path, PartitionValues)]): String = { + val distinctPartColNames = pathWithPartitionValues.map(_._2.columnNames).distinct + + def groupByKey[K, V](seq: Seq[(K, V)]): Map[K, Iterable[V]] = + seq.groupBy { case (key, _) => key }.mapValues(_.map { case (_, value) => value }) + + val partColNamesToPaths = groupByKey(pathWithPartitionValues.map { + case (path, partValues) => partValues.columnNames -> path + }) + + val distinctPartColLists = distinctPartColNames.map(_.mkString(", ")).zipWithIndex.map { + case (names, index) => + s"Partition column name list #$index: $names" + } + + // Lists out those non-leaf partition directories that also contain files + val suspiciousPaths = distinctPartColNames.sortBy(_.length).flatMap(partColNamesToPaths) + + s"Conflicting partition column names detected:\n" + + distinctPartColLists.mkString("\n\t", "\n\t", "\n\n") + + "For partitioned table directories, data files should only live in leaf directories.\n" + + "And directories at the same level should have the same partition column name.\n" + + "Please check the following directories for unexpected files or " + + "inconsistent partition column names:\n" + + suspiciousPaths.map("\t" + _).mkString("\n", "\n", "") + } + /** * Converts a string to a [[Literal]] with automatic type inference. Currently only supports * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 01df189d1f3be..d0ebb11b063f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -538,4 +538,49 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) } } + + test("listConflictingPartitionColumns") { + def makeExpectedMessage(colNameLists: Seq[String], paths: Seq[String]): String = { + val conflictingColNameLists = colNameLists.zipWithIndex.map { case (list, index) => + s"\tPartition column name list #$index: $list" + }.mkString("\n", "\n", "\n") + + // scalastyle:off + s"""Conflicting partition column names detected: + |$conflictingColNameLists + |For partitioned table directories, data files should only live in leaf directories. + |And directories at the same level should have the same partition column name. + |Please check the following directories for unexpected files or inconsistent partition column names: + |${paths.map("\t" + _).mkString("\n", "\n", "")} + """.stripMargin.trim + // scalastyle:on + } + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/b=1"), PartitionValues(Seq("b"), Seq(Literal(1)))))).trim === + makeExpectedMessage(Seq("a", "b"), Seq("file:/tmp/foo/a=1", "file:/tmp/foo/b=1"))) + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1/_temporary"), PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/a=1"), PartitionValues(Seq("a"), Seq(Literal(1)))))).trim === + makeExpectedMessage( + Seq("a"), + Seq("file:/tmp/foo/a=1/_temporary", "file:/tmp/foo/a=1"))) + + assert( + listConflictingPartitionColumns( + Seq( + (new Path("file:/tmp/foo/a=1"), + PartitionValues(Seq("a"), Seq(Literal(1)))), + (new Path("file:/tmp/foo/a=1/b=foo"), + PartitionValues(Seq("a", "b"), Seq(Literal(1), Literal("foo")))))).trim === + makeExpectedMessage( + Seq("a", "a, b"), + Seq("file:/tmp/foo/a=1", "file:/tmp/foo/a=1/b=foo"))) + } } From 9d36ec24312f0a9865b4392f89e9611a5b80916d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 24 Jun 2015 09:49:20 -0700 Subject: [PATCH 10/44] [SPARK-8567] [SQL] Debugging flaky HiveSparkSubmitSuite Using similar approach used in `HiveThriftServer2Suite` to print stdout/stderr of the spawned process instead of logging them to see what happens on Jenkins. (This test suite only fails on Jenkins and doesn't spill out any log...) cc yhuai Author: Cheng Lian Closes #6978 from liancheng/debug-hive-spark-submit-suite and squashes the following commits: b031647 [Cheng Lian] Prints process stdout/stderr instead of logging them --- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index ab443032be20d..d85516ab0878e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive import java.io.File +import scala.sys.process.{ProcessLogger, Process} + import org.apache.spark._ import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -82,12 +84,18 @@ class HiveSparkSubmitSuite // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val process = Utils.executeCommand( + val process = Process( Seq("./bin/spark-submit") ++ args, new File(sparkHome), - Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + "SPARK_TESTING" -> "1", + "SPARK_HOME" -> sparkHome + ).run(ProcessLogger( + (line: String) => { println(s"out> $line") }, + (line: String) => { println(s"err> $line") } + )) + try { - val exitCode = failAfter(120 seconds) { process.waitFor() } + val exitCode = failAfter(120 seconds) { process.exitValue() } if (exitCode != 0) { fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") } From bba6699d0e9093bc041a9a33dd31992790f32174 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 24 Jun 2015 09:50:03 -0700 Subject: [PATCH 11/44] [SPARK-8578] [SQL] Should ignore user defined output committer when appending data https://issues.apache.org/jira/browse/SPARK-8578 It is not very safe to use a custom output committer when append data to an existing dir. This changes adds the logic to check if we are appending data, and if so, we use the output committer associated with the file output format. Author: Yin Huai Closes #6964 from yhuai/SPARK-8578 and squashes the following commits: 43544c4 [Yin Huai] Do not use a custom output commiter when appendiing data. --- .../apache/spark/sql/sources/commands.scala | 89 +++++++++++-------- .../sql/sources/hadoopFsRelationSuites.scala | 83 ++++++++++++++++- 2 files changed, 136 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 215e53c020849..fb6173f58ece6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -96,7 +96,8 @@ private[sql] case class InsertIntoHadoopFsRelation( val fs = outputPath.getFileSystem(hadoopConf) val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - val doInsertion = (mode, fs.exists(qualifiedOutputPath)) match { + val pathExists = fs.exists(qualifiedOutputPath) + val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => sys.error(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => @@ -107,6 +108,8 @@ private[sql] case class InsertIntoHadoopFsRelation( case (SaveMode.Ignore, exists) => !exists } + // If we are appending data to an existing dir. + val isAppend = (pathExists) && (mode == SaveMode.Append) if (doInsertion) { val job = new Job(hadoopConf) @@ -130,10 +133,10 @@ private[sql] case class InsertIntoHadoopFsRelation( val partitionColumns = relation.partitionColumns.fieldNames if (partitionColumns.isEmpty) { - insert(new DefaultWriterContainer(relation, job), df) + insert(new DefaultWriterContainer(relation, job, isAppend), df) } else { val writerContainer = new DynamicPartitionWriterContainer( - relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME) + relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME, isAppend) insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns) } } @@ -277,7 +280,8 @@ private[sql] case class InsertIntoHadoopFsRelation( private[sql] abstract class BaseWriterContainer( @transient val relation: HadoopFsRelation, - @transient job: Job) + @transient job: Job, + isAppend: Boolean) extends SparkHadoopMapReduceUtil with Logging with Serializable { @@ -356,34 +360,47 @@ private[sql] abstract class BaseWriterContainer( } private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - val committerClass = context.getConfiguration.getClass( - SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) - - Option(committerClass).map { clazz => - logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") - - // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat - // has an associated output committer. To override this output committer, - // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. - // If a data source needs to override the output committer, it needs to set the - // output committer in prepareForWrite method. - if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { - // The specified output committer is a FileOutputCommitter. - // So, we will use the FileOutputCommitter-specified constructor. - val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - ctor.newInstance(new Path(outputPath), context) - } else { - // The specified output committer is just a OutputCommitter. - // So, we will use the no-argument constructor. - val ctor = clazz.getDeclaredConstructor() - ctor.newInstance() + val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) + + if (isAppend) { + // If we are appending data to an existing dir, we will only use the output committer + // associated with the file output format since it is not safe to use a custom + // committer for appending. For example, in S3, direct parquet output committer may + // leave partial data in the destination dir when the the appending job fails. + logInfo( + s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName} " + + "for appending.") + defaultOutputCommitter + } else { + val committerClass = context.getConfiguration.getClass( + SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) + + Option(committerClass).map { clazz => + logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") + + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat + // has an associated output committer. To override this output committer, + // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. + // If a data source needs to override the output committer, it needs to set the + // output committer in prepareForWrite method. + if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { + // The specified output committer is a FileOutputCommitter. + // So, we will use the FileOutputCommitter-specified constructor. + val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + ctor.newInstance(new Path(outputPath), context) + } else { + // The specified output committer is just a OutputCommitter. + // So, we will use the no-argument constructor. + val ctor = clazz.getDeclaredConstructor() + ctor.newInstance() + } + }.getOrElse { + // If output committer class is not set, we will use the one associated with the + // file output format. + logInfo( + s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}") + defaultOutputCommitter } - }.getOrElse { - // If output committer class is not set, we will use the one associated with the - // file output format. - val outputCommitter = outputFormatClass.newInstance().getOutputCommitter(context) - logInfo(s"Using output committer class ${outputCommitter.getClass.getCanonicalName}") - outputCommitter } } @@ -433,8 +450,9 @@ private[sql] abstract class BaseWriterContainer( private[sql] class DefaultWriterContainer( @transient relation: HadoopFsRelation, - @transient job: Job) - extends BaseWriterContainer(relation, job) { + @transient job: Job, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { @transient private var writer: OutputWriter = _ @@ -473,8 +491,9 @@ private[sql] class DynamicPartitionWriterContainer( @transient relation: HadoopFsRelation, @transient job: Job, partitionColumns: Array[String], - defaultPartitionName: String) - extends BaseWriterContainer(relation, job) { + defaultPartitionName: String, + isAppend: Boolean) + extends BaseWriterContainer(relation, job, isAppend) { // All output writers are created on executor side. @transient protected var outputWriters: mutable.Map[String, OutputWriter] = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index e0d8277a8ed3f..a16ab3a00ddb8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,10 +17,16 @@ package org.apache.spark.sql.sources +import scala.collection.JavaConversions._ + import java.io.File import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil @@ -476,7 +482,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { // more cores, the issue can be reproduced steadily. Fortunately our Jenkins builder meets this // requirement. We probably want to move this test case to spark-integration-tests or spark-perf // later. - test("SPARK-8406: Avoids name collision while writing Parquet files") { + test("SPARK-8406: Avoids name collision while writing files") { withTempPath { dir => val path = dir.getCanonicalPath sqlContext @@ -497,6 +503,81 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { } } } + + test("SPARK-8578 specified custom output committer will not be used to append data") { + val clonedConf = new Configuration(configuration) + try { + val df = sqlContext.range(1, 10).toDF("i") + withTempPath { dir => + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + // Since Parquet has its own output committer setting, also set it + // to AlwaysFailParquetOutputCommitter at here. + configuration.set("spark.sql.parquet.output.committer.class", + classOf[AlwaysFailParquetOutputCommitter].getName) + // Because there data already exists, + // this append should succeed because we will use the output committer associated + // with file format and AlwaysFailOutputCommitter will not be used. + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + checkAnswer( + sqlContext.read + .format(dataSourceName) + .option("dataSchema", df.schema.json) + .load(dir.getCanonicalPath), + df.unionAll(df)) + + // This will fail because AlwaysFailOutputCommitter is used when we do append. + intercept[Exception] { + df.write.mode("overwrite").format(dataSourceName).save(dir.getCanonicalPath) + } + } + withTempPath { dir => + configuration.set( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + classOf[AlwaysFailOutputCommitter].getName) + // Since Parquet has its own output committer setting, also set it + // to AlwaysFailParquetOutputCommitter at here. + configuration.set("spark.sql.parquet.output.committer.class", + classOf[AlwaysFailParquetOutputCommitter].getName) + // Because there is no existing data, + // this append will fail because AlwaysFailOutputCommitter is used when we do append + // and there is no existing data. + intercept[Exception] { + df.write.mode("append").format(dataSourceName).save(dir.getCanonicalPath) + } + } + } finally { + // Hadoop 1 doesn't have `Configuration.unset` + configuration.clear() + clonedConf.foreach(entry => configuration.set(entry.getKey, entry.getValue)) + } + } +} + +// This class is used to test SPARK-8578. We should not use any custom output committer when +// we actually append data to an existing dir. +class AlwaysFailOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) + extends FileOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + sys.error("Intentional job commitment failure for testing purpose.") + } +} + +// This class is used to test SPARK-8578. We should not use any custom output committer when +// we actually append data to an existing dir. +class AlwaysFailParquetOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + sys.error("Intentional job commitment failure for testing purpose.") + } } class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { From 31f48e5af887a9ccc9cea0218c36bf52bbf49d24 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Wed, 24 Jun 2015 11:20:51 -0700 Subject: [PATCH 12/44] [SPARK-8576] Add spark-ec2 options to set IAM roles and instance-initiated shutdown behavior Both of these options are useful when spark-ec2 is being used as part of an automated pipeline and the engineers want to minimize the need to pass around AWS keys for access to things like S3 (keys are replaced by the IAM role) and to be able to launch a cluster that can terminate itself cleanly. Author: Nicholas Chammas Closes #6962 from nchammas/additional-ec2-options and squashes the following commits: fcf252e [Nicholas Chammas] PEP8 fixes efba9ee [Nicholas Chammas] add help for --instance-initiated-shutdown-behavior 598aecf [Nicholas Chammas] option to launch instances into IAM role 2743632 [Nicholas Chammas] add option for instance initiated shutdown --- ec2/spark_ec2.py | 56 ++++++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 63e2c79669763..e4932cfa7a4fc 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -306,6 +306,13 @@ def parse_args(): "--private-ips", action="store_true", default=False, help="Use private IPs for instances rather than public if VPC/subnet " + "requires that.") + parser.add_option( + "--instance-initiated-shutdown-behavior", default="stop", + choices=["stop", "terminate"], + help="Whether instances should terminate when shut down or just stop") + parser.add_option( + "--instance-profile-name", default=None, + help="IAM profile name to launch instances under") (opts, args) = parser.parse_args() if len(args) != 2: @@ -602,7 +609,8 @@ def launch_cluster(conn, opts, cluster_name): block_device_map=block_map, subnet_id=opts.subnet_id, placement_group=opts.placement_group, - user_data=user_data_content) + user_data=user_data_content, + instance_profile_name=opts.instance_profile_name) my_req_ids += [req.id for req in slave_reqs] i += 1 @@ -647,16 +655,19 @@ def launch_cluster(conn, opts, cluster_name): for zone in zones: num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) if num_slaves_this_zone > 0: - slave_res = image.run(key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - placement=zone, - min_count=num_slaves_this_zone, - max_count=num_slaves_this_zone, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + slave_res = image.run( + key_name=opts.key_pair, + security_group_ids=[slave_group.id] + additional_group_ids, + instance_type=opts.instance_type, + placement=zone, + min_count=num_slaves_this_zone, + max_count=num_slaves_this_zone, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) slave_nodes += slave_res.instances print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( s=num_slaves_this_zone, @@ -678,16 +689,19 @@ def launch_cluster(conn, opts, cluster_name): master_type = opts.instance_type if opts.zone == 'all': opts.zone = random.choice(conn.get_all_zones()).name - master_res = image.run(key_name=opts.key_pair, - security_group_ids=[master_group.id] + additional_group_ids, - instance_type=master_type, - placement=opts.zone, - min_count=1, - max_count=1, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content) + master_res = image.run( + key_name=opts.key_pair, + security_group_ids=[master_group.id] + additional_group_ids, + instance_type=master_type, + placement=opts.zone, + min_count=1, + max_count=1, + block_device_map=block_map, + subnet_id=opts.subnet_id, + placement_group=opts.placement_group, + user_data=user_data_content, + instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, + instance_profile_name=opts.instance_profile_name) master_nodes = master_res.instances print("Launched master in %s, regid = %s" % (zone, master_res.id)) From 1173483f3f465a4c63246e83d0aaa2af521395f5 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Wed, 24 Jun 2015 11:53:03 -0700 Subject: [PATCH 13/44] [SPARK-8399] [STREAMING] [WEB UI] Overlap between histograms and axis' name in Spark Streaming UI Moved where the X axis' name (#batches) is written in histograms in the spark streaming web ui so the histograms and the axis' name do not overlap. Author: BenFradet Closes #6845 from BenFradet/SPARK-8399 and squashes the following commits: b63695f [BenFradet] adjusted inner histograms eb610ee [BenFradet] readjusted #batches on the x axis dd46f98 [BenFradet] aligned all unit labels and ticks 0564b62 [BenFradet] readjusted #batches placement edd0936 [BenFradet] moved where the X axis' name (#batches) is written in histograms in the spark streaming web ui --- .../apache/spark/streaming/ui/static/streaming-page.js | 10 ++++++---- .../org/apache/spark/streaming/ui/StreamingPage.scala | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js index 75251f493ad22..4886b68eeaf76 100644 --- a/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js @@ -31,6 +31,8 @@ var maxXForHistogram = 0; var histogramBinCount = 10; var yValueFormat = d3.format(",.2f"); +var unitLabelYOffset = -10; + // Show a tooltip "text" for "node" function showBootstrapTooltip(node, text) { $(node).tooltip({title: text, trigger: "manual", container: "body"}); @@ -133,7 +135,7 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { .attr("class", "y axis") .call(yAxis) .append("text") - .attr("transform", "translate(0," + (-3) + ")") + .attr("transform", "translate(0," + unitLabelYOffset + ")") .text(unitY); @@ -223,10 +225,10 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) { .style("border-left", "0px solid white"); var margin = {top: 20, right: 30, bottom: 30, left: 10}; - var width = 300 - margin.left - margin.right; + var width = 350 - margin.left - margin.right; var height = 150 - margin.top - margin.bottom; - var x = d3.scale.linear().domain([0, maxXForHistogram]).range([0, width]); + var x = d3.scale.linear().domain([0, maxXForHistogram]).range([0, width - 50]); var y = d3.scale.linear().domain([minY, maxY]).range([height, 0]); var xAxis = d3.svg.axis().scale(x).orient("top").ticks(5); @@ -248,7 +250,7 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) { .attr("class", "x axis") .call(xAxis) .append("text") - .attr("transform", "translate(" + (margin.left + width - 40) + ", 15)") + .attr("transform", "translate(" + (margin.left + width - 45) + ", " + unitLabelYOffset + ")") .text("#batches"); svg.append("g") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 4ee7a486e370b..87af902428ec8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -310,7 +310,7 @@ private[ui] class StreamingPage(parent: StreamingTab) Timelines (Last {batchTimes.length} batches, {numActiveBatches} active, {numCompletedBatches} completed) - Histograms + Histograms @@ -456,7 +456,7 @@ private[ui] class StreamingPage(parent: StreamingTab) {receiverActive} {receiverLocation} {receiverLastErrorTime} -
{receiverLastError}
+
{receiverLastError}
From 43e66192f45a23f7232116e9f664158862df5015 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 24 Jun 2015 11:55:20 -0700 Subject: [PATCH 14/44] [SPARK-8506] Add pakages to R context created through init. Author: Holden Karau Closes #6928 from holdenk/SPARK-8506-sparkr-does-not-provide-an-easy-way-to-depend-on-spark-packages-when-performing-init-from-inside-of-r and squashes the following commits: b60dd63 [Holden Karau] Add an example with the spark-csv package fa8bc92 [Holden Karau] typo: sparm -> spark 865a90c [Holden Karau] strip spaces for comparision c7a4471 [Holden Karau] Add some documentation c1a9233 [Holden Karau] refactor for testing c818556 [Holden Karau] Add pakages to R --- R/pkg/R/client.R | 26 +++++++++++++++++++------- R/pkg/R/sparkR.R | 7 +++++-- R/pkg/inst/tests/test_client.R | 32 ++++++++++++++++++++++++++++++++ docs/sparkr.md | 17 +++++++++++++---- 4 files changed, 69 insertions(+), 13 deletions(-) create mode 100644 R/pkg/inst/tests/test_client.R diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 1281c41213e32..cf2e5ddeb7a9d 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -34,24 +34,36 @@ connectBackend <- function(hostname, port, timeout = 6000) { con } -launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) { +determineSparkSubmitBin <- function() { if (.Platform$OS.type == "unix") { sparkSubmitBinName = "spark-submit" } else { sparkSubmitBinName = "spark-submit.cmd" } + sparkSubmitBinName +} + +generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + if (jars != "") { + jars <- paste("--jars", jars) + } + + if (packages != "") { + packages <- paste("--packages", packages) + } + combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ") + combinedArgs +} + +launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { + sparkSubmitBin <- determineSparkSubmitBin() if (sparkHome != "") { sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) } else { sparkSubmitBin <- sparkSubmitBinName } - - if (jars != "") { - jars <- paste("--jars", jars) - } - - combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ") + combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") invisible(system2(sparkSubmitBin, combinedArgs, wait = F)) } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index dbde0c44c55d5..8f81d5640c1d0 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -81,6 +81,7 @@ sparkR.stop <- function() { #' @param sparkExecutorEnv Named list of environment variables to be used when launching executors. #' @param sparkJars Character string vector of jar files to pass to the worker nodes. #' @param sparkRLibDir The path where R is installed on the worker nodes. +#' @param sparkPackages Character string vector of packages from spark-packages.org #' @export #' @examples #'\dontrun{ @@ -100,7 +101,8 @@ sparkR.init <- function( sparkEnvir = list(), sparkExecutorEnv = list(), sparkJars = "", - sparkRLibDir = "") { + sparkRLibDir = "", + sparkPackages = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") @@ -129,7 +131,8 @@ sparkR.init <- function( args = path, sparkHome = sparkHome, jars = jars, - sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell")) + sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), + sparkPackages = sparkPackages) # wait atmost 100 seconds for JVM to launch wait <- 0.1 for (i in 1:25) { diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R new file mode 100644 index 0000000000000..30b05c1a2afcd --- /dev/null +++ b/R/pkg/inst/tests/test_client.R @@ -0,0 +1,32 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in client.R") + +test_that("adding spark-testing-base as a package works", { + args <- generateSparkSubmitArgs("", "", "", "", + "holdenk:spark-testing-base:1.3.0_0.0.5") + expect_equal(gsub("[[:space:]]", "", args), + gsub("[[:space:]]", "", + "--packages holdenk:spark-testing-base:1.3.0_0.0.5")) +}) + +test_that("no package specified doesn't add packages flag", { + args <- generateSparkSubmitArgs("", "", "", "", "") + expect_equal(gsub("[[:space:]]", "", args), + "") +}) diff --git a/docs/sparkr.md b/docs/sparkr.md index 4d82129921a37..095ea4308cfeb 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -27,9 +27,9 @@ All of the examples on this page use sample data included in R or the Spark dist
The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster. You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name -etc. Further, to work with DataFrames we will need a `SQLContext`, which can be created from the -SparkContext. If you are working from the SparkR shell, the `SQLContext` and `SparkContext` should -already be created for you. +, any spark packages depended on, etc. Further, to work with DataFrames we will need a `SQLContext`, +which can be created from the SparkContext. If you are working from the SparkR shell, the +`SQLContext` and `SparkContext` should already be created for you. {% highlight r %} sc <- sparkR.init() @@ -62,7 +62,16 @@ head(df) SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. -The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). +The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). These packages can either be added by +specifying `--packages` with `spark-submit` or `sparkR` commands, or if creating context through `init` +you can specify the packages with the `packages` argument. + +
+{% highlight r %} +sc <- sparkR.init(packages="com.databricks:spark-csv_2.11:1.0.3") +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} +
We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. From b84d4b4dfe8ced1b96a0c74ef968a20a1bba8231 Mon Sep 17 00:00:00 2001 From: "Santiago M. Mola" Date: Wed, 24 Jun 2015 12:29:07 -0700 Subject: [PATCH 15/44] [SPARK-7088] [SQL] Fix analysis for 3rd party logical plan. ResolveReferences analysis rule now does not throw when it cannot resolve references in a self-join. Author: Santiago M. Mola Closes #6853 from smola/SPARK-7088 and squashes the following commits: af71ac7 [Santiago M. Mola] [SPARK-7088] Fix analysis for 3rd party logical plan. --- .../sql/catalyst/analysis/Analyzer.scala | 38 ++++++++++--------- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++++++ 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0a3f5a7b5cade..b06759f144fd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -283,7 +283,7 @@ class Analyzer( val conflictingAttributes = left.outputSet.intersect(right.outputSet) logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") - val (oldRelation, newRelation) = right.collect { + right.collect { // Handle base relations that might appear more than once. case oldVersion: MultiInstanceRelation if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => @@ -308,25 +308,27 @@ class Analyzer( if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) - }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass. - sys.error( - s""" - |Failure when resolving conflicting references in Join: - |$plan - | - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - """.stripMargin) } - - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => attributeRewrites.get(a).getOrElse(a) - } + // Only handle first case, others will be fixed on the next pass. + .headOption match { + case None => + /* + * No result implies that there is a logical plan node that produces new references + * that this rule cannot handle. When that is the case, there must be another rule + * that resolves these conflicts. Otherwise, the analysis will fail. + */ + j + case Some((oldRelation, newRelation)) => + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val newRight = right transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => attributeRewrites.get(a).getOrElse(a) + } + } + j.copy(right = newRight) } - j.copy(right = newRight) // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on grandchild diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c5a1437be6d05..a069b4710f38c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -48,6 +48,7 @@ trait CheckAnalysis { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { + case operator: LogicalPlan => operator transformExpressionsUp { case a: Attribute if !a.resolved => @@ -121,6 +122,17 @@ trait CheckAnalysis { case _ => // Analysis successful! } + + // Special handling for cases when self-join introduce duplicate expression ids. + case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Join: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + } extendedCheckRules.foreach(_(plan)) } From f04b5672c5a5562f8494df3b0df23235285c9e9e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 24 Jun 2015 13:28:50 -0700 Subject: [PATCH 16/44] [SPARK-7289] handle project -> limit -> sort efficiently make the `TakeOrdered` strategy and operator more general, such that it can optionally handle a projection when necessary Author: Wenchen Fan Closes #6780 from cloud-fan/limit and squashes the following commits: 34aa07b [Wenchen Fan] revert 07d5456 [Wenchen Fan] clean closure 20821ec [Wenchen Fan] fix 3676a82 [Wenchen Fan] address comments b558549 [Wenchen Fan] address comments 214842b [Wenchen Fan] fix style 2d8be83 [Wenchen Fan] add LimitPushDown 948f740 [Wenchen Fan] fix existing --- .../sql/catalyst/optimizer/Optimizer.scala | 52 ++++++++++--------- .../optimizer/UnionPushdownSuite.scala | 4 +- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/SparkPlan.scala | 1 - .../spark/sql/execution/SparkStrategies.scala | 8 ++- .../spark/sql/execution/basicOperators.scala | 27 +++++++--- .../spark/sql/execution/PlannerSuite.scala | 6 +++ .../apache/spark/sql/hive/HiveContext.scala | 2 +- 8 files changed, 62 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 98b4476076854..bfd24287c9645 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -39,19 +39,22 @@ object DefaultOptimizer extends Optimizer { Batch("Distinct", FixedPoint(100), ReplaceDistinctWithAggregate) :: Batch("Operator Optimizations", FixedPoint(100), - UnionPushdown, - CombineFilters, + // Operator push down + UnionPushDown, + PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, ColumnPruning, + // Operator combine ProjectCollapsing, + CombineFilters, CombineLimits, + // Constant folding NullPropagation, OptimizeIn, ConstantFolding, LikeSimplification, BooleanSimplification, - PushPredicateThroughJoin, RemovePositive, SimplifyFilters, SimplifyCasts, @@ -63,25 +66,25 @@ object DefaultOptimizer extends Optimizer { } /** - * Pushes operations to either side of a Union. - */ -object UnionPushdown extends Rule[LogicalPlan] { + * Pushes operations to either side of a Union. + */ +object UnionPushDown extends Rule[LogicalPlan] { /** - * Maps Attributes from the left side to the corresponding Attribute on the right side. - */ - def buildRewrites(union: Union): AttributeMap[Attribute] = { + * Maps Attributes from the left side to the corresponding Attribute on the right side. + */ + private def buildRewrites(union: Union): AttributeMap[Attribute] = { assert(union.left.output.size == union.right.output.size) AttributeMap(union.left.output.zip(union.right.output)) } /** - * Rewrites an expression so that it can be pushed to the right side of a Union operator. - * This method relies on the fact that the output attributes of a union are always equal - * to the left child's output. - */ - def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = { + * Rewrites an expression so that it can be pushed to the right side of a Union operator. + * This method relies on the fact that the output attributes of a union are always equal + * to the left child's output. + */ + private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { val result = e transform { case a: Attribute => rewrites(a) } @@ -108,7 +111,6 @@ object UnionPushdown extends Rule[LogicalPlan] { } } - /** * Attempts to eliminate the reading of unneeded columns from the query plan using the following * transformations: @@ -117,7 +119,6 @@ object UnionPushdown extends Rule[LogicalPlan] { * - Aggregate * - Project <- Join * - LeftSemiJoin - * - Performing alias substitution. */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -159,10 +160,11 @@ object ColumnPruning extends Rule[LogicalPlan] { Join(left, prunedChild(right, allReferences), LeftSemi, condition) + // Push down project through limit, so that we may have chance to push it further. case Project(projectList, Limit(exp, child)) => Limit(exp, Project(projectList, child)) - // push down project if possible when the child is sort + // Push down project if possible when the child is sort case p @ Project(projectList, s @ Sort(_, _, grandChild)) if s.references.subsetOf(p.outputSet) => s.copy(child = Project(projectList, grandChild)) @@ -181,8 +183,8 @@ object ColumnPruning extends Rule[LogicalPlan] { } /** - * Combines two adjacent [[Project]] operators into one, merging the - * expressions into one single expression. + * Combines two adjacent [[Project]] operators into one and perform alias substitution, + * merging the expressions into one single expression. */ object ProjectCollapsing extends Rule[LogicalPlan] { @@ -222,10 +224,10 @@ object ProjectCollapsing extends Rule[LogicalPlan] { object LikeSimplification extends Rule[LogicalPlan] { // if guards below protect from escapes on trailing %. // Cases like "something\%" are not optimized, but this does not affect correctness. - val startsWith = "([^_%]+)%".r - val endsWith = "%([^_%]+)".r - val contains = "%([^_%]+)%".r - val equalTo = "([^_%]*)".r + private val startsWith = "([^_%]+)%".r + private val endsWith = "%([^_%]+)".r + private val contains = "%([^_%]+)%".r + private val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Like(l, Literal(utf, StringType)) => @@ -497,7 +499,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { grandChild)) } - def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = { + private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]) = { condition transform { case a: AttributeReference => sourceAliases.getOrElse(a, a) } @@ -682,7 +684,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS /** Maximum number of decimal digits representable precisely in a Double */ - val MAX_DOUBLE_DIGITS = 15 + private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index 35f50be46b76f..ec379489a6d1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -class UnionPushdownSuite extends PlanTest { +class UnionPushDownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Subqueries", Once, EliminateSubQueries) :: Batch("Union Pushdown", Once, - UnionPushdown) :: Nil + UnionPushDown) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 04fc798bf3738..5708df82de12f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -858,7 +858,7 @@ class SQLContext(@transient val sparkContext: SparkContext) experimental.extraStrategies ++ ( DataSourceStrategy :: DDLStrategy :: - TakeOrdered :: + TakeOrderedAndProject :: HashAggregation :: LeftSemiJoin :: HashJoin :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 2b8d30294293c..47f56b2b7ebe6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -169,7 +169,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if(codegenEnabled && expressions.forall(_.isThreadSafe)) { - GenerateMutableProjection.generate(expressions, inputSchema) } else { () => new InterpretedMutableProjection(expressions, inputSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1ff1cc224de8c..21912cf24933e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -213,10 +213,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1) - object TakeOrdered extends Strategy { + object TakeOrderedAndProject extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) => - execution.TakeOrdered(limit, order, planLater(child)) :: Nil + execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil + case logical.Limit( + IntegerLiteral(limit), + logical.Project(projectList, logical.Sort(order, true, child))) => + execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 7aedd630e3871..647c4ab5cb651 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -39,8 +39,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends @transient lazy val buildProjection = newMutableProjection(projectList, child.output) protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => - val resuableProjection = buildProjection() - iter.map(resuableProjection) + val reusableProjection = buildProjection() + iter.map(reusableProjection) } override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -147,12 +147,18 @@ case class Limit(limit: Int, child: SparkPlan) /** * :: DeveloperApi :: - * Take the first limit elements as defined by the sortOrder. This is logically equivalent to - * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but - * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. + * Take the first limit elements as defined by the sortOrder, and do projection if needed. + * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator, + * or having a [[Project]] operator between them. + * This could have been named TopK, but Spark's top operator does the opposite in ordering + * so we name it TakeOrdered to avoid confusion. */ @DeveloperApi -case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode { +case class TakeOrderedAndProject( + limit: Int, + sortOrder: Seq[SortOrder], + projectList: Option[Seq[NamedExpression]], + child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output @@ -160,8 +166,13 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) private val ord: RowOrdering = new RowOrdering(sortOrder, child.output) - private def collectData(): Array[InternalRow] = - child.execute().map(_.copy()).takeOrdered(limit)(ord) + // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. + @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) + + private def collectData(): Array[InternalRow] = { + val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) + projection.map(data.map(_)).getOrElse(data) + } override def executeCollect(): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 5854ab48db552..3dd24130af81a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -141,4 +141,10 @@ class PlannerSuite extends SparkFunSuite { setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } + + test("efficient limit -> project -> sort") { + val query = testData.sort('key).select('value).limit(2).logicalPlan + val planned = planner.TakeOrderedAndProject(query) + assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index cf05c6c989655..8021f915bb821 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -442,7 +442,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { HiveCommandStrategy(self), HiveDDLStrategy, DDLStrategy, - TakeOrdered, + TakeOrderedAndProject, ParquetOperations, InMemoryScans, ParquetConversion, // Must be before HiveTableScans From fb32c388985ce65c1083cb435cf1f7479fecbaac Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 24 Jun 2015 14:58:43 -0700 Subject: [PATCH 17/44] [SPARK-7633] [MLLIB] [PYSPARK] Python bindings for StreamingLogisticRegressionwithSGD Add Python bindings to StreamingLogisticRegressionwithSGD. No Java wrappers are needed as models are updated directly using train. Author: MechCoder Closes #6849 from MechCoder/spark-3258 and squashes the following commits: b4376a5 [MechCoder] minor d7e5fc1 [MechCoder] Refactor into StreamingLinearAlgorithm Better docs 9c09d4e [MechCoder] [SPARK-7633] Python bindings for StreamingLogisticRegressionwithSGD --- python/pyspark/mllib/classification.py | 96 +++++++++++++++++- python/pyspark/mllib/tests.py | 135 ++++++++++++++++++++++++- 2 files changed, 229 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 758accf4b41eb..2698f10d06883 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -21,6 +21,7 @@ from numpy import array from pyspark import RDD +from pyspark.streaming import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper @@ -28,7 +29,8 @@ __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS', - 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] + 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes', + 'StreamingLogisticRegressionWithSGD'] class LinearClassificationModel(LinearModel): @@ -583,6 +585,98 @@ def train(cls, data, lambda_=1.0): return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) +class StreamingLinearAlgorithm(object): + """ + Base class that has to be inherited by any StreamingLinearAlgorithm. + + Prevents reimplementation of methods predictOn and predictOnValues. + """ + def __init__(self, model): + self._model = model + + def latestModel(self): + """ + Returns the latest model. + """ + return self._model + + def _validate(self, dstream): + if not isinstance(dstream, DStream): + raise TypeError( + "dstream should be a DStream object, got %s" % type(dstream)) + if not self._model: + raise ValueError( + "Model must be intialized using setInitialWeights") + + def predictOn(self, dstream): + """ + Make predictions on a dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.map(lambda x: self._model.predict(x)) + + def predictOnValues(self, dstream): + """ + Make predictions on a keyed dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.mapValues(lambda x: self._model.predict(x)) + + +@inherit_doc +class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): + """ + Run LogisticRegression with SGD on a stream of data. + + The weights obtained at the end of training a stream are used as initial + weights for the next stream. + + :param stepSize: Step size for each iteration of gradient descent. + :param numIterations: Number of iterations run for each batch of data. + :param miniBatchFraction: Fraction of data on which SGD is run for each + iteration. + :param regParam: L2 Regularization parameter. + """ + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01): + self.stepSize = stepSize + self.numIterations = numIterations + self.regParam = regParam + self.miniBatchFraction = miniBatchFraction + self._model = None + super(StreamingLogisticRegressionWithSGD, self).__init__( + model=self._model) + + def setInitialWeights(self, initialWeights): + """ + Set the initial value of weights. + + This must be set before running trainOn and predictOn. + """ + initialWeights = _convert_to_vector(initialWeights) + + # LogisticRegressionWithSGD does only binary classification. + self._model = LogisticRegressionModel( + initialWeights, 0, initialWeights.size, 2) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + # LogisticRegressionWithSGD.train raises an error for an empty RDD. + if not rdd.isEmpty(): + self._model = LogisticRegressionWithSGD.train( + rdd, self.numIterations, self.stepSize, + self.miniBatchFraction, self._model.weights) + + dstream.foreachRDD(update) + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 509faa11df170..cd80c3e07a4f7 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -26,7 +26,8 @@ from time import time, sleep from shutil import rmtree -from numpy import array, array_equal, zeros, inf, all, random +from numpy import ( + array, array_equal, zeros, inf, random, exp, dot, all, mean) from numpy import sum as array_sum from py4j.protocol import Py4JJavaError @@ -45,6 +46,7 @@ from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics from pyspark.mllib.feature import Word2Vec @@ -1037,6 +1039,137 @@ def test_dim(self): self.assertEqual(len(point.features), 2) +class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): + + @staticmethod + def generateLogisticInput(offset, scale, nPoints, seed): + """ + Generate 1 / (1 + exp(-x * scale + offset)) + + where, + x is randomnly distributed and the threshold + and labels for each sample in x is obtained from a random uniform + distribution. + """ + rng = random.RandomState(seed) + x = rng.randn(nPoints) + sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) + y_p = rng.rand(nPoints) + cut_off = y_p <= sigmoid + y_p[cut_off] = 1.0 + y_p[~cut_off] = 0.0 + return [ + LabeledPoint(y_p[i], Vectors.dense([x[i]])) + for i in range(nPoints)] + + def test_parameter_accuracy(self): + """ + Test that the final value of weights is close to the desired value. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + + t = time() + self.ssc.start() + self._ssc_wait(t, 20.0, 0.01) + rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 + self.assertAlmostEqual(rel, 0.1, 1) + + def test_convergence(self): + """ + Test that weights converge to the required value on toy data. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + models = [] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + input_stream.foreachRDD( + lambda x: models.append(slr.latestModel().weights[0])) + + t = time() + self.ssc.start() + self._ssc_wait(t, 15.0, 0.01) + t_models = array(models) + diff = t_models[1:] - t_models[:-1] + + # Test that weights improve with a small tolerance, + self.assertTrue(all(diff >= -0.1)) + self.assertTrue(array_sum(diff > 0) > 1) + + @staticmethod + def calculate_accuracy_error(true, predicted): + return sum(abs(array(true) - array(predicted))) / len(true) + + def test_predictions(self): + """Test predicted values on a toy model.""" + input_batches = [] + for i in range(20): + batch = self.sc.parallelize( + self.generateLogisticInput(0, 1.5, 100, 42 + i)) + input_batches.append(batch.map(lambda x: (x.label, x.features))) + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([1.5]) + predict_stream = slr.predictOnValues(input_stream) + true_predicted = [] + predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) + t = time() + self.ssc.start() + self._ssc_wait(t, 5.0, 0.01) + + # Test that the accuracy error is no more than 0.4 on each batch. + for batch in true_predicted: + true, predicted = zip(*batch) + self.assertTrue( + self.calculate_accuracy_error(true, predicted) < 0.4) + + def test_training_and_prediction(self): + """Test that the model improves on toy data with no. of batches""" + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.01, numIterations=25) + slr.setInitialWeights([-0.1]) + errors = [] + + def collect_errors(rdd): + true, predicted = zip(*rdd.collect()) + errors.append(self.calculate_accuracy_error(true, predicted)) + + true_predicted = [] + input_stream = self.ssc.queueStream(input_batches) + predict_stream = self.ssc.queueStream(predict_batches) + slr.trainOn(input_stream) + ps = slr.predictOnValues(predict_stream) + ps.foreachRDD(lambda x: collect_errors(x)) + + t = time() + self.ssc.start() + self._ssc_wait(t, 20.0, 0.01) + + # Test that the improvement in error is atleast 0.3 + self.assertTrue(errors[1] - errors[-1] > 0.3) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") From 8ab50765cd793169091d983b50d87a391f6ac1f4 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 24 Jun 2015 15:03:43 -0700 Subject: [PATCH 18/44] [SPARK-6777] [SQL] Implements backwards compatibility rules in CatalystSchemaConverter This PR introduces `CatalystSchemaConverter` for converting Parquet schema to Spark SQL schema and vice versa. Original conversion code in `ParquetTypesConverter` is removed. Benefits of the new version are: 1. When converting Spark SQL schemas, it generates standard Parquet schemas conforming to [the most updated Parquet format spec] [1]. Converting to old style Parquet schemas is also supported via feature flag `spark.sql.parquet.followParquetFormatSpec` (which is set to `false` for now, and should be set to `true` after both read and write paths are fixed). Note that although this version of Parquet format spec hasn't been officially release yet, Parquet MR 1.7.0 already sticks to it. So it should be safe to follow. 1. It implements backwards-compatibility rules described in the most updated Parquet format spec. Thus can recognize more schema patterns generated by other/legacy systems/tools. 1. Code organization follows convention used in [parquet-mr] [2], which is easier to follow. (Structure of `CatalystSchemaConverter` is similar to `AvroSchemaConverter`). To fully implement backwards-compatibility rules in both read and write path, we also need to update `CatalystRowConverter` (which is responsible for converting Parquet records to `Row`s), `RowReadSupport`, and `RowWriteSupport`. These would be done in follow-up PRs. TODO - [x] More schema conversion test cases for legacy schema patterns. [1]: https://github.com/apache/parquet-format/blob/ea095226597fdbecd60c2419d96b54b2fdb4ae6c/LogicalTypes.md [2]: https://github.com/apache/parquet-mr/ Author: Cheng Lian Closes #6617 from liancheng/spark-6777 and squashes the following commits: 2a2062d [Cheng Lian] Don't convert decimals without precision information b60979b [Cheng Lian] Adds a constructor which accepts a Configuration, and fixes default value of assumeBinaryIsString 743730f [Cheng Lian] Decimal scale shouldn't be larger than precision a104a9e [Cheng Lian] Fixes Scala style issue 1f71d8d [Cheng Lian] Adds feature flag to allow falling back to old style Parquet schema conversion ba84f4b [Cheng Lian] Fixes MapType schema conversion bug 13cb8d5 [Cheng Lian] Fixes MiMa failure 81de5b0 [Cheng Lian] Fixes UDT, workaround read path, and add tests 28ef95b [Cheng Lian] More AnalysisExceptions b10c322 [Cheng Lian] Replaces require() with analysisRequire() which throws AnalysisException cceaf3f [Cheng Lian] Implements backwards compatibility rules in CatalystSchemaConverter --- project/MimaExcludes.scala | 7 +- .../apache/spark/sql/types/DecimalType.scala | 9 +- .../scala/org/apache/spark/sql/SQLConf.scala | 14 + .../sql/parquet/CatalystSchemaConverter.scala | 565 ++++++++++++++ .../sql/parquet/ParquetTableSupport.scala | 6 +- .../spark/sql/parquet/ParquetTypes.scala | 374 +-------- .../spark/sql/parquet/ParquetIOSuite.scala | 6 +- .../sql/parquet/ParquetSchemaSuite.scala | 736 ++++++++++++++++-- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 9 files changed, 1297 insertions(+), 422 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index f678c69a6dfa9..6f86a505b3ae4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -69,7 +69,12 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.parquet.CatalystTimestampConverter"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.CatalystTimestampConverter$") + "org.apache.spark.sql.parquet.CatalystTimestampConverter$"), + // SPARK-6777 Implements backwards compatibility rules in CatalystSchemaConverter + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTypeInfo"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTypeInfo$") ) case v if v.startsWith("1.4") => Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 407dc27326c2e..18cdfa7238f39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -20,13 +20,18 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression /** Precision parameters for a Decimal */ -case class PrecisionInfo(precision: Int, scale: Int) - +case class PrecisionInfo(precision: Int, scale: Int) { + if (scale > precision) { + throw new AnalysisException( + s"Decimal scale ($scale) cannot be greater than precision ($precision).") + } +} /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 265352647fa9f..9a10a23937fbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -264,6 +264,14 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "") + val PARQUET_FOLLOW_PARQUET_FORMAT_SPEC = booleanConf( + key = "spark.sql.parquet.followParquetFormatSpec", + defaultValue = Some(false), + doc = "Wether to stick to Parquet format specification when converting Parquet schema to " + + "Spark SQL schema and vice versa. Sticks to the specification if set to true; falls back " + + "to compatible mode if set to false.", + isPublic = false) + val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf( key = "spark.sql.parquet.output.committer.class", defaultValue = Some(classOf[ParquetOutputCommitter].getName), @@ -498,6 +506,12 @@ private[sql] class SQLConf extends Serializable with CatalystConf { */ private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + /** + * When set to true, sticks to Parquet format spec when converting Parquet schema to Spark SQL + * schema and vice versa. Otherwise, falls back to compatible mode. + */ + private[spark] def followParquetFormatSpec: Boolean = getConf(PARQUET_FOLLOW_PARQUET_FORMAT_SPEC) + /** * When set to true, partition pruning for in-memory columnar tables is enabled. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala new file mode 100644 index 0000000000000..4fd3e93b70311 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -0,0 +1,565 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition._ +import org.apache.parquet.schema._ + +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, SQLConf} + +/** + * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]] and + * vice versa. + * + * Parquet format backwards-compatibility rules are respected when converting Parquet + * [[MessageType]] schemas. + * + * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + * + * @constructor + * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL + * [[StringType]] fields when converting Parquet a [[MessageType]] to Spark SQL + * [[StructType]]. + * @param assumeInt96IsTimestamp Whether unannotated INT96 fields should be assumed to be Spark SQL + * [[TimestampType]] fields when converting Parquet a [[MessageType]] to Spark SQL + * [[StructType]]. Note that Spark SQL [[TimestampType]] is similar to Hive timestamp, which + * has optional nanosecond precision, but different from `TIME_MILLS` and `TIMESTAMP_MILLIS` + * described in Parquet format spec. + * @param followParquetFormatSpec Whether to generate standard DECIMAL, LIST, and MAP structure when + * converting Spark SQL [[StructType]] to Parquet [[MessageType]]. For Spark 1.4.x and + * prior versions, Spark SQL only supports decimals with a max precision of 18 digits, and + * uses non-standard LIST and MAP structure. Note that the current Parquet format spec is + * backwards-compatible with these settings. If this argument is set to `false`, we fallback + * to old style non-standard behaviors. + */ +private[parquet] class CatalystSchemaConverter( + private val assumeBinaryIsString: Boolean, + private val assumeInt96IsTimestamp: Boolean, + private val followParquetFormatSpec: Boolean) { + + // Only used when constructing converter for converting Spark SQL schema to Parquet schema, in + // which case `assumeInt96IsTimestamp` and `assumeBinaryIsString` are irrelevant. + def this() = this( + assumeBinaryIsString = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, + assumeInt96IsTimestamp = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + followParquetFormatSpec = SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get) + + def this(conf: SQLConf) = this( + assumeBinaryIsString = conf.isParquetBinaryAsString, + assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, + followParquetFormatSpec = conf.followParquetFormatSpec) + + def this(conf: Configuration) = this( + assumeBinaryIsString = + conf.getBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get), + assumeInt96IsTimestamp = + conf.getBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get), + followParquetFormatSpec = + conf.getBoolean( + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get)) + + /** + * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. + */ + def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) + + private def convert(parquetSchema: GroupType): StructType = { + val fields = parquetSchema.getFields.map { field => + field.getRepetition match { + case OPTIONAL => + StructField(field.getName, convertField(field), nullable = true) + + case REQUIRED => + StructField(field.getName, convertField(field), nullable = false) + + case REPEATED => + throw new AnalysisException( + s"REPEATED not supported outside LIST or MAP. Type: $field") + } + } + + StructType(fields) + } + + /** + * Converts a Parquet [[Type]] to a Spark SQL [[DataType]]. + */ + def convertField(parquetType: Type): DataType = parquetType match { + case t: PrimitiveType => convertPrimitiveField(t) + case t: GroupType => convertGroupField(t.asGroupType()) + } + + private def convertPrimitiveField(field: PrimitiveType): DataType = { + val typeName = field.getPrimitiveTypeName + val originalType = field.getOriginalType + + def typeString = + if (originalType == null) s"$typeName" else s"$typeName ($originalType)" + + def typeNotImplemented() = + throw new AnalysisException(s"Parquet type not yet supported: $typeString") + + def illegalType() = + throw new AnalysisException(s"Illegal Parquet type: $typeString") + + // When maxPrecision = -1, we skip precision range check, and always respect the precision + // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored + // as binaries with variable lengths. + def makeDecimalType(maxPrecision: Int = -1): DecimalType = { + val precision = field.getDecimalMetadata.getPrecision + val scale = field.getDecimalMetadata.getScale + + CatalystSchemaConverter.analysisRequire( + maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, + s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)") + + DecimalType(precision, scale) + } + + field.getPrimitiveTypeName match { + case BOOLEAN => BooleanType + + case FLOAT => FloatType + + case DOUBLE => DoubleType + + case INT32 => + field.getOriginalType match { + case INT_8 => ByteType + case INT_16 => ShortType + case INT_32 | null => IntegerType + case DATE => DateType + case DECIMAL => makeDecimalType(maxPrecisionForBytes(4)) + case TIME_MILLIS => typeNotImplemented() + case _ => illegalType() + } + + case INT64 => + field.getOriginalType match { + case INT_64 | null => LongType + case DECIMAL => makeDecimalType(maxPrecisionForBytes(8)) + case TIMESTAMP_MILLIS => typeNotImplemented() + case _ => illegalType() + } + + case INT96 => + CatalystSchemaConverter.analysisRequire( + assumeInt96IsTimestamp, + "INT96 is not supported unless it's interpreted as timestamp. " + + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") + TimestampType + + case BINARY => + field.getOriginalType match { + case UTF8 => StringType + case null if assumeBinaryIsString => StringType + case null => BinaryType + case DECIMAL => makeDecimalType() + case _ => illegalType() + } + + case FIXED_LEN_BYTE_ARRAY => + field.getOriginalType match { + case DECIMAL => makeDecimalType(maxPrecisionForBytes(field.getTypeLength)) + case INTERVAL => typeNotImplemented() + case _ => illegalType() + } + + case _ => illegalType() + } + } + + private def convertGroupField(field: GroupType): DataType = { + Option(field.getOriginalType).fold(convert(field): DataType) { + // A Parquet list is represented as a 3-level structure: + // + // group (LIST) { + // repeated group list { + // element; + // } + // } + // + // However, according to the most recent Parquet format spec (not released yet up until + // writing), some 2-level structures are also recognized for backwards-compatibility. Thus, + // we need to check whether the 2nd level or the 3rd level refers to list element type. + // + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + case LIST => + CatalystSchemaConverter.analysisRequire( + field.getFieldCount == 1, s"Invalid list type $field") + + val repeatedType = field.getType(0) + CatalystSchemaConverter.analysisRequire( + repeatedType.isRepetition(REPEATED), s"Invalid list type $field") + + if (isElementType(repeatedType, field.getName)) { + ArrayType(convertField(repeatedType), containsNull = false) + } else { + val elementType = repeatedType.asGroupType().getType(0) + val optional = elementType.isRepetition(OPTIONAL) + ArrayType(convertField(elementType), containsNull = optional) + } + + // scalastyle:off + // `MAP_KEY_VALUE` is for backwards-compatibility + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1 + // scalastyle:on + case MAP | MAP_KEY_VALUE => + CatalystSchemaConverter.analysisRequire( + field.getFieldCount == 1 && !field.getType(0).isPrimitive, + s"Invalid map type: $field") + + val keyValueType = field.getType(0).asGroupType() + CatalystSchemaConverter.analysisRequire( + keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2, + s"Invalid map type: $field") + + val keyType = keyValueType.getType(0) + CatalystSchemaConverter.analysisRequire( + keyType.isPrimitive, + s"Map key type is expected to be a primitive type, but found: $keyType") + + val valueType = keyValueType.getType(1) + val valueOptional = valueType.isRepetition(OPTIONAL) + MapType( + convertField(keyType), + convertField(valueType), + valueContainsNull = valueOptional) + + case _ => + throw new AnalysisException(s"Unrecognized Parquet type: $field") + } + } + + // scalastyle:off + // Here we implement Parquet LIST backwards-compatibility rules. + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + // scalastyle:on + private def isElementType(repeatedType: Type, parentName: String) = { + { + // For legacy 2-level list types with primitive element type, e.g.: + // + // // List (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated int32 element; + // } + // + repeatedType.isPrimitive + } || { + // For legacy 2-level list types whose element type is a group type with 2 or more fields, + // e.g.: + // + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // required int32 num; + // }; + // } + // + repeatedType.asGroupType().getFieldCount > 1 + } || { + // For legacy 2-level list types generated by parquet-avro (Parquet version < 1.6.0), e.g.: + // + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == "array" + } || { + // For Parquet data generated by parquet-thrift, e.g.: + // + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group my_list_tuple { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == s"${parentName}_tuple" + } + } + + /** + * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. + */ + def convert(catalystSchema: StructType): MessageType = { + Types.buildMessage().addFields(catalystSchema.map(convertField): _*).named("root") + } + + /** + * Converts a Spark SQL [[StructField]] to a Parquet [[Type]]. + */ + def convertField(field: StructField): Type = { + convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + } + + private def convertField(field: StructField, repetition: Type.Repetition): Type = { + CatalystSchemaConverter.checkFieldName(field.name) + + field.dataType match { + // =================== + // Simple atomic types + // =================== + + case BooleanType => + Types.primitive(BOOLEAN, repetition).named(field.name) + + case ByteType => + Types.primitive(INT32, repetition).as(INT_8).named(field.name) + + case ShortType => + Types.primitive(INT32, repetition).as(INT_16).named(field.name) + + case IntegerType => + Types.primitive(INT32, repetition).named(field.name) + + case LongType => + Types.primitive(INT64, repetition).named(field.name) + + case FloatType => + Types.primitive(FLOAT, repetition).named(field.name) + + case DoubleType => + Types.primitive(DOUBLE, repetition).named(field.name) + + case StringType => + Types.primitive(BINARY, repetition).as(UTF8).named(field.name) + + case DateType => + Types.primitive(INT32, repetition).as(DATE).named(field.name) + + // NOTE: !! This timestamp type is not specified in Parquet format spec !! + // However, Impala and older versions of Spark SQL use INT96 to store timestamps with + // nanosecond precision (not TIME_MILLIS or TIMESTAMP_MILLIS described in the spec). + case TimestampType => + Types.primitive(INT96, repetition).named(field.name) + + case BinaryType => + Types.primitive(BINARY, repetition).named(field.name) + + // ===================================== + // Decimals (for Spark version <= 1.4.x) + // ===================================== + + // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and + // always store decimals in fixed-length byte arrays. + case DecimalType.Fixed(precision, scale) + if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(minBytesForPrecision(precision)) + .named(field.name) + + case dec @ DecimalType() if !followParquetFormatSpec => + throw new AnalysisException( + s"Data type $dec is not supported. " + + s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," + + "decimal precision and scale must be specified, " + + "and precision must be less than or equal to 18.") + + // ===================================== + // Decimals (follow Parquet format spec) + // ===================================== + + // Uses INT32 for 1 <= precision <= 9 + case DecimalType.Fixed(precision, scale) + if precision <= maxPrecisionForBytes(4) && followParquetFormatSpec => + Types + .primitive(INT32, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses INT64 for 1 <= precision <= 18 + case DecimalType.Fixed(precision, scale) + if precision <= maxPrecisionForBytes(8) && followParquetFormatSpec => + Types + .primitive(INT64, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses FIXED_LEN_BYTE_ARRAY for all other precisions + case DecimalType.Fixed(precision, scale) if followParquetFormatSpec => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(minBytesForPrecision(precision)) + .named(field.name) + + case dec @ DecimalType.Unlimited if followParquetFormatSpec => + throw new AnalysisException( + s"Data type $dec is not supported. Decimal precision and scale must be specified.") + + // =================================================== + // ArrayType and MapType (for Spark versions <= 1.4.x) + // =================================================== + + // Spark 1.4.x and prior versions convert ArrayType with nullable elements into a 3-level + // LIST structure. This behavior mimics parquet-hive (1.6.0rc3). Note that this case is + // covered by the backwards-compatibility rules implemented in `isElementType()`. + case ArrayType(elementType, nullable @ true) if !followParquetFormatSpec => + // group (LIST) { + // optional group bag { + // repeated element; + // } + // } + ConversionPatterns.listType( + repetition, + field.name, + Types + .buildGroup(REPEATED) + .addField(convertField(StructField("element", elementType, nullable))) + .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME)) + + // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level + // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is + // covered by the backwards-compatibility rules implemented in `isElementType()`. + case ArrayType(elementType, nullable @ false) if !followParquetFormatSpec => + // group (LIST) { + // repeated element; + // } + ConversionPatterns.listType( + repetition, + field.name, + convertField(StructField("element", elementType, nullable), REPEATED)) + + // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by + // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. + case MapType(keyType, valueType, valueContainsNull) if !followParquetFormatSpec => + // group (MAP) { + // repeated group map (MAP_KEY_VALUE) { + // required key; + // value; + // } + // } + ConversionPatterns.mapType( + repetition, + field.name, + convertField(StructField("key", keyType, nullable = false)), + convertField(StructField("value", valueType, valueContainsNull))) + + // ================================================== + // ArrayType and MapType (follow Parquet format spec) + // ================================================== + + case ArrayType(elementType, containsNull) if followParquetFormatSpec => + // group (LIST) { + // repeated group list { + // element; + // } + // } + Types + .buildGroup(repetition).as(LIST) + .addField( + Types.repeatedGroup() + .addField(convertField(StructField("element", elementType, containsNull))) + .named("list")) + .named(field.name) + + case MapType(keyType, valueType, valueContainsNull) => + // group (MAP) { + // repeated group key_value { + // required key; + // value; + // } + // } + Types + .buildGroup(repetition).as(MAP) + .addField( + Types + .repeatedGroup() + .addField(convertField(StructField("key", keyType, nullable = false))) + .addField(convertField(StructField("value", valueType, valueContainsNull))) + .named("key_value")) + .named(field.name) + + // =========== + // Other types + // =========== + + case StructType(fields) => + fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) => + builder.addField(convertField(field)) + }.named(field.name) + + case udt: UserDefinedType[_] => + convertField(field.copy(dataType = udt.sqlType)) + + case _ => + throw new AnalysisException(s"Unsupported data type $field.dataType") + } + } + + // Max precision of a decimal value stored in `numBytes` bytes + private def maxPrecisionForBytes(numBytes: Int): Int = { + Math.round( // convert double to long + Math.floor(Math.log10( // number of base-10 digits + Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes + .asInstanceOf[Int] + } + + // Min byte counts needed to store decimals with various precisions + private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision => + var numBytes = 1 + while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { + numBytes += 1 + } + numBytes + } +} + + +private[parquet] object CatalystSchemaConverter { + def checkFieldName(name: String): Unit = { + // ,;{}()\n\t= and space are special characters in Parquet schema + analysisRequire( + !name.matches(".*[ ,;{}()\n\t=].*"), + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ")) + } + + def analysisRequire(f: => Boolean, message: String): Unit = { + if (!f) { + throw new AnalysisException(message) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index e65fa0030e179..0d96a1e8070b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -86,8 +86,7 @@ private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logg // TODO: Why it can be null? if (schema == null) { log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes( - parquetSchema, false, true) + schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false, true) } log.debug(s"list of attributes that will be read: $schema") new RowRecordMaterializer(parquetSchema, schema) @@ -105,8 +104,7 @@ private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logg // If the parquet file is thrift derived, there is a good chance that // it will have the thrift class in metadata. val isThriftDerived = keyValueMetaData.keySet().contains("thrift.class") - parquetSchema = ParquetTypesConverter - .convertFromAttributes(requestedAttributes, isThriftDerived) + parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes) metadata.put( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, ParquetTypesConverter.convertToString(requestedAttributes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index ba2a35b74ef82..4d5199a140344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -29,214 +29,19 @@ import org.apache.parquet.format.converter.ParquetMetadataConverter import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} -import org.apache.parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import org.apache.parquet.schema.Type.Repetition -import org.apache.parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} +import org.apache.parquet.schema.MessageType import org.apache.spark.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.types._ -/** A class representing Parquet info fields we care about, for passing back to Parquet */ -private[parquet] case class ParquetTypeInfo( - primitiveType: ParquetPrimitiveTypeName, - originalType: Option[ParquetOriginalType] = None, - decimalMetadata: Option[DecimalMetadata] = None, - length: Option[Int] = None) - private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = ctype match { case _: NumericType | BooleanType | StringType | BinaryType => true case _: DataType => false } - def toPrimitiveDataType( - parquetType: ParquetPrimitiveType, - binaryAsString: Boolean, - int96AsTimestamp: Boolean): DataType = { - val originalType = parquetType.getOriginalType - val decimalInfo = parquetType.getDecimalMetadata - parquetType.getPrimitiveTypeName match { - case ParquetPrimitiveTypeName.BINARY - if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType - case ParquetPrimitiveTypeName.BINARY => BinaryType - case ParquetPrimitiveTypeName.BOOLEAN => BooleanType - case ParquetPrimitiveTypeName.DOUBLE => DoubleType - case ParquetPrimitiveTypeName.FLOAT => FloatType - case ParquetPrimitiveTypeName.INT32 - if originalType == ParquetOriginalType.DATE => DateType - case ParquetPrimitiveTypeName.INT32 => IntegerType - case ParquetPrimitiveTypeName.INT64 => LongType - case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType - case ParquetPrimitiveTypeName.INT96 => - // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - throw new AnalysisException("Potential loss of precision: cannot convert INT96") - case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY - if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => - // TODO: for now, our reader only supports decimals that fit in a Long - DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) - case _ => throw new AnalysisException(s"Unsupported parquet datatype $parquetType") - } - } - - /** - * Converts a given Parquet `Type` into the corresponding - * [[org.apache.spark.sql.types.DataType]]. - * - * We apply the following conversion rules: - *
    - *
  • Primitive types are converter to the corresponding primitive type.
  • - *
  • Group types that have a single field that is itself a group, which has repetition - * level `REPEATED`, are treated as follows:
      - *
    • If the nested group has name `values`, the surrounding group is converted - * into an [[ArrayType]] with the corresponding field type (primitive or - * complex) as element type.
    • - *
    • If the nested group has name `map` and two fields (named `key` and `value`), - * the surrounding group is converted into a [[MapType]] - * with the corresponding key and value (value possibly complex) types. - * Note that we currently assume map values are not nullable.
    • - *
    • Other group types are converted into a [[StructType]] with the corresponding - * field types.
  • - *
- * Note that fields are determined to be `nullable` if and only if their Parquet repetition - * level is not `REQUIRED`. - * - * @param parquetType The type to convert. - * @return The corresponding Catalyst type. - */ - def toDataType(parquetType: ParquetType, - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): DataType = { - def correspondsToMap(groupType: ParquetGroupType): Boolean = { - if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { - false - } else { - // This mostly follows the convention in ``parquet.schema.ConversionPatterns`` - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - keyValueGroup.getRepetition == Repetition.REPEATED && - keyValueGroup.getName == CatalystConverter.MAP_SCHEMA_NAME && - keyValueGroup.getFieldCount == 2 && - keyValueGroup.getFields.apply(0).getName == CatalystConverter.MAP_KEY_SCHEMA_NAME && - keyValueGroup.getFields.apply(1).getName == CatalystConverter.MAP_VALUE_SCHEMA_NAME - } - } - - def correspondsToArray(groupType: ParquetGroupType): Boolean = { - groupType.getFieldCount == 1 && - groupType.getFieldName(0) == CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME && - groupType.getFields.apply(0).getRepetition == Repetition.REPEATED - } - - if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString, isInt96AsTimestamp) - } else { - val groupType = parquetType.asGroupType() - parquetType.getOriginalType match { - // if the schema was constructed programmatically there may be hints how to convert - // it inside the metadata via the OriginalType field - case ParquetOriginalType.LIST => { // TODO: check enums! - assert(groupType.getFieldCount == 1) - val field = groupType.getFields.apply(0) - if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { - val bag = field.asGroupType() - assert(bag.getFieldCount == 1) - ArrayType( - toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), - containsNull = true) - } else { - ArrayType( - toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) - } - } - case ParquetOriginalType.MAP => { - assert( - !groupType.getFields.apply(0).isPrimitive, - "Parquet Map type malformatted: expected nested group for map!") - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - assert( - keyValueGroup.getFieldCount == 2, - "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") - assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - - val keyType = - toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) - val valueType = - toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) - MapType(keyType, valueType, - keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) - } - case _ => { - // Note: the order of these checks is important! - if (correspondsToMap(groupType)) { // MapType - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - - val keyType = - toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) - val valueType = - toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) - MapType(keyType, valueType, - keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) - } else if (correspondsToArray(groupType)) { // ArrayType - val field = groupType.getFields.apply(0) - if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { - val bag = field.asGroupType() - assert(bag.getFieldCount == 1) - ArrayType( - toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), - containsNull = true) - } else { - ArrayType( - toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) - } - } else { // everything else: StructType - val fields = groupType - .getFields - .map(ptype => new StructField( - ptype.getName, - toDataType(ptype, isBinaryAsString, isInt96AsTimestamp), - ptype.getRepetition != Repetition.REQUIRED)) - StructType(fields) - } - } - } - } - } - - /** - * For a given Catalyst [[org.apache.spark.sql.types.DataType]] return - * the name of the corresponding Parquet primitive type or None if the given type - * is not primitive. - * - * @param ctype The type to convert - * @return The name of the corresponding Parquet type properties - */ - def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match { - case StringType => Some(ParquetTypeInfo( - ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))) - case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY)) - case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN)) - case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE)) - case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT)) - case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - // There is no type for Byte or Short so we promote them to INT32. - case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - case DateType => Some(ParquetTypeInfo( - ParquetPrimitiveTypeName.INT32, Some(ParquetOriginalType.DATE))) - case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) - case TimestampType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT96)) - case DecimalType.Fixed(precision, scale) if precision <= 18 => - // TODO: for now, our writer only supports decimals that fit in a Long - Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, - Some(ParquetOriginalType.DECIMAL), - Some(new DecimalMetadata(precision, scale)), - Some(BYTES_FOR_PRECISION(precision)))) - case _ => None - } - /** * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision. */ @@ -248,177 +53,29 @@ private[parquet] object ParquetTypesConverter extends Logging { length } - /** - * Converts a given Catalyst [[org.apache.spark.sql.types.DataType]] into - * the corresponding Parquet `Type`. - * - * The conversion follows the rules below: - *
    - *
  • Primitive types are converted into Parquet's primitive types.
  • - *
  • [[org.apache.spark.sql.types.StructType]]s are converted - * into Parquet's `GroupType` with the corresponding field types.
  • - *
  • [[org.apache.spark.sql.types.ArrayType]]s are converted - * into a 2-level nested group, where the outer group has the inner - * group as sole field. The inner group has name `values` and - * repetition level `REPEATED` and has the element type of - * the array as schema. We use Parquet's `ConversionPatterns` for this - * purpose.
  • - *
  • [[org.apache.spark.sql.types.MapType]]s are converted - * into a nested (2-level) Parquet `GroupType` with two fields: a key - * type and a value type. The nested group has repetition level - * `REPEATED` and name `map`. We use Parquet's `ConversionPatterns` - * for this purpose
  • - *
- * Parquet's repetition level is generally set according to the following rule: - *
    - *
  • If the call to `fromDataType` is recursive inside an enclosing `ArrayType` or - * `MapType`, then the repetition level is set to `REPEATED`.
  • - *
  • Otherwise, if the attribute whose type is converted is `nullable`, the Parquet - * type gets repetition level `OPTIONAL` and otherwise `REQUIRED`.
  • - *
- * - *@param ctype The type to convert - * @param name The name of the [[org.apache.spark.sql.catalyst.expressions.Attribute]] - * whose type is converted - * @param nullable When true indicates that the attribute is nullable - * @param inArray When true indicates that this is a nested attribute inside an array. - * @return The corresponding Parquet type. - */ - def fromDataType( - ctype: DataType, - name: String, - nullable: Boolean = true, - inArray: Boolean = false, - toThriftSchemaNames: Boolean = false): ParquetType = { - val repetition = - if (inArray) { - Repetition.REPEATED - } else { - if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED - } - val arraySchemaName = if (toThriftSchemaNames) { - name + CatalystConverter.THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX - } else { - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME - } - val typeInfo = fromPrimitiveDataType(ctype) - typeInfo.map { - case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) => - val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull) - for (len <- length) { - builder.length(len) - } - for (metadata <- decimalMetadata) { - builder.precision(metadata.getPrecision).scale(metadata.getScale) - } - builder.named(name) - }.getOrElse { - ctype match { - case udt: UserDefinedType[_] => { - fromDataType(udt.sqlType, name, nullable, inArray, toThriftSchemaNames) - } - case ArrayType(elementType, false) => { - val parquetElementType = fromDataType( - elementType, - arraySchemaName, - nullable = false, - inArray = true, - toThriftSchemaNames) - ConversionPatterns.listType(repetition, name, parquetElementType) - } - case ArrayType(elementType, true) => { - val parquetElementType = fromDataType( - elementType, - arraySchemaName, - nullable = true, - inArray = false, - toThriftSchemaNames) - ConversionPatterns.listType( - repetition, - name, - new ParquetGroupType( - Repetition.REPEATED, - CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, - parquetElementType)) - } - case StructType(structFields) => { - val fields = structFields.map { - field => fromDataType(field.dataType, field.name, field.nullable, - inArray = false, toThriftSchemaNames) - } - new ParquetGroupType(repetition, name, fields.toSeq) - } - case MapType(keyType, valueType, valueContainsNull) => { - val parquetKeyType = - fromDataType( - keyType, - CatalystConverter.MAP_KEY_SCHEMA_NAME, - nullable = false, - inArray = false, - toThriftSchemaNames) - val parquetValueType = - fromDataType( - valueType, - CatalystConverter.MAP_VALUE_SCHEMA_NAME, - nullable = valueContainsNull, - inArray = false, - toThriftSchemaNames) - ConversionPatterns.mapType( - repetition, - name, - parquetKeyType, - parquetValueType) - } - case _ => throw new AnalysisException(s"Unsupported datatype $ctype") - } - } - } - - def convertToAttributes(parquetSchema: ParquetType, - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): Seq[Attribute] = { - parquetSchema - .asGroupType() - .getFields - .map( - field => - new AttributeReference( - field.getName, - toDataType(field, isBinaryAsString, isInt96AsTimestamp), - field.getRepetition != Repetition.REQUIRED)()) + def convertToAttributes( + parquetSchema: MessageType, + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): Seq[Attribute] = { + val converter = new CatalystSchemaConverter( + isBinaryAsString, isInt96AsTimestamp, followParquetFormatSpec = false) + converter.convert(parquetSchema).toAttributes } - def convertFromAttributes(attributes: Seq[Attribute], - toThriftSchemaNames: Boolean = false): MessageType = { - checkSpecialCharacters(attributes) - val fields = attributes.map( - attribute => - fromDataType(attribute.dataType, attribute.name, attribute.nullable, - toThriftSchemaNames = toThriftSchemaNames)) - new MessageType("root", fields) + def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { + val converter = new CatalystSchemaConverter() + converter.convert(StructType.fromAttributes(attributes)) } def convertFromString(string: String): Seq[Attribute] = { Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { case s: StructType => s.toAttributes - case other => throw new AnalysisException(s"Can convert $string to row") - } - } - - private def checkSpecialCharacters(schema: Seq[Attribute]) = { - // ,;{}()\n\t= and space character are special characters in Parquet schema - schema.map(_.name).foreach { name => - if (name.matches(".*[ ,;{}()\n\t=].*")) { - throw new AnalysisException( - s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". - |Please use alias to rename it. - """.stripMargin.split("\n").mkString(" ")) - } + case other => sys.error(s"Can convert $string to row") } } def convertToString(schema: Seq[Attribute]): String = { - checkSpecialCharacters(schema) + schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) StructType.fromAttributes(schema).json } @@ -450,8 +107,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ParquetTypesConverter.convertToString(attributes)) // TODO: add extra data, e.g., table name, date, etc.? - val parquetSchema: MessageType = - ParquetTypesConverter.convertFromAttributes(attributes) + val parquetSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes) val metaData: FileMetaData = new FileMetaData( parquetSchema, extraMetadata, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 47a7be1c6a664..7b16eba00d6fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -99,7 +99,6 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): DataFrame = sqlContext.sparkContext .parallelize(0 to 1000) @@ -158,6 +157,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { checkParquetFile(data) } + test("array and double") { + val data = (1 to 4).map(i => (i.toDouble, Seq(i.toDouble, (i + 1).toDouble))) + checkParquetFile(data) + } + test("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) withParquetDataFrame(data) { df => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 171a656f0e01e..d0bfcde7e032b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -24,26 +24,109 @@ import org.apache.parquet.schema.MessageTypeParser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext +abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { + val sqlContext = TestSQLContext /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. */ - private def testSchema[T <: Product: ClassTag: TypeTag]( - testName: String, messageType: String, isThriftDerived: Boolean = false): Unit = { - test(testName) { - val actual = ParquetTypesConverter.convertFromAttributes( - ScalaReflection.attributesFor[T], isThriftDerived) - val expected = MessageTypeParser.parseMessageType(messageType) + protected def testSchemaInference[T <: Product: ClassTag: TypeTag]( + testName: String, + messageType: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + testSchema( + testName, + StructType.fromAttributes(ScalaReflection.attributesFor[T]), + messageType, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } + + protected def testParquetToCatalyst( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql <= parquet: $testName") { + val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) + val expected = sqlSchema + assert( + actual === expected, + s"""Schema mismatch. + |Expected schema: ${expected.json} + |Actual schema: ${actual.json} + """.stripMargin) + } + } + + protected def testCatalystToParquet( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql => parquet: $testName") { + val actual = converter.convert(sqlSchema) + val expected = MessageTypeParser.parseMessageType(parquetSchema) actual.checkContains(expected) expected.checkContains(actual) } } - testSchema[(Boolean, Int, Long, Float, Double, Array[Byte])]( + protected def testSchema( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + + testCatalystToParquet( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + + testParquetToCatalyst( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } +} + +class ParquetSchemaInferenceSuite extends ParquetSchemaTest { + testSchemaInference[(Boolean, Int, Long, Float, Double, Array[Byte])]( "basic types", """ |message root { @@ -54,9 +137,10 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { | required double _5; | optional binary _6; |} - """.stripMargin) + """.stripMargin, + binaryAsString = false) - testSchema[(Byte, Short, Int, Long, java.sql.Date)]( + testSchemaInference[(Byte, Short, Int, Long, java.sql.Date)]( "logical integral types", """ |message root { @@ -68,27 +152,79 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { |} """.stripMargin) - // Currently String is the only supported logical binary type. - testSchema[Tuple1[String]]( - "binary logical types", + testSchemaInference[Tuple1[String]]( + "string", """ |message root { | optional binary _1 (UTF8); |} + """.stripMargin, + binaryAsString = true) + + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - non-standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated int32 element; + | } + |} """.stripMargin) - testSchema[Tuple1[Seq[Int]]]( - "array", + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - non-standard", """ |message root { | optional group _1 (LIST) { - | repeated int32 array; + | repeated group bag { + | optional int32 element; + | } | } |} """.stripMargin) - testSchema[Tuple1[Map[Int, String]]]( - "map", + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - non-standard", """ |message root { | optional group _1 (MAP) { @@ -100,7 +236,7 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { |} """.stripMargin) - testSchema[Tuple1[Pair[Int, String]]]( + testSchemaInference[Tuple1[Pair[Int, String]]]( "struct", """ |message root { @@ -109,20 +245,21 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { | optional binary _2 (UTF8); | } |} - """.stripMargin) + """.stripMargin, + followParquetFormatSpec = true) - testSchema[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( - "deeply nested type", + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - non-standard", """ |message root { - | optional group _1 (MAP) { - | repeated group map (MAP_KEY_VALUE) { + | optional group _1 (MAP_KEY_VALUE) { + | repeated group map { | required int32 key; | optional group value { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group array { + | optional group element { | required int32 _1; | required double _2; | } @@ -134,43 +271,76 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { |} """.stripMargin) - testSchema[(Option[Int], Map[Int, Option[Double]])]( - "optional types", + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - standard", """ |message root { - | optional int32 _1; - | optional group _2 (MAP) { - | repeated group map (MAP_KEY_VALUE) { + | optional group _1 (MAP) { + | repeated group key_value { | required int32 key; - | optional double value; + | optional group value { + | optional binary _1 (UTF8); + | optional group _2 (LIST) { + | repeated group list { + | optional group element { + | required int32 _1; + | required double _2; + | } + | } + | } + | } | } | } |} - """.stripMargin) + """.stripMargin, + followParquetFormatSpec = true) - // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated - // as expected from attributes - testSchema[(Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( - "thrift generated parquet schema", + testSchemaInference[(Option[Int], Map[Int, Option[Double]])]( + "optional types", """ |message root { - | optional binary _1 (UTF8); - | optional binary _2 (UTF8); - | optional binary _3 (UTF8); - | optional group _4 (LIST) { - | repeated int32 _4_tuple; - | } - | optional group _5 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required binary key (UTF8); - | optional group value (LIST) { - | repeated int32 value_tuple; - | } + | optional int32 _1; + | optional group _2 (MAP) { + | repeated group key_value { + | required int32 key; + | optional double value; | } | } |} - """.stripMargin, isThriftDerived = true) + """.stripMargin, + followParquetFormatSpec = true) + // Parquet files generated by parquet-thrift are already handled by the schema converter, but + // let's leave this test here until both read path and write path are all updated. + ignore("thrift generated parquet schema") { + // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated + // as expected from attributes + testSchemaInference[( + Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( + "thrift generated parquet schema", + """ + |message root { + | optional binary _1 (UTF8); + | optional binary _2 (UTF8); + | optional binary _3 (UTF8); + | optional group _4 (LIST) { + | repeated int32 _4_tuple; + | } + | optional group _5 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value (LIST) { + | repeated int32 value_tuple; + | } + | } + | } + |} + """.stripMargin, + isThriftDerived = true) + } +} + +class ParquetSchemaSuite extends ParquetSchemaTest { test("DataType string parser compatibility") { // This is the generated string from previous versions of the Spark SQL, using the following: // val schema = StructType(List( @@ -180,10 +350,7 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" // scalastyle:off - val jsonString = - """ - |{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]} - """.stripMargin + val jsonString = """{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]}""" // scalastyle:on val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) @@ -277,4 +444,465 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { StructField("secondField", StringType, nullable = true)))) }.getMessage.contains("detected conflicting schemas")) } + + // ======================================================= + // Tests for converting Parquet LIST to Catalyst ArrayType + // ======================================================= + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 2", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | optional int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 2", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 3", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 element; + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 4", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false), + StructField("num", IntegerType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required binary str (UTF8); + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group array { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group f1_tuple { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + // ======================================================= + // Tests for converting Catalyst ArrayType to Parquet LIST + // ======================================================= + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group bag { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 element; + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Parquet Map to Catalyst MapType + // ==================================================== + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | optional binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Catalyst MapType to Parquet Map + // ==================================================== + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ================================= + // Tests for conversion for decimals + // ================================= + + testSchema( + "DECIMAL(1, 0) - standard", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional int32 f1 (DECIMAL(1, 0)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(8, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional int32 f1 (DECIMAL(8, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(9, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional int32 f1 (DECIMAL(9, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(18, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional int64 f1 (DECIMAL(18, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(19, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(19, 3)))), + """message root { + | optional fixed_len_byte_array(9) f1 (DECIMAL(19, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(1, 0) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional fixed_len_byte_array(1) f1 (DECIMAL(1, 0)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(8, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional fixed_len_byte_array(4) f1 (DECIMAL(8, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(9, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(18, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); + |} + """.stripMargin) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index a2e666586c186..f0aad8dbbe64d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -638,7 +638,7 @@ class SQLQuerySuite extends QueryTest { test("SPARK-5203 union with different decimal precision") { Seq.empty[(Decimal, Decimal)] .toDF("d1", "d2") - .select($"d1".cast(DecimalType(10, 15)).as("d")) + .select($"d1".cast(DecimalType(10, 5)).as("d")) .registerTempTable("dn") sql("select d from dn union all select d * 2 from dn") From dca21a83ac33813dd8165acb5f20d06e4f9b9034 Mon Sep 17 00:00:00 2001 From: fe2s Date: Wed, 24 Jun 2015 15:12:23 -0700 Subject: [PATCH 19/44] [SPARK-8558] [BUILD] Script /dev/run-tests fails when _JAVA_OPTIONS env var set Author: fe2s Author: Oleksiy Dyagilev Closes #6956 from fe2s/fix-run-tests and squashes the following commits: 31b6edc [fe2s] str is a built-in function, so using it as a variable name will lead to spurious warnings in some Python linters 7d781a0 [fe2s] fixing for openjdk/IBM, seems like they have slightly different wording, but all have 'version' word. Surrounding with spaces for the case if version word appears in _JAVA_OPTIONS cd455ef [fe2s] address comment, looking for java version string rather than expecting to have on a certain line number ad577d7 [Oleksiy Dyagilev] [SPARK-8558][BUILD] Script /dev/run-tests fails when _JAVA_OPTIONS env var set --- dev/run-tests.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index de1b4537eda5f..e7c09b0f40cdc 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -477,7 +477,12 @@ def determine_java_version(java_exe): raw_output = subprocess.check_output([java_exe, "-version"], stderr=subprocess.STDOUT) - raw_version_str = raw_output.split('\n')[0] # eg 'java version "1.8.0_25"' + + raw_output_lines = raw_output.split('\n') + + # find raw version string, eg 'java version "1.8.0_25"' + raw_version_str = next(x for x in raw_output_lines if " version " in x) + version_str = raw_version_str.split()[-1].strip('"') # eg '1.8.0_25' version, update = version_str.split('_') # eg ['1.8.0', '25'] From 7daa70292e47be6a944351ef00c770ad4bcb0877 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 24 Jun 2015 15:52:58 -0700 Subject: [PATCH 20/44] [SPARK-8567] [SQL] Increase the timeout of HiveSparkSubmitSuite https://issues.apache.org/jira/browse/SPARK-8567 Author: Yin Huai Closes #6957 from yhuai/SPARK-8567 and squashes the following commits: 62dff5b [Yin Huai] Increase the timeout. --- .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index d85516ab0878e..b875e52b986ab 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -95,7 +95,7 @@ class HiveSparkSubmitSuite )) try { - val exitCode = failAfter(120 seconds) { process.exitValue() } + val exitCode = failAfter(180 seconds) { process.exitValue() } if (exitCode != 0) { fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") } From b71d3254e50838ccae43bdb0ff186fda25f03152 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 24 Jun 2015 16:26:00 -0700 Subject: [PATCH 21/44] [SPARK-8075] [SQL] apply type check interface to more expressions a follow up of https://github.com/apache/spark/pull/6405. Note: It's not a big change, a lot of changing is due to I swap some code in `aggregates.scala` to make aggregate functions right below its corresponding aggregate expressions. Author: Wenchen Fan Closes #6723 from cloud-fan/type-check and squashes the following commits: 2124301 [Wenchen Fan] fix tests 5a658bb [Wenchen Fan] add tests 287d3bb [Wenchen Fan] apply type check interface to more expressions --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 17 +- .../spark/sql/catalyst/expressions/Cast.scala | 11 +- .../sql/catalyst/expressions/Expression.scala | 4 +- .../catalyst/expressions/ExtractValue.scala | 10 +- .../sql/catalyst/expressions/aggregates.scala | 420 +++++++++--------- .../sql/catalyst/expressions/arithmetic.scala | 2 - .../expressions/complexTypeCreator.scala | 30 +- .../expressions/decimalFunctions.scala | 17 +- .../sql/catalyst/expressions/generators.scala | 13 +- .../spark/sql/catalyst/expressions/math.scala | 4 +- .../expressions/namedExpressions.scala | 4 +- .../catalyst/expressions/nullFunctions.scala | 27 +- .../spark/sql/catalyst/expressions/sets.scala | 10 +- .../expressions/stringOperations.scala | 2 - .../expressions/windowExpressions.scala | 3 +- .../spark/sql/catalyst/util/TypeUtils.scala | 9 + .../sql/catalyst/analysis/AnalysisSuite.scala | 6 +- .../ExpressionTypeCheckingSuite.scala | 26 +- .../spark/sql/execution/pythonUdfs.scala | 2 +- .../execution/HiveTypeCoercionSuite.scala | 6 - 21 files changed, 337 insertions(+), 290 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/{expressions => analysis}/ExpressionTypeCheckingSuite.scala (84%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b06759f144fd9..cad2c2abe6b1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -587,8 +587,8 @@ class Analyzer( failAnalysis( s"""Expect multiple names given for ${g.getClass.getName}, |but only single name '${name}' specified""".stripMargin) - case Alias(g: Generator, name) => Some((g, name :: Nil)) - case MultiAlias(g: Generator, names) => Some(g, names) + case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil)) + case MultiAlias(g: Generator, names) if g.resolved => Some(g, names) case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index d4ab1fc643c33..4ef7341a33245 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -317,6 +317,7 @@ trait HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) + case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) } } @@ -590,11 +591,12 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if !a.resolved => - val commonType = a.childTypes.reduce( - (a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType)) - CreateArray( - children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) + case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonTypeAndPromoteToString(types) match { + case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) + case None => a + } // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows. case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. @@ -620,12 +622,11 @@ trait HiveTypeCoercion { // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. - case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => + case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 => val types = es.map(_.dataType) findTightestCommonTypeAndPromoteToString(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) - case None => - sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") + case None => c } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d271434a306dd..8bd7fc18a8dd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -31,7 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { - override lazy val resolved = childrenResolved && resolve(child.dataType, dataType) + override def checkInputDataTypes(): TypeCheckResult = { + if (resolve(child.dataType, dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"cannot cast ${child.dataType} to $dataType") + } + } override def foldable: Boolean = child.foldable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a10a959ae766f..f59db3d5dfc23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -162,9 +162,7 @@ abstract class Expression extends TreeNode[Expression] { /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, * or returns a `TypeCheckResult` with an error message if invalid. - * Note: it's not valid to call this method until `childrenResolved == true` - * TODO: we should remove the default implementation and implement it for all - * expressions with proper error message. + * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 4d6c1c265150d..4d7c95ffd1850 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -96,6 +96,11 @@ object ExtractValue { } } +/** + * A common interface of all kinds of extract value expressions. + * Note: concrete extract value expressions are created only by `ExtractValue.apply`, + * we don't need to do type check for them. + */ trait ExtractValue extends UnaryExpression { self: Product => } @@ -179,9 +184,6 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType - override lazy val resolved = childrenResolved && - child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType] - protected def evalNotNull(value: Any, ordinal: Any) = { // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives @@ -203,8 +205,6 @@ case class GetMapValue(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType - override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType] - protected def evalNotNull(value: Any, ordinal: Any) = { val baseValue = value.asInstanceOf[Map[Any, _]] baseValue.get(ordinal).orNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 00d2e499c5890..a9fc54c548f49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog -import org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -101,6 +102,9 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ } override def newInstance(): MinFunction = new MinFunction(child, this) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function min") } case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -132,6 +136,9 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ } override def newInstance(): MaxFunction = new MaxFunction(child, this) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function max") } case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -165,6 +172,21 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def newInstance(): CountFunction = new CountFunction(child, this) } +case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var count: Long = _ + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + count += 1L + } + } + + override def eval(input: InternalRow): Any = count +} + case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { def this() = this(null) @@ -183,6 +205,28 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate } } +case class CountDistinctFunction( + @transient expr: Seq[Expression], + @transient base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + @transient + val distinctValue = new InterpretedProjection(expr) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = distinctValue(input) + if (!evaluatedExpr.anyNull) { + seen.add(evaluatedExpr) + } + } + + override def eval(input: InternalRow): Any = seen.size.toLong +} + case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { def this() = this(null) @@ -278,6 +322,25 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) } } +case class ApproxCountDistinctPartitionFunction( + expr: Expression, + base: AggregateExpression, + relativeSD: Double) + extends AggregateFunction { + def this() = this(null, null, 0) // Required for serialization. + + private val hyperLogLog = new HyperLogLog(relativeSD) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + hyperLogLog.offer(evaluatedExpr) + } + } + + override def eval(input: InternalRow): Any = hyperLogLog +} + case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { @@ -289,6 +352,23 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) } } +case class ApproxCountDistinctMergeFunction( + expr: Expression, + base: AggregateExpression, + relativeSD: Double) + extends AggregateFunction { + def this() = this(null, null, 0) // Required for serialization. + + private val hyperLogLog = new HyperLogLog(relativeSD) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) + } + + override def eval(input: InternalRow): Any = hyperLogLog.cardinality() +} + case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -349,159 +429,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } override def newInstance(): AverageFunction = new AverageFunction(child, this) -} - -case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - child.dataType - } - - override def toString: String = s"SUM($child)" - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() - SplitEvaluation( - Cast(CombineSum(partialSum.toAttribute), dataType), - partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - CombineSum(partialSum.toAttribute), - partialSum :: Nil) - } - } - - override def newInstance(): SumFunction = new SumFunction(child, this) -} - -/** - * Sum should satisfy 3 cases: - * 1) sum of all null values = zero - * 2) sum for table column with no data = null - * 3) sum of column with null and not null values = sum of not null values - * Require separate CombineSum Expression and function as it has to distinguish "No data" case - * versus "data equals null" case, while aggregating results and at each partial expression.i.e., - * Combining PartitionLevel InputData - * <-- null - * Zero <-- Zero <-- null - * - * <-- null <-- no data - * null <-- null <-- no data - */ -case class CombineSum(child: Expression) extends AggregateExpression { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"CombineSum($child)" - override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) -} - -case class SumDistinct(child: Expression) - extends PartialAggregate with trees.UnaryNode[Expression] { - - def this() = this(null) - override def nullable: Boolean = true - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - child.dataType - } - override def toString: String = s"SUM(DISTINCT $child)" - override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() - SplitEvaluation( - CombineSetsAndSum(partialSet.toAttribute, this), - partialSet :: Nil) - } -} -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { - def this() = this(null, null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = true - override def dataType: DataType = base.dataType - override def toString: String = s"CombineAndSum($inputSet)" - override def newInstance(): CombineSetsAndSumFunction = { - new CombineSetsAndSumFunction(inputSet, this) - } -} - -case class CombineSetsAndSumFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next) - } - } - - override def eval(input: InternalRow): Any = { - val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] - if (casted.size == 0) { - null - } else { - Cast(Literal( - casted.iterator.map(f => f.apply(0)).reduceLeft( - base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - base.dataType).eval(null) - } - } -} - -case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"FIRST($child)" - - override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child), "PartialFirst")() - SplitEvaluation( - First(partialFirst.toAttribute), - partialFirst :: Nil) - } - override def newInstance(): FirstFunction = new FirstFunction(child, this) -} - -case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references: AttributeSet = child.references - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"LAST($child)" - - override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child), "PartialLast")() - SplitEvaluation( - Last(partialLast.toAttribute), - partialLast :: Nil) - } - override def newInstance(): LastFunction = new LastFunction(child, this) + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -551,55 +481,41 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } -case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - def this() = this(null, null) // Required for serialization. +case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - var count: Long = _ + override def nullable: Boolean = true - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1L - } + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType } - override def eval(input: InternalRow): Any = count -} - -case class ApproxCountDistinctPartitionFunction( - expr: Expression, - base: AggregateExpression, - relativeSD: Double) - extends AggregateFunction { - def this() = this(null, null, 0) // Required for serialization. + override def toString: String = s"SUM($child)" - private val hyperLogLog = new HyperLogLog(relativeSD) + override def asPartial: SplitEvaluation = { + child.dataType match { + case DecimalType.Fixed(_, _) => + val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + SplitEvaluation( + Cast(CombineSum(partialSum.toAttribute), dataType), + partialSum :: Nil) - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - hyperLogLog.offer(evaluatedExpr) + case _ => + val partialSum = Alias(Sum(child), "PartialSum")() + SplitEvaluation( + CombineSum(partialSum.toAttribute), + partialSum :: Nil) } } - override def eval(input: InternalRow): Any = hyperLogLog -} - -case class ApproxCountDistinctMergeFunction( - expr: Expression, - base: AggregateExpression, - relativeSD: Double) - extends AggregateFunction { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) - } + override def newInstance(): SumFunction = new SumFunction(child, this) - override def eval(input: InternalRow): Any = hyperLogLog.cardinality() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") } case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -633,6 +549,30 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr } } +/** + * Sum should satisfy 3 cases: + * 1) sum of all null values = zero + * 2) sum for table column with no data = null + * 3) sum of column with null and not null values = sum of not null values + * Require separate CombineSum Expression and function as it has to distinguish "No data" case + * versus "data equals null" case, while aggregating results and at each partial expression.i.e., + * Combining PartitionLevel InputData + * <-- null + * Zero <-- Zero <-- null + * + * <-- null <-- no data + * null <-- null <-- no data + */ +case class CombineSum(child: Expression) extends AggregateExpression { + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"CombineSum($child)" + override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) +} + case class CombineSumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -670,6 +610,33 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) } } +case class SumDistinct(child: Expression) + extends PartialAggregate with trees.UnaryNode[Expression] { + + def this() = this(null) + override def nullable: Boolean = true + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString: String = s"SUM(DISTINCT $child)" + override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) + + override def asPartial: SplitEvaluation = { + val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() + SplitEvaluation( + CombineSetsAndSum(partialSet.toAttribute, this), + partialSet :: Nil) + } + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") +} + case class SumDistinctFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -696,8 +663,20 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } -case class CountDistinctFunction( - @transient expr: Seq[Expression], +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { + def this() = this(null, null) + + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = true + override def dataType: DataType = base.dataType + override def toString: String = s"CombineAndSum($inputSet)" + override def newInstance(): CombineSetsAndSumFunction = { + new CombineSetsAndSumFunction(inputSet, this) + } +} + +case class CombineSetsAndSumFunction( + @transient inputSet: Expression, @transient base: AggregateExpression) extends AggregateFunction { @@ -705,17 +684,39 @@ case class CountDistinctFunction( val seen = new OpenHashSet[Any]() - @transient - val distinctValue = new InterpretedProjection(expr) - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) + val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] + val inputIterator = inputSetEval.iterator + while (inputIterator.hasNext) { + seen.add(inputIterator.next) } } - override def eval(input: InternalRow): Any = seen.size.toLong + override def eval(input: InternalRow): Any = { + val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] + if (casted.size == 0) { + null + } else { + Cast(Literal( + casted.iterator.map(f => f.apply(0)).reduceLeft( + base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), + base.dataType).eval(null) + } + } +} + +case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"FIRST($child)" + + override def asPartial: SplitEvaluation = { + val partialFirst = Alias(First(child), "PartialFirst")() + SplitEvaluation( + First(partialFirst.toAttribute), + partialFirst :: Nil) + } + override def newInstance(): FirstFunction = new FirstFunction(child, this) } case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -732,6 +733,21 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = result } +case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def references: AttributeSet = child.references + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"LAST($child)" + + override def asPartial: SplitEvaluation = { + val partialLast = Alias(Last(child), "PartialLast")() + SplitEvaluation( + Last(partialLast.toAttribute), + partialLast :: Nil) + } + override def newInstance(): LastFunction = new LastFunction(child, this) +} + case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ace8427c8ddaf..3d4d9e2d798f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -25,8 +25,6 @@ import org.apache.spark.sql.types._ abstract class UnaryArithmetic extends UnaryExpression { self: Product => - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index e0bf07ed182f3..5def57b067424 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ - /** * Returns an Array containing the evaluation of all children expressions. */ @@ -27,15 +28,12 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - lazy val childTypes = children.map(_.dataType).distinct - - override lazy val resolved = - childrenResolved && childTypes.size <= 1 + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") override def dataType: DataType = { - assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}") ArrayType( - childTypes.headOption.getOrElse(NullType), + children.headOption.map(_.dataType).getOrElse(NullType), containsNull = children.exists(_.nullable)) } @@ -56,19 +54,15 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override lazy val resolved: Boolean = childrenResolved - override lazy val dataType: StructType = { - assert(resolved, - s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) } + } StructType(fields) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 2bc893af02641..f5c2dde191cf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ -/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ +/** + * Return the unscaled Long value of a Decimal, assuming it fits in a Long. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ case class UnscaledValue(child: Expression) extends UnaryExpression { override def dataType: DataType = LongType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"UnscaledValue($child)" override def eval(input: InternalRow): Any = { @@ -43,12 +44,14 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { } } -/** Create a Decimal from an unscaled Long value */ +/** + * Create a Decimal from an unscaled Long value. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { override def dataType: DataType = DecimalType(precision, scale) - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"MakeDecimal($child,$precision,$scale)" override def eval(input: InternalRow): Decimal = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index f30cb42d12b83..356560e54cae3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ @@ -100,9 +100,14 @@ case class UserDefinedGenerator( case class Explode(child: Expression) extends Generator with trees.UnaryNode[Expression] { - override lazy val resolved = - child.resolved && - (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"input to function explode should be array or map type, not ${child.dataType}") + } + } override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match { case ArrayType(et, containsNull) => (et, containsNull) :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 250564dc4b818..5694afc61be05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.lang.{Long => JLong} -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -60,7 +59,6 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"$name($child)" @@ -224,7 +222,7 @@ case class Bin(child: Expression) def funcName: String = name.toLowerCase - override def eval(input: catalyst.InternalRow): Any = { + override def eval(input: InternalRow): Any = { val evalE = child.eval(input) if (evalE == null) { null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 9cacdceb13837..6f56a9ec7beb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} @@ -113,7 +112,8 @@ case class Alias(child: Expression, name: String)( extends NamedExpression with trees.UnaryNode[Expression] { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) - override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator] + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !child.isInstanceOf[Generator] override def eval(input: InternalRow): Any = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 98acaf23c44c1..5d5911403ece1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,33 +17,32 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ - override def nullable: Boolean = !children.exists(!_.nullable) + override def nullable: Boolean = children.forall(_.nullable) // Coalesce is foldable if all children are foldable. - override def foldable: Boolean = !children.exists(!_.foldable) + override def foldable: Boolean = children.forall(_.foldable) - // Only resolved if all the children are of the same type. - override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1) + override def checkInputDataTypes(): TypeCheckResult = { + if (children == Nil) { + TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce") + } + } override def toString: String = s"Coalesce(${children.mkString(",")})" - override def dataType: DataType = if (resolved) { - children.head.dataType - } else { - val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") - throw new UnresolvedException( - this, s"Coalesce cannot have children of different types. $childTypes") - } + override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { - var i = 0 var result: Any = null val childIterator = children.iterator while (childIterator.hasNext && result == null) { @@ -75,7 +74,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } case class IsNull(child: Expression) extends UnaryExpression with Predicate { - override def foldable: Boolean = child.foldable override def nullable: Boolean = false override def eval(input: InternalRow): Any = { @@ -93,7 +91,6 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { } case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { - override def foldable: Boolean = child.foldable override def nullable: Boolean = false override def toString: String = s"IS NOT NULL $child" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 30e41677b774b..efc6f50b78943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -78,6 +78,8 @@ case class NewSet(elementType: DataType) extends LeafExpression { /** * Adds an item to a set. * For performance, this expression mutates its input during evaluation. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class AddItemToSet(item: Expression, set: Expression) extends Expression { @@ -85,7 +87,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { override def nullable: Boolean = set.nullable - override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT] + override def dataType: DataType = set.dataType override def eval(input: InternalRow): Any = { val itemEval = item.eval(input) @@ -128,12 +130,14 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { /** * Combines the elements of two sets. * For performance, this expression mutates its left input set during evaluation. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { override def nullable: Boolean = left.nullable || right.nullable - override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT] + override def dataType: DataType = left.dataType override def symbol: String = "++=" @@ -176,6 +180,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres /** * Returns the number of elements in the input set. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class CountSet(child: Expression) extends UnaryExpression { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 315c63e63c635..44416e79cd7aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -117,8 +117,6 @@ trait CaseConversionExpression extends ExpectsInputTypes { def convert(v: UTF8String): UTF8String - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def dataType: DataType = StringType override def expectedChildTypes: Seq[DataType] = Seq(StringType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 896e383f50eac..12023ad311dc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -68,7 +68,8 @@ case class WindowSpecDefinition( override def children: Seq[Expression] = partitionSpec ++ orderSpec override lazy val resolved: Boolean = - childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame] + childrenResolved && checkInputDataTypes().isSuccess && + frameSpecification.isInstanceOf[SpecifiedWindowFrame] override def toString: String = simpleString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 04857a23f4c1e..8656cc334d09f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -48,6 +48,15 @@ object TypeUtils { } } + def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { + if (types.distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's ${types.mkString("[", ", ", "]")}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e09cd790a7187..77ca080f366cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -193,7 +193,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { errorTest( "bad casts", testRelation.select(Literal(1).cast(BinaryType).as('badCast)), - "invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) errorTest( "non-boolean filters", @@ -264,9 +264,9 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { val plan = Aggregate( Nil, - Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil, + Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, LocalRelation( - AttributeReference("a", StringType)(exprId = ExprId(2)))) + AttributeReference("a", IntegerType)(exprId = ExprId(2)))) assert(plan.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala similarity index 84% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 49b111989799b..bc1537b0715b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.expressions +package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.StringType @@ -136,6 +136,28 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError( CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), "WHEN expressions in CaseWhen should all be boolean type") + } + + test("check types for aggregates") { + // We will cast String to Double for sum and average + assertSuccess(Sum('stringField)) + assertSuccess(SumDistinct('stringField)) + assertSuccess(Average('stringField)) + + assertError(Min('complexField), "function min accepts non-complex type") + assertError(Max('complexField), "function max accepts non-complex type") + assertError(Sum('booleanField), "function sum accepts numeric type") + assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type") + assertError(Average('booleanField), "function average accepts numeric type") + } + test("check types for others") { + assertError(CreateArray(Seq('intField, 'booleanField)), + "input to function array should all be the same type") + assertError(Coalesce(Seq('intField, 'booleanField)), + "input to function coalesce should all be the same type") + assertError(Coalesce(Nil), "input to function coalesce cannot be empty") + assertError(Explode('intField), + "input to function explode should be array or map type") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 6db551c543a9c..f9c3fe92c2670 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -55,7 +55,7 @@ private[spark] case class PythonUDF( override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" - def nullable: Boolean = true + override def nullable: Boolean = true override def eval(input: InternalRow): Any = { throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index f0f04f8c73fb4..197e9bfb02c4e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -59,10 +59,4 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { } assert(numEquals === 1) } - - test("COALESCE with different types") { - intercept[RuntimeException] { - TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect() - } - } } From 82f80c1c7dc42c11bca2b6832c10f9610a43391b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 24 Jun 2015 19:34:07 -0700 Subject: [PATCH 22/44] Two minor SQL cleanup (compiler warning & indent). Author: Reynold Xin Closes #7000 from rxin/minor-cleanup and squashes the following commits: 046044c [Reynold Xin] Two minor SQL cleanup (compiler warning & indent). --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cad2c2abe6b1a..117c87a785fdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -309,8 +309,8 @@ class Analyzer( .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) } - // Only handle first case, others will be fixed on the next pass. - .headOption match { + // Only handle first case, others will be fixed on the next pass. + .headOption match { case None => /* * No result implies that there is a logical plan node that produces new references diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 4ef7341a33245..976fa57cb98d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -678,8 +678,8 @@ trait HiveTypeCoercion { findTightestCommonTypeAndPromoteToString((c.key +: c.whenList).map(_.dataType)) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { - case Seq(when, then) if when.dataType != commonType => - Seq(Cast(when, commonType), then) + case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType => + Seq(Cast(whenExpr, commonType), thenExpr) case other => other }.reduce(_ ++ _) CaseKeyWhen(Cast(c.key, commonType), castedBranches) From 7bac2fe7717c0102b4875dbd95ae0bbf964536e3 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Wed, 24 Jun 2015 22:09:31 -0700 Subject: [PATCH 23/44] [SPARK-7884] Move block deserialization from BlockStoreShuffleFetcher to ShuffleReader This commit updates the shuffle read path to enable ShuffleReader implementations more control over the deserialization process. The BlockStoreShuffleFetcher.fetch() method has been renamed to BlockStoreShuffleFetcher.fetchBlockStreams(). Previously, this method returned a record iterator; now, it returns an iterator of (BlockId, InputStream). Deserialization of records is now handled in the ShuffleReader.read() method. This change creates a cleaner separation of concerns and allows implementations of ShuffleReader more flexibility in how records are retrieved. Author: Matt Massie Author: Kay Ousterhout Closes #6423 from massie/shuffle-api-cleanup and squashes the following commits: 8b0632c [Matt Massie] Minor Scala style fixes d0a1b39 [Matt Massie] Merge pull request #1 from kayousterhout/massie_shuffle-api-cleanup 290f1eb [Kay Ousterhout] Added test for HashShuffleReader.read() 5186da0 [Kay Ousterhout] Revert "Add test to ensure HashShuffleReader is freeing resources" f98a1b9 [Matt Massie] Add test to ensure HashShuffleReader is freeing resources a011bfa [Matt Massie] Use PrivateMethodTester on check that delegate stream is closed 4ea1712 [Matt Massie] Small code cleanup for readability 7429a98 [Matt Massie] Update tests to check that BufferReleasingStream is closing delegate InputStream f458489 [Matt Massie] Remove unnecessary map() on return Iterator 4abb855 [Matt Massie] Consolidate metric code. Make it clear why InterrubtibleIterator is needed. 5c30405 [Matt Massie] Return visibility of BlockStoreShuffleFetcher to private[hash] 7eedd1d [Matt Massie] Small Scala import cleanup 28f8085 [Matt Massie] Small import nit f93841e [Matt Massie] Update shuffle read metrics in ShuffleReader instead of BlockStoreShuffleFetcher. 7e8e0fe [Matt Massie] Minor Scala style fixes 01e8721 [Matt Massie] Explicitly cast iterator in branches for type clarity 7c8f73e [Matt Massie] Close Block InputStream immediately after all records are read 208b7a5 [Matt Massie] Small code style changes b70c945 [Matt Massie] Make BlockStoreShuffleFetcher visible to shuffle package 19135f2 [Matt Massie] [SPARK-7884] Allow Spark shuffle APIs to be more customizable --- .../hash/BlockStoreShuffleFetcher.scala | 59 +++---- .../shuffle/hash/HashShuffleReader.scala | 52 +++++- .../storage/ShuffleBlockFetcherIterator.scala | 90 +++++++---- .../shuffle/hash/HashShuffleReaderSuite.scala | 150 ++++++++++++++++++ .../ShuffleBlockFetcherIteratorSuite.scala | 59 ++++--- 5 files changed, 314 insertions(+), 96 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 597d46a3d2223..9d8e7e9f03aea 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -17,29 +17,29 @@ package org.apache.spark.shuffle.hash -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.util.{Failure, Success, Try} +import java.io.InputStream + +import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.{Failure, Success} import org.apache.spark._ -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.CompletionIterator +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, + ShuffleBlockId} private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetch[T]( + def fetchBlockStreams( shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer) - : Iterator[T] = + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) + : Iterator[(BlockId, InputStream)] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val blockManager = SparkEnv.get.blockManager val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId) logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) @@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + blocksByAddress, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + + // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler + blockFetcherItr.map { blockPair => val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Success(block) => { - block.asInstanceOf[Iterator[T]] + case Success(inputStream) => { + (blockId, inputStream) } case Failure(e) => { blockId match { @@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } } - - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - SparkEnv.get.blockManager.shuffleClient, - blockManager, - blocksByAddress, - serializer, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - val itr = blockFetcherItr.flatMap(unpackBlock) - - val completionIter = CompletionIterator[T, Iterator[T]](itr, { - context.taskMetrics.updateShuffleReadMetrics() - }) - - new InterruptibleIterator[T](context, completionIter) { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - override def next(): T = { - readMetrics.incRecordsRead(1) - delegate.next() - } - } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 41bafabde05b9..d5c9880659dd3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,16 +17,20 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.storage.BlockManager +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext) + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, @@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( + handle.shuffleId, startPartition, context, blockManager, mapOutputTracker) + + // Wrap the streams for compression based on configuration + val wrappedStreams = blockStreams.map { case (blockId, inputStream) => + blockManager.wrapForCompression(blockId, inputStream) + } + val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val serializerInstance = ser.newInstance() + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { wrappedStream => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map(record => { + readMetrics.incRecordsRead(1) + record + }), + context.taskMetrics().updateShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { - new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { - new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context)) + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") - - // Convert the Product2s to pairs since this is what downstream RDDs currently expect - iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // Sort the output if there is a sort ordering defined. diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d0faab62c9e9e..e49e39679e940 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,23 +17,23 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import scala.util.{Failure, Try} import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.{SerializerInstance, Serializer} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.util.Utils /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a - * pipelined fashion as they are received. + * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid * using too much memory. @@ -44,7 +44,6 @@ import org.apache.spark.util.{CompletionIterator, Utils} * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. - * @param serializer serializer used to deserialize the data. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. */ private[spark] @@ -53,9 +52,8 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[InputStream])] with Logging { import ShuffleBlockFetcherIterator._ @@ -83,7 +81,7 @@ final class ShuffleBlockFetcherIterator( /** * A queue to hold our results. This turns the asynchronous model provided by - * [[BlockTransferService]] into a synchronous model (iterator). + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). */ private[this] val results = new LinkedBlockingQueue[FetchResult] @@ -102,9 +100,7 @@ final class ShuffleBlockFetcherIterator( /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L - private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - - private[this] val serializerInstance: SerializerInstance = serializer.newInstance() + private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no @@ -114,17 +110,23 @@ final class ShuffleBlockFetcherIterator( initialize() - /** - * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. - */ - private[this] def cleanup() { - isZombie = true + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { case SuccessFetchResult(_, _, buf) => buf.release() case _ => } + currentResult = null + } + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + releaseCurrentResultBuffer() // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { @@ -272,7 +274,13 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Try[Iterator[Any]]) = { + /** + * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + */ + override def next(): (BlockId, Try[InputStream]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -290,22 +298,15 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[Iterator[Any]] = result match { + val iteratorTry: Try[InputStream] = result match { case FailureFetchResult(_, e) => Failure(e) case SuccessFetchResult(blockId, _, buf) => // There is a chance that createInputStream can fail (e.g. fetching a local file that does // not exist, SPARK-4085). In that case, we should propagate the right exception so // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { is0 => - val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializerInstance.deserializeStream(is).asKeyValueIterator - CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - buf.release() - }) + Try(buf.createInputStream()).map { inputStream => + new BufferReleasingInputStream(inputStream, this) } } @@ -313,6 +314,39 @@ final class ShuffleBlockFetcherIterator( } } +/** + * Helper class that ensures a ManagedBuffer is release upon InputStream.close() + */ +private class BufferReleasingInputStream( + private val delegate: InputStream, + private val iterator: ShuffleBlockFetcherIterator) + extends InputStream { + private[this] var closed = false + + override def read(): Int = delegate.read() + + override def close(): Unit = { + if (!closed) { + delegate.close() + iterator.releaseCurrentResultBuffer() + closed = true + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = delegate.skip(n) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = delegate.read(b) + + override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + + override def reset(): Unit = delegate.reset() +} private[storage] object ShuffleBlockFetcherIterator { diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala new file mode 100644 index 0000000000000..28ca68698e3dc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import java.io.{ByteArrayOutputStream, InputStream} +import java.nio.ByteBuffer + +import org.mockito.Matchers.{eq => meq, _} +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} + +/** + * Wrapper for a managed buffer that keeps track of how many times retain and release are called. + * + * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class + * is final (final classes cannot be spied on). + */ +class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer { + var callsToRetain = 0 + var callsToRelease = 0 + + override def size(): Long = underlyingBuffer.size() + override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer() + override def createInputStream(): InputStream = underlyingBuffer.createInputStream() + override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty() + + override def retain(): ManagedBuffer = { + callsToRetain += 1 + underlyingBuffer.retain() + } + override def release(): ManagedBuffer = { + callsToRelease += 1 + underlyingBuffer.release() + } +} + +class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { + + /** + * This test makes sure that, when data is read from a HashShuffleReader, the underlying + * ManagedBuffers that contain the data are eventually released. + */ + test("read() releases resources on completion") { + val testConf = new SparkConf(false) + // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the + // shuffle code calls SparkEnv.get()). + sc = new SparkContext("local", "test", testConf) + + val reduceId = 15 + val shuffleId = 22 + val numMaps = 6 + val keyValuePairsPerMap = 10 + val serializer = new JavaSerializer(testConf) + + // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we + // can ensure retain() and release() are properly called. + val blockManager = mock(classOf[BlockManager]) + + // Create a return function to use for the mocked wrapForCompression method that just returns + // the original input stream. + val dummyCompressionFunction = new Answer[InputStream] { + override def answer(invocation: InvocationOnMock): InputStream = + invocation.getArguments()(1).asInstanceOf[InputStream] + } + + // Create a buffer with some randomly generated key-value pairs to use as the shuffle data + // from each mappers (all mappers return the same shuffle data). + val byteOutputStream = new ByteArrayOutputStream() + val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + (0 until keyValuePairsPerMap).foreach { i => + serializationStream.writeKey(i) + serializationStream.writeValue(2*i) + } + + // Setup the mocked BlockManager to return RecordingManagedBuffers. + val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) + when(blockManager.blockManagerId).thenReturn(localBlockManagerId) + val buffers = (0 until numMaps).map { mapId => + // Create a ManagedBuffer with the shuffle data. + val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray)) + val managedBuffer = new RecordingManagedBuffer(nioBuffer) + + // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to + // fetch shuffle data. + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) + .thenAnswer(dummyCompressionFunction) + + managedBuffer + } + + // Make a mocked MapOutputTracker for the shuffle reader to use to determine what + // shuffle data to read. + val mapOutputTracker = mock(classOf[MapOutputTracker]) + // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val statuses: Array[(BlockManagerId, Long)] = + Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size().toLong)) + when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses) + + // Create a mocked shuffle handle to pass into HashShuffleReader. + val shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.serializer).thenReturn(Some(serializer)) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + val shuffleReader = new HashShuffleReader( + shuffleHandle, + reduceId, + reduceId + 1, + new TaskContextImpl(0, 0, 0, 0, null), + blockManager, + mapOutputTracker) + + assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) + + // Calling .length above will have exhausted the iterator; make sure that exhausting the + // iterator caused retain and release to be called on each buffer. + buffers.foreach { buffer => + assert(buffer.callsToRetain === 1) + assert(buffer.callsToRelease === 1) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 2a7fe67ad8585..9ced4148d7206 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,23 +17,25 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.Semaphore -import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.future import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkConf, SparkFunSuite, TaskContextImpl} +import org.apache.spark.{SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.serializer.TestSerializer -class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { + +class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. @@ -57,7 +59,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer } - private val conf = new SparkConf + // Create a mock managed buffer for testing + def createMockManagedBuffer(): ManagedBuffer = { + val mockManagedBuffer = mock(classOf[ManagedBuffer]) + when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream])) + mockManagedBuffer + } test("successful 3 local reads + 2 remote reads") { val blockManager = mock(classOf[BlockManager]) @@ -66,9 +73,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure blockManager.getBlockData would return the blocks val localBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => doReturn(buf).when(blockManager).getBlockData(meq(blockId)) } @@ -76,9 +83,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val remoteBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer()) val transfer = createMockTransfer(remoteBlocks) @@ -92,7 +98,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // 3 local blocks fetched in initialization @@ -100,15 +105,24 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (blockId, subIterator) = iterator.next() - assert(subIterator.isSuccess, + val (blockId, inputStream) = iterator.next() + assert(inputStream.isSuccess, s"iterator should have 5 elements defined but actually has $i elements") - // Make sure we release the buffer once the iterator is exhausted. + // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream + val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream] verify(mockBuf, times(0)).release() - subIterator.get.foreach(_ => Unit) // exhaust the iterator + val delegateAccess = PrivateMethod[InputStream]('delegate) + + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() + wrappedInputStream.close() + verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + wrappedInputStream.close() // close should be idempotent verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } // 3 local blocks, and 2 remote blocks @@ -125,10 +139,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) @@ -159,11 +172,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) - // Exhaust the first block, and then it should be released. - iterator.next()._2.get.foreach(_ => Unit) + verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() + iterator.next()._2.get.close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator @@ -222,7 +234,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // Continue only after the mock calls onBlockFetchFailure From c337844ed7f9b2cb7b217dc935183ef5e1096ca1 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 25 Jun 2015 00:06:23 -0700 Subject: [PATCH 24/44] [SPARK-8604] [SQL] HadoopFsRelation subclasses should set their output format class `HadoopFsRelation` subclasses, especially `ParquetRelation2` should set its own output format class, so that the default output committer can be setup correctly when doing appending (where we ignore user defined output committers). Author: Cheng Lian Closes #6998 from liancheng/spark-8604 and squashes the following commits: 9be51d1 [Cheng Lian] Adds more comments 6db1368 [Cheng Lian] HadoopFsRelation subclasses should set their output format class --- .../apache/spark/sql/parquet/newParquet.scala | 6 ++++++ .../spark/sql/hive/orc/OrcRelation.scala | 12 ++++++++++- .../sql/sources/SimpleTextRelation.scala | 2 ++ .../sql/sources/hadoopFsRelationSuites.scala | 21 +++++++++++++++++++ 4 files changed, 40 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 1d353bd8e1114..bc39fae2bcfde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -194,6 +194,12 @@ private[sql] class ParquetRelation2( committerClass, classOf[ParquetOutputCommitter]) + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + // TODO There's no need to use two kinds of WriteSupport // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and // complex types. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 705f48f1cd9f0..0fd7b3a91d6dd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSer import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, RecordWriter, Reporter} +import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -194,6 +194,16 @@ private[sql] class OrcRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = { + job.getConfiguration match { + case conf: JobConf => + conf.setOutputFormat(classOf[OrcOutputFormat]) + case conf => + conf.setClass( + "mapred.output.format.class", + classOf[OrcOutputFormat], + classOf[MapRedOutputFormat[_, _]]) + } + new OutputWriterFactory { override def newInstance( path: String, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 5d7cd16c129cd..e8141923a9b5c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -119,6 +119,8 @@ class SimpleTextRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + job.setOutputFormatClass(classOf[TextOutputFormat[_, _]]) + override def newInstance( path: String, dataSchema: StructType, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index a16ab3a00ddb8..afecf9675e11f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -719,4 +719,25 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { } } } + + test("SPARK-8604: Parquet data source should write summary file while doing appending") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(0, 5) + df.write.mode(SaveMode.Overwrite).parquet(path) + + val summaryPath = new Path(path, "_metadata") + val commonSummaryPath = new Path(path, "_common_metadata") + + val fs = summaryPath.getFileSystem(configuration) + fs.delete(summaryPath, true) + fs.delete(commonSummaryPath, true) + + df.write.mode(SaveMode.Append).parquet(path) + checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) + + assert(fs.exists(summaryPath)) + assert(fs.exists(commonSummaryPath)) + } + } } From 085a7216bf5e6c2b4f297feca4af71a751e37975 Mon Sep 17 00:00:00 2001 From: Joshi Date: Thu, 25 Jun 2015 20:21:34 +0900 Subject: [PATCH 25/44] [SPARK-5768] [WEB UI] Fix for incorrect memory in Spark UI Fix for incorrect memory in Spark UI as per SPARK-5768 Author: Joshi Author: Rekha Joshi Closes #6972 from rekhajoshm/SPARK-5768 and squashes the following commits: b678a91 [Joshi] Fix for incorrect memory in Spark UI 2fe53d9 [Joshi] Fix for incorrect memory in Spark UI eb823b8 [Joshi] SPARK-5768: Fix for incorrect memory in Spark UI 0be142d [Rekha Joshi] Merge pull request #3 from apache/master 106fd8e [Rekha Joshi] Merge pull request #2 from apache/master e3677c9 [Rekha Joshi] Merge pull request #1 from apache/master --- core/src/main/scala/org/apache/spark/ui/ToolTips.scala | 4 ++++ .../main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 063e2a1f8b18e..e2d25e36365fa 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -35,6 +35,10 @@ private[spark] object ToolTips { val OUTPUT = "Bytes and records written to Hadoop." + val STORAGE_MEMORY = + "Memory used / total available memory for storage of data " + + "like RDD partitions cached in memory. " + val SHUFFLE_WRITE = "Bytes and records written to disk in order to be read by a shuffle in a future stage." diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b247e4cdc3bd4..01cddda4c62cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -67,7 +67,7 @@ private[ui] class ExecutorsPage( Executor ID Address RDD Blocks - Memory Used + Storage Memory Disk Used Active Tasks Failed Tasks From e988adb58f02d06065837f3d79eee220f6558def Mon Sep 17 00:00:00 2001 From: Tom Graves Date: Thu, 25 Jun 2015 08:27:08 -0500 Subject: [PATCH 26/44] =?UTF-8?q?[SPARK-8574]=20org/apache/spark/unsafe=20?= =?UTF-8?q?doesn't=20honor=20the=20java=20source/ta=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …rget versions. I basically copied the compatibility rules from the top level pom.xml into here. Someone more familiar with all the options in the top level pom may want to make sure nothing else should be copied on down. With this is allows me to build with jdk8 and run with lower versions. Source shows compiled for jdk6 as its supposed to. Author: Tom Graves Author: Thomas Graves Closes #6989 from tgravescs/SPARK-8574 and squashes the following commits: e1ea2d4 [Thomas Graves] Change to use combine.children="append" 150d645 [Tom Graves] [SPARK-8574] org/apache/spark/unsafe doesn't honor the java source/target versions --- unsafe/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 62c6354f1e203..dd2ae6457f0b9 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -80,7 +80,7 @@ net.alchim31.maven scala-maven-plugin - + -XDignore.symbol.file From f9b397f54d1c491680d70aba210bb8211fd249c1 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 25 Jun 2015 06:52:03 -0700 Subject: [PATCH 27/44] [SPARK-8567] [SQL] Add logs to record the progress of HiveSparkSubmitSuite. Author: Yin Huai Closes #7009 from yhuai/SPARK-8567 and squashes the following commits: 62fb1f9 [Yin Huai] Add sc.stop(). b22cf7d [Yin Huai] Add logs. --- .../org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index b875e52b986ab..a38ed23b5cf9a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -115,6 +115,7 @@ object SparkSubmitClassLoaderTest extends Logging { val sc = new SparkContext(conf) val hiveContext = new TestHiveContext(sc) val df = hiveContext.createDataFrame((1 to 100).map(i => (i, i))).toDF("i", "j") + logInfo("Testing load classes at the driver side.") // First, we load classes at driver side. try { Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) @@ -124,6 +125,7 @@ object SparkSubmitClassLoaderTest extends Logging { throw new Exception("Could not load user class from jar:\n", t) } // Second, we load classes at the executor side. + logInfo("Testing load classes at the executor side.") val result = df.mapPartitions { x => var exception: String = null try { @@ -141,6 +143,7 @@ object SparkSubmitClassLoaderTest extends Logging { } // Load a Hive UDF from the jar. + logInfo("Registering temporary Hive UDF provided in a jar.") hiveContext.sql( """ |CREATE TEMPORARY FUNCTION example_max @@ -150,18 +153,23 @@ object SparkSubmitClassLoaderTest extends Logging { hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") source.registerTempTable("sourceTable") // Load a Hive SerDe from the jar. + logInfo("Creating a Hive table with a SerDe provided in a jar.") hiveContext.sql( """ |CREATE TABLE t1(key int, val string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' """.stripMargin) // Actually use the loaded UDF and SerDe. + logInfo("Writing data into the table.") hiveContext.sql( "INSERT INTO TABLE t1 SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") val count = hiveContext.table("t1").orderBy("key", "val").count() if (count != 10) { throw new Exception(s"table t1 should have 10 rows instead of $count rows") } + logInfo("Test finishes.") + sc.stop() } } @@ -199,5 +207,6 @@ object SparkSQLConfTest extends Logging { val hiveContext = new TestHiveContext(sc) // Run a simple command to make sure all lazy vals in hiveContext get instantiated. hiveContext.tables().collect() + sc.stop() } } From 2519dcc33bde3a6d341790d73b5d292ea7af961a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 25 Jun 2015 08:13:17 -0700 Subject: [PATCH 28/44] [MINOR] [MLLIB] rename some functions of PythonMLLibAPI Keep the same naming conventions for PythonMLLibAPI. Only the following three functions is different from others ```scala trainNaiveBayes trainGaussianMixture trainWord2Vec ``` So change them to ```scala trainNaiveBayesModel trainGaussianMixtureModel trainWord2VecModel ``` It does not affect any users and public APIs, only to make better understand for developer and code hacker. Author: Yanbo Liang Closes #7011 from yanboliang/py-mllib-api-rename and squashes the following commits: 771ffec [Yanbo Liang] rename some functions of PythonMLLibAPI --- .../org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 6 +++--- python/pyspark/mllib/classification.py | 2 +- python/pyspark/mllib/clustering.py | 6 +++--- python/pyspark/mllib/feature.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index c4bea7c2cad4f..b16903a8d515c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -278,7 +278,7 @@ private[python] class PythonMLLibAPI extends Serializable { /** * Java stub for NaiveBayes.train() */ - def trainNaiveBayes( + def trainNaiveBayesModel( data: JavaRDD[LabeledPoint], lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) @@ -346,7 +346,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Java stub for Python mllib GaussianMixture.run() * Returns a list containing weights, mean and covariance of each mixture component. */ - def trainGaussianMixture( + def trainGaussianMixtureModel( data: JavaRDD[Vector], k: Int, convergenceTol: Double, @@ -553,7 +553,7 @@ private[python] class PythonMLLibAPI extends Serializable { * @param seed initial seed for random generator * @return A handle to java Word2VecModelWrapper instance at python side */ - def trainWord2Vec( + def trainWord2VecModel( dataJRDD: JavaRDD[java.util.ArrayList[String]], vectorSize: Int, learningRate: Double, diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 2698f10d06883..735d45ba03d27 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -581,7 +581,7 @@ def train(cls, data, lambda_=1.0): first = data.first() if not isinstance(first, LabeledPoint): raise ValueError("`data` should be an RDD of LabeledPoint") - labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_) + labels, pi, theta = callMLlibFunc("trainNaiveBayesModel", data, lambda_) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index e6ef72942ce77..8bc0654c76ca3 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -265,9 +265,9 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia initialModelWeights = initialModel.weights initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] - weight, mu, sigma = callMLlibFunc("trainGaussianMixture", rdd.map(_convert_to_vector), k, - convergenceTol, maxIterations, seed, initialModelWeights, - initialModelMu, initialModelSigma) + weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), + k, convergenceTol, maxIterations, seed, + initialModelWeights, initialModelMu, initialModelSigma) mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] return GaussianMixtureModel(weight, mvg_obj) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 334f5b86cd392..f00bb93b7bf40 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -549,7 +549,7 @@ def fit(self, data): """ if not isinstance(data, RDD): raise TypeError("data should be an RDD of list of string") - jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), + jmodel = callMLlibFunc("trainWord2VecModel", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), int(self.numIterations), int(self.seed), int(self.minCount)) From c392a9efabcb1ec2a2c53f001ecdae33c245ba35 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 25 Jun 2015 10:56:00 -0700 Subject: [PATCH 29/44] [SPARK-8637] [SPARKR] [HOTFIX] Fix packages argument, sparkSubmitBinName cc cafreeman Author: Shivaram Venkataraman Closes #7022 from shivaram/sparkr-init-hotfix and squashes the following commits: 9178d15 [Shivaram Venkataraman] Fix packages argument, sparkSubmitBinName --- R/pkg/R/client.R | 2 +- R/pkg/R/sparkR.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index cf2e5ddeb7a9d..78c7a3037ffac 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -57,7 +57,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack } launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { - sparkSubmitBin <- determineSparkSubmitBin() + sparkSubmitBinName <- determineSparkSubmitBin() if (sparkHome != "") { sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) } else { diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 8f81d5640c1d0..633b869f91784 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -132,7 +132,7 @@ sparkR.init <- function( sparkHome = sparkHome, jars = jars, sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), - sparkPackages = sparkPackages) + packages = sparkPackages) # wait atmost 100 seconds for JVM to launch wait <- 0.1 for (i in 1:25) { From 47c874babe7779c7a2f32e0b891503ef6bebcab0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 25 Jun 2015 22:07:37 -0700 Subject: [PATCH 30/44] [SPARK-8237] [SQL] Add misc function sha2 JIRA: https://issues.apache.org/jira/browse/SPARK-8237 Author: Liang-Chi Hsieh Closes #6934 from viirya/expr_sha2 and squashes the following commits: 35e0bb3 [Liang-Chi Hsieh] For comments. 68b5284 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_sha2 8573aff [Liang-Chi Hsieh] Remove unnecessary Product. ee61e06 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_sha2 59e41aa [Liang-Chi Hsieh] Add misc function: sha2. --- python/pyspark/sql/functions.py | 19 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 98 ++++++++++++++++++- .../expressions/MiscFunctionsSuite.scala | 14 ++- .../org/apache/spark/sql/functions.scala | 20 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 17 ++++ 6 files changed, 165 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cfa87aeea193a..7d3d0361610b7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -42,6 +42,7 @@ 'monotonicallyIncreasingId', 'rand', 'randn', + 'sha2', 'sparkPartitionId', 'struct', 'udf', @@ -363,6 +364,24 @@ def randn(seed=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def sha2(col, numBits): + """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, + and SHA-512). The numBits indicates the desired bit length of the result, which must have a + value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + + >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() + >>> digests[0] + Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') + >>> digests[1] + Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha2(_to_java_column(col), numBits) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5fb3369f85d12..457948a800a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -135,6 +135,7 @@ object FunctionRegistry { // misc functions expression[Md5]("md5"), + expression[Sha2]("sha2"), // aggregate functions expression[Average]("avg"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4bee8cb728e5c..e80706fc65aff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import java.security.MessageDigest +import java.security.NoSuchAlgorithmException + import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{BinaryType, StringType, DataType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType} import org.apache.spark.unsafe.types.UTF8String /** @@ -44,7 +47,96 @@ case class Md5(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => - "org.apache.spark.unsafe.types.UTF8String.fromString" + - s"(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + } +} + +/** + * A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512) + * and returns it as a hex string. The first argument is the string or binary to be hashed. The + * second argument indicates the desired bit length of the result, which must have a value of 224, + * 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If + * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or + * the hash length is not one of the permitted values, the return value is NULL. + */ +case class Sha2(left: Expression, right: Expression) + extends BinaryExpression with Serializable with ExpectsInputTypes { + + override def dataType: DataType = StringType + + override def toString: String = s"SHA2($left, $right)" + + override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType) + + override def eval(input: InternalRow): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val bitLength = evalE2.asInstanceOf[Int] + val input = evalE1.asInstanceOf[Array[Byte]] + bitLength match { + case 224 => + // DigestUtils doesn't support SHA-224 now + try { + val md = MessageDigest.getInstance("SHA-224") + md.update(input) + UTF8String.fromBytes(md.digest()) + } catch { + // SHA-224 is not supported on the system, return null + case noa: NoSuchAlgorithmException => null + } + case 256 | 0 => + UTF8String.fromString(DigestUtils.sha256Hex(input)) + case 384 => + UTF8String.fromString(DigestUtils.sha384Hex(input)) + case 512 => + UTF8String.fromString(DigestUtils.sha512Hex(input)) + case _ => null + } + } + } + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val digestUtils = "org.apache.commons.codec.digest.DigestUtils" + + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + if (${eval2.primitive} == 224) { + try { + java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); + md.update(${eval1.primitive}); + ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest()); + } catch (java.security.NoSuchAlgorithmException e) { + ${ev.isNull} = true; + } + } else if (${eval2.primitive} == 256 || ${eval2.primitive} == 0) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha256Hex(${eval1.primitive})); + } else if (${eval2.primitive} == 384) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha384Hex(${eval1.primitive})); + } else if (${eval2.primitive} == 512) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha512Hex(${eval1.primitive})); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 48b84130b4556..38482c54c61db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.commons.codec.digest.DigestUtils + import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{StringType, BinaryType} +import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType} class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -29,4 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(null, BinaryType)), null) } + test("sha2") { + checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) + checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), + DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6))) + // unsupported bit length + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null) + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null) + checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null) + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 38d9085a505fb..355ce0e3423cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1414,6 +1414,26 @@ object functions { */ def md5(columnName: String): Column = md5(Column(columnName)) + /** + * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha2(e: Column, numBits: Int): Column = { + require(Seq(0, 224, 256, 384, 512).contains(numBits), + s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") + Sha2(e.expr, lit(numBits).expr) + } + + /** + * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits) + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8b53b384a22fd..8baed57a7f129 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -144,6 +144,23 @@ class DataFrameFunctionsSuite extends QueryTest { Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) } + test("misc sha2 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(sha2($"a", 256), sha2("b", 256)), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + checkAnswer( + df.selectExpr("sha2(a, 256)", "sha2(b, 256)"), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + intercept[IllegalArgumentException] { + df.select(sha2($"a", 1024)) + } + } + test("string length function") { checkAnswer( nullStrings.select(strlen($"s"), strlen("s")), From 40360112c417b5432564f4bcb8a9100f4066b55e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 25 Jun 2015 22:16:53 -0700 Subject: [PATCH 31/44] [SPARK-8620] [SQL] cleanup CodeGenContext fix docs, remove nativeTypes , use java type to get boxed type ,default value, etc. to avoid handle `DateType` and `TimestampType` as int and long again and again. Author: Wenchen Fan Closes #7010 from cloud-fan/cg and squashes the following commits: aa01cf9 [Wenchen Fan] cleanup CodeGenContext --- .../spark/sql/catalyst/expressions/Cast.scala | 5 +- .../expressions/codegen/CodeGenerator.scala | 130 +++++++++--------- .../codegen/GenerateProjection.scala | 34 ++--- .../expressions/stringOperations.scala | 1 - 4 files changed, 82 insertions(+), 88 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8bd7fc18a8dd4..8d66968a2fc35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -467,11 +467,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w defineCodeGen(ctx, ev, c => s"!$c.isZero()") case (dt: NumericType, BooleanType) => defineCodeGen(ctx, ev, c => s"$c != 0") - - case (_: DecimalType, IntegerType) => - defineCodeGen(ctx, ev, c => s"($c).toInt()") case (_: DecimalType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") + defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()") case (_: NumericType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 47c5455435ec6..e20e3a9dca502 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -59,6 +59,14 @@ class CodeGenContext { val stringType: String = classOf[UTF8String].getName val decimalType: String = classOf[Decimal].getName + final val JAVA_BOOLEAN = "boolean" + final val JAVA_BYTE = "byte" + final val JAVA_SHORT = "short" + final val JAVA_INT = "int" + final val JAVA_LONG = "long" + final val JAVA_FLOAT = "float" + final val JAVA_DOUBLE = "double" + private val curId = new java.util.concurrent.atomic.AtomicInteger() /** @@ -72,98 +80,94 @@ class CodeGenContext { } /** - * Return the code to access a column for given DataType + * Returns the code to access a column in Row for a given DataType. */ def getColumn(dataType: DataType, ordinal: Int): String = { - if (isNativeType(dataType)) { - s"i.${accessorForType(dataType)}($ordinal)" + val jt = javaType(dataType) + if (isPrimitiveType(jt)) { + s"i.get${primitiveTypeName(jt)}($ordinal)" } else { - s"(${boxedType(dataType)})i.apply($ordinal)" + s"($jt)i.apply($ordinal)" } } /** - * Return the code to update a column in Row for given DataType + * Returns the code to update a column in Row for a given DataType. */ def setColumn(dataType: DataType, ordinal: Int, value: String): String = { - if (isNativeType(dataType)) { - s"${mutatorForType(dataType)}($ordinal, $value)" + val jt = javaType(dataType) + if (isPrimitiveType(jt)) { + s"set${primitiveTypeName(jt)}($ordinal, $value)" } else { s"update($ordinal, $value)" } } /** - * Return the name of accessor in Row for a DataType + * Returns the name used in accessor and setter for a Java primitive type. */ - def accessorForType(dt: DataType): String = dt match { - case IntegerType => "getInt" - case other => s"get${boxedType(dt)}" + def primitiveTypeName(jt: String): String = jt match { + case JAVA_INT => "Int" + case _ => boxedType(jt) } - /** - * Return the name of mutator in Row for a DataType - */ - def mutatorForType(dt: DataType): String = dt match { - case IntegerType => "setInt" - case other => s"set${boxedType(dt)}" - } + def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) /** - * Return the Java type for a DataType + * Returns the Java type for a DataType. */ def javaType(dt: DataType): String = dt match { - case IntegerType => "int" - case LongType => "long" - case ShortType => "short" - case ByteType => "byte" - case DoubleType => "double" - case FloatType => "float" - case BooleanType => "boolean" + case BooleanType => JAVA_BOOLEAN + case ByteType => JAVA_BYTE + case ShortType => JAVA_SHORT + case IntegerType => JAVA_INT + case LongType => JAVA_LONG + case FloatType => JAVA_FLOAT + case DoubleType => JAVA_DOUBLE case dt: DecimalType => decimalType case BinaryType => "byte[]" case StringType => stringType - case DateType => "int" - case TimestampType => "long" + case DateType => JAVA_INT + case TimestampType => JAVA_LONG case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" } /** - * Return the boxed type in Java + * Returns the boxed type in Java. */ - def boxedType(dt: DataType): String = dt match { - case IntegerType => "Integer" - case LongType => "Long" - case ShortType => "Short" - case ByteType => "Byte" - case DoubleType => "Double" - case FloatType => "Float" - case BooleanType => "Boolean" - case DateType => "Integer" - case TimestampType => "Long" - case _ => javaType(dt) + def boxedType(jt: String): String = jt match { + case JAVA_BOOLEAN => "Boolean" + case JAVA_BYTE => "Byte" + case JAVA_SHORT => "Short" + case JAVA_INT => "Integer" + case JAVA_LONG => "Long" + case JAVA_FLOAT => "Float" + case JAVA_DOUBLE => "Double" + case other => other } + def boxedType(dt: DataType): String = boxedType(javaType(dt)) + /** - * Return the representation of default value for given DataType + * Returns the representation of default value for a given Java Type. */ - def defaultValue(dt: DataType): String = dt match { - case BooleanType => "false" - case FloatType => "-1.0f" - case ShortType => "(short)-1" - case LongType => "-1L" - case ByteType => "(byte)-1" - case DoubleType => "-1.0" - case IntegerType => "-1" - case DateType => "-1" - case TimestampType => "-1L" + def defaultValue(jt: String): String = jt match { + case JAVA_BOOLEAN => "false" + case JAVA_BYTE => "(byte)-1" + case JAVA_SHORT => "(short)-1" + case JAVA_INT => "-1" + case JAVA_LONG => "-1L" + case JAVA_FLOAT => "-1.0f" + case JAVA_DOUBLE => "-1.0" case _ => "null" } + def defaultValue(dt: DataType): String = defaultValue(javaType(dt)) + /** - * Generate code for equal expression in Java + * Generates code for equal expression in Java. */ def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { case BinaryType => s"java.util.Arrays.equals($c1, $c2)" @@ -172,7 +176,7 @@ class CodeGenContext { } /** - * Generate code for compare expression in Java + * Generates code for compare expression in Java. */ def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { // java boolean doesn't support > or < operator @@ -184,25 +188,17 @@ class CodeGenContext { } /** - * List of data types that have special accessors and setters in [[InternalRow]]. + * List of java data types that have special accessors and setters in [[InternalRow]]. */ - val nativeTypes = - Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) + val primitiveTypes = + Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) /** - * Returns true if the data type has a special accessor and setter in [[InternalRow]]. + * Returns true if the Java type has a special accessor and setter in [[InternalRow]]. */ - def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt) + def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) - /** - * List of data types who's Java type is primitive type - */ - val primitiveTypes = nativeTypes ++ Seq(DateType, TimestampType) - - /** - * Returns true if the Java type is primitive type - */ - def isPrimitiveType(dt: DataType): Boolean = primitiveTypes.contains(dt) + def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index e362625469e29..624e1cf4e201a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -72,54 +72,56 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" }.mkString("\n ") - val specificAccessorFunctions = ctx.nativeTypes.map { dataType => + val specificAccessorFunctions = ctx.primitiveTypes.map { jt => val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) => - List(s"case $i: return c$i;") - case _ => Nil + case (e, i) if ctx.javaType(e.dataType) == jt => + Some(s"case $i: return c$i;") + case _ => None }.mkString("\n ") if (cases.length > 0) { + val getter = "get" + ctx.primitiveTypeName(jt) s""" @Override - public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) { + public $jt $getter(int i) { if (isNullAt(i)) { - return ${ctx.defaultValue(dataType)}; + return ${ctx.defaultValue(jt)}; } switch (i) { $cases } throw new IllegalArgumentException("Invalid index: " + i - + " in ${ctx.accessorForType(dataType)}"); + + " in $getter"); }""" } else { "" } - }.mkString("\n") + }.filter(_.length > 0).mkString("\n") - val specificMutatorFunctions = ctx.nativeTypes.map { dataType => + val specificMutatorFunctions = ctx.primitiveTypes.map { jt => val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) => - List(s"case $i: { c$i = value; return; }") - case _ => Nil + case (e, i) if ctx.javaType(e.dataType) == jt => + Some(s"case $i: { c$i = value; return; }") + case _ => None }.mkString("\n ") if (cases.length > 0) { + val setter = "set" + ctx.primitiveTypeName(jt) s""" @Override - public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) { + public void $setter(int i, $jt value) { nullBits[i] = false; switch (i) { $cases } throw new IllegalArgumentException("Invalid index: " + i + - " in ${ctx.mutatorForType(dataType)}"); + " in $setter}"); }""" } else { "" } - }.mkString("\n") + }.filter(_.length > 0).mkString("\n") val hashValues = expressions.zipWithIndex.map { case (e, i) => - val col = newTermName(s"c$i") + val col = s"c$i" val nonNull = e.dataType match { case BooleanType => s"$col ? 0 : 1" case ByteType | ShortType | IntegerType | DateType => s"$col" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 44416e79cd7aa..a6225fdafedde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.expressions.Substring import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String From 1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 25 Jun 2015 22:44:26 -0700 Subject: [PATCH 32/44] [SPARK-8635] [SQL] improve performance of CatalystTypeConverters In `CatalystTypeConverters.createToCatalystConverter`, we add special handling for primitive types. We can apply this strategy to more places to improve performance. Author: Wenchen Fan Closes #7018 from cloud-fan/converter and squashes the following commits: 8b16630 [Wenchen Fan] another fix 326c82c [Wenchen Fan] optimize type converter --- .../sql/catalyst/CatalystTypeConverters.scala | 60 ++++++++++++------- .../sql/catalyst/expressions/ScalaUdf.scala | 3 +- .../org/apache/spark/sql/DataFrame.scala | 4 +- .../sql/execution/stat/FrequentItems.scala | 2 +- .../sql/execution/stat/StatFunctions.scala | 2 +- .../sql/sources/DataSourceStrategy.scala | 2 +- .../apache/spark/sql/sources/commands.scala | 4 +- .../spark/sql/sources/TableScanSuite.scala | 4 +- 8 files changed, 48 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 429fc4077be9a..012f8bbecb4d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -52,6 +52,13 @@ object CatalystTypeConverters { } } + private def isWholePrimitive(dt: DataType): Boolean = dt match { + case dt if isPrimitive(dt) => true + case ArrayType(elementType, _) => isWholePrimitive(elementType) + case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType) + case _ => false + } + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { val converter = dataType match { case udt: UserDefinedType[_] => UDTConverter(udt) @@ -148,6 +155,8 @@ object CatalystTypeConverters { private[this] val elementConverter = getConverterForType(elementType) + private[this] val isNoChange = isWholePrimitive(elementType) + override def toCatalystImpl(scalaValue: Any): Seq[Any] = { scalaValue match { case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) @@ -166,8 +175,10 @@ object CatalystTypeConverters { override def toScala(catalystValue: Seq[Any]): Seq[Any] = { if (catalystValue == null) { null + } else if (isNoChange) { + catalystValue } else { - catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala) + catalystValue.map(elementConverter.toScala) } } @@ -183,6 +194,8 @@ object CatalystTypeConverters { private[this] val keyConverter = getConverterForType(keyType) private[this] val valueConverter = getConverterForType(valueType) + private[this] val isNoChange = isWholePrimitive(keyType) && isWholePrimitive(valueType) + override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match { case m: Map[_, _] => m.map { case (k, v) => @@ -203,6 +216,8 @@ object CatalystTypeConverters { override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = { if (catalystValue == null) { null + } else if (isNoChange) { + catalystValue } else { catalystValue.map { case (k, v) => keyConverter.toScala(k) -> valueConverter.toScala(v) @@ -258,16 +273,13 @@ object CatalystTypeConverters { toScala(row(column).asInstanceOf[InternalRow]) } - private object StringConverter extends CatalystTypeConverter[Any, String, Any] { + private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 } - override def toScala(catalystValue: Any): String = catalystValue match { - case null => null - case str: String => str - case utf8: UTF8String => utf8.toString() - } + override def toScala(catalystValue: UTF8String): String = + if (catalystValue == null) null else catalystValue.toString override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString } @@ -275,7 +287,8 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue) override def toScala(catalystValue: Any): Date = if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int]) - override def toScalaImpl(row: InternalRow, column: Int): Date = toScala(row.getInt(column)) + override def toScalaImpl(row: InternalRow, column: Int): Date = + DateTimeUtils.toJavaDate(row.getInt(column)) } private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] { @@ -285,7 +298,7 @@ object CatalystTypeConverters { if (catalystValue == null) null else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long]) override def toScalaImpl(row: InternalRow, column: Int): Timestamp = - toScala(row.getLong(column)) + DateTimeUtils.toJavaTimestamp(row.getLong(column)) } private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { @@ -296,10 +309,7 @@ object CatalystTypeConverters { } override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = - row.get(column) match { - case d: JavaBigDecimal => d - case d: Decimal => d.toJavaBigDecimal - } + row.get(column).asInstanceOf[Decimal].toJavaBigDecimal } private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { @@ -362,6 +372,19 @@ object CatalystTypeConverters { } } + /** + * Creates a converter function that will convert Catalyst types to Scala type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { + if (isPrimitive(dataType)) { + identity + } else { + getConverterForType(dataType).toScala + } + } + /** * Converts Scala objects to Catalyst rows / types. * @@ -389,15 +412,6 @@ object CatalystTypeConverters { * produced by createToScalaConverter. */ def convertToScala(catalystValue: Any, dataType: DataType): Any = { - getConverterForType(dataType).toScala(catalystValue) - } - - /** - * Creates a converter function that will convert Catalyst types to Scala type. - * Typical use case would be converting a collection of rows that have the same schema. You will - * call this function once to get a converter, and apply it to every row. - */ - private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { - getConverterForType(dataType).toScala + createToScalaConverter(dataType)(catalystValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 3992f1f59dad8..55df72f102295 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types.DataType @@ -39,7 +38,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) - lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) + val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _) s"""case $x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f3f0f5305318e..0db4df34f9e22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1418,12 +1418,14 @@ class DataFrame private[sql]( lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema - queryExecution.executedPlan.execute().mapPartitions { rows => + internalRowRdd.mapPartitions { rows => val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]) } } + private[sql] def internalRowRdd = queryExecution.executedPlan.execute() + /** * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 8df1da037c434..3ebbf96090a55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -90,7 +90,7 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) } - val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).internalRowRdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 93383e5a62f11..252c611d02ebc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -81,7 +81,7 @@ private[sql] object StatFunctions extends Logging { s"with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) - df.select(columns: _*).rdd.aggregate(new CovarianceCounter)( + df.select(columns: _*).internalRowRdd.aggregate(new CovarianceCounter)( seqOp = (counter, row) => { counter.add(row.getDouble(0), row.getDouble(1)) }, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index a8f56f4767407..ce16e050c56ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -313,7 +313,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { output: Seq[Attribute], rdd: RDD[Row]): RDD[InternalRow] = { if (relation.relation.needConversion) { - execution.RDDConversions.rowToRowRdd(rdd.asInstanceOf[RDD[Row]], output.map(_.dataType)) + execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { rdd.map(_.asInstanceOf[InternalRow]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index fb6173f58ece6..dbb369cf45502 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -154,7 +154,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) + df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => @@ -220,7 +220,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) + df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 79eac930e54f7..de0ed0c0427a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -88,9 +88,9 @@ case class AllDataTypesScan( UTF8String.fromString(s"varchar_$i"), Seq(i, i + 1), Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), - Map(i -> i.toString), + Map(i -> UTF8String.fromString(i.toString)), Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), - Row(i, i.toString), + Row(i, UTF8String.fromString(i.toString)), Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) } From 9fed6abfdcb7afcf92be56e5ccbed6599fe66bc4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 26 Jun 2015 00:12:05 -0700 Subject: [PATCH 33/44] [SPARK-8344] Add message processing time metric to DAGScheduler This commit adds a new metric, `messageProcessingTime`, to the DAGScheduler metrics source. This metrics tracks the time taken to process messages in the scheduler's event processing loop, which is a helpful debugging aid for diagnosing performance issues in the scheduler (such as SPARK-4961). In order to do this, I moved the creation of the DAGSchedulerSource metrics source into DAGScheduler itself, similar to how MasterSource is created and registered in Master. Author: Josh Rosen Closes #7002 from JoshRosen/SPARK-8344 and squashes the following commits: 57f914b [Josh Rosen] Fix import ordering 7d6bb83 [Josh Rosen] Add message processing time metrics to DAGScheduler --- .../scala/org/apache/spark/SparkContext.scala | 1 - .../apache/spark/scheduler/DAGScheduler.scala | 18 ++++++++++++++++-- .../spark/scheduler/DAGSchedulerSource.scala | 8 ++++++-- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 141276ac901fb..c7a7436462083 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -545,7 +545,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Post init _taskScheduler.postStartHook() - _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) _executorAllocationManager.foreach { e => _env.metricsSystem.registerSource(e.executorAllocationManagerSource) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index aea6674ed20be..b00a5fee09bf2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -81,6 +81,8 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) + private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) @@ -1438,17 +1440,29 @@ class DAGScheduler( taskScheduler.stop() } - // Start the event thread at the end of the constructor + // Start the event thread and register the metrics source at the end of the constructor + env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler) extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging { + private[this] val timer = dagScheduler.metricsSource.messageProcessingTimer + /** * The main event loop of the DAG scheduler. */ - override def onReceive(event: DAGSchedulerEvent): Unit = event match { + override def onReceive(event: DAGSchedulerEvent): Unit = { + val timerContext = timer.time() + try { + doOnReceive(event) + } finally { + timerContext.stop() + } + } + + private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 02c67073af6a0..6b667d5d7645b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -17,11 +17,11 @@ package org.apache.spark.scheduler -import com.codahale.metrics.{Gauge, MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry, Timer} import org.apache.spark.metrics.source.Source -private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) +private[scheduler] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source { override val metricRegistry = new MetricRegistry() override val sourceName = "DAGScheduler" @@ -45,4 +45,8 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] { override def getValue: Int = dagScheduler.activeJobs.size }) + + /** Timer that tracks the time to process messages in the DAGScheduler's event loop */ + val messageProcessingTimer: Timer = + metricRegistry.timer(MetricRegistry.name("messageProcessingTime")) } From c9e05a315a96fbf3026a2b3c6934dd2dec420099 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 26 Jun 2015 01:19:05 -0700 Subject: [PATCH 34/44] [SPARK-8613] [ML] [TRIVIAL] add param to disable linear feature scaling Add a param to disable linear feature scaling (to be implemented later in linear & logistic regression). Done as a seperate PR so we can use same param & not conflict while working on the sub-tasks. Author: Holden Karau Closes #7024 from holdenk/SPARK-8522-Disable-Linear_featureScaling-Spark-8613-Add-param and squashes the following commits: ce8931a [Holden Karau] Regenerate the sharedParams code fa6427e [Holden Karau] update text for standardization param. 7b24a2b [Holden Karau] generate the new standardization param 3c190af [Holden Karau] Add the standardization param to sharedparamscodegen --- .../ml/param/shared/SharedParamsCodeGen.scala | 3 +++ .../spark/ml/param/shared/sharedParams.scala | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 8ffbcf0d8bc71..b0a6af171c01f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -53,6 +53,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[Boolean]("standardization", "whether to standardize the training features" + + " prior to fitting the model sequence. Note that the coefficients of models are" + + " always returned on the original scale.", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a0c8ccdac9ad9..bbe08939b6d75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -233,6 +233,23 @@ private[ml] trait HasFitIntercept extends Params { final def getFitIntercept: Boolean = $(fitIntercept) } +/** + * (private[ml]) Trait for shared param standardization (default: true). + */ +private[ml] trait HasStandardization extends Params { + + /** + * Param for whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale.. + * @group param + */ + final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale.") + + setDefault(standardization, true) + + /** @group getParam */ + final def getStandardization: Boolean = $(standardization) +} + /** * (private[ml]) Trait for shared param seed (default: this.getClass.getName.hashCode.toLong). */ From 37bf76a2de2143ec6348a3d43b782227849520cc Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 26 Jun 2015 08:45:22 -0500 Subject: [PATCH 35/44] [SPARK-8302] Support heterogeneous cluster install paths on YARN. Some users have Hadoop installations on different paths across their cluster. Currently, that makes it hard to set up some configuration in Spark since that requires hardcoding paths to jar files or native libraries, which wouldn't work on such a cluster. This change introduces a couple of YARN-specific configurations that instruct the backend to replace certain paths when launching remote processes. That way, if the configuration says the Spark jar is in "/spark/spark.jar", and also says that "/spark" should be replaced with "{{SPARK_INSTALL_DIR}}", YARN will start containers in the NMs with "{{SPARK_INSTALL_DIR}}/spark.jar" as the location of the jar. Coupled with YARN's environment whitelist (which allows certain env variables to be exposed to containers), this allows users to support such heterogeneous environments, as long as a single replacement is enough. (Otherwise, this feature would need to be extended to support multiple path replacements.) Author: Marcelo Vanzin Closes #6752 from vanzin/SPARK-8302 and squashes the following commits: 4bff8d4 [Marcelo Vanzin] Add docs, rename configs. 0aa2a02 [Marcelo Vanzin] Only do replacement for paths that need it. 2e9cc9d [Marcelo Vanzin] Style. a5e1f68 [Marcelo Vanzin] [SPARK-8302] Support heterogeneous cluster install paths on YARN. --- docs/running-on-yarn.md | 26 ++++++++++ .../org/apache/spark/deploy/yarn/Client.scala | 47 +++++++++++++++---- .../spark/deploy/yarn/ExecutorRunnable.scala | 4 +- .../spark/deploy/yarn/ClientSuite.scala | 19 ++++++++ 4 files changed, 84 insertions(+), 12 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 96cf612c54fdd..3f8a093bbe957 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -258,6 +258,32 @@ Most of the configs are the same for Spark on YARN as for other deployment modes Principal to be used to login to KDC, while running on secure HDFS. + + spark.yarn.config.gatewayPath + (none) + + A path that is valid on the gateway host (the host where a Spark application is started) but may + differ for paths for the same resource in other nodes in the cluster. Coupled with + spark.yarn.config.replacementPath, this is used to support clusters with + heterogeneous configurations, so that Spark can correctly launch remote processes. +

+ The replacement path normally will contain a reference to some environment variable exported by + YARN (and, thus, visible to Spark containers). +

+ For example, if the gateway node has Hadoop libraries installed on /disk1/hadoop, and + the location of the Hadoop install is exported by YARN as the HADOOP_HOME + environment variable, setting this value to /disk1/hadoop and the replacement path to + $HADOOP_HOME will make sure that paths used to launch remote processes properly + reference the local YARN configuration. + + + + spark.yarn.config.replacementPath + (none) + + See spark.yarn.config.gatewayPath. + + # Launching Spark on YARN diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index da1ec2a0fe2e9..67a5c95400e53 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -676,7 +676,7 @@ private[spark] class Client( val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { - prefixEnv = Some(Utils.libraryPathEnvPrefix(libraryPaths)) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) } if (sparkConf.getOption("spark.yarn.am.extraJavaOptions").isDefined) { logWarning("spark.yarn.am.extraJavaOptions will not take effect in cluster mode") @@ -698,7 +698,7 @@ private[spark] class Client( } sparkConf.getOption("spark.yarn.am.extraLibraryPath").foreach { paths => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(paths))) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) } } @@ -1106,10 +1106,10 @@ object Client extends Logging { env: HashMap[String, String], isAM: Boolean, extraClassPath: Option[String] = None): Unit = { - extraClassPath.foreach(addClasspathEntry(_, env)) - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env - ) + extraClassPath.foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } + addClasspathEntry(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env) if (isAM) { addClasspathEntry( @@ -1125,12 +1125,14 @@ object Client extends Logging { getUserClasspath(sparkConf) } userClassPath.foreach { x => - addFileToClasspath(x, null, env) + addFileToClasspath(sparkConf, x, null, env) } } - addFileToClasspath(new URI(sparkJar(sparkConf)), SPARK_JAR, env) + addFileToClasspath(sparkConf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) populateHadoopClasspath(conf, env) - sys.env.get(ENV_DIST_CLASSPATH).foreach(addClasspathEntry(_, env)) + sys.env.get(ENV_DIST_CLASSPATH).foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } } /** @@ -1159,16 +1161,18 @@ object Client extends Logging { * * If not a "local:" file and no alternate name, the environment is not modified. * + * @parma conf Spark configuration. * @param uri URI to add to classpath (optional). * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. */ private def addFileToClasspath( + conf: SparkConf, uri: URI, fileName: String, env: HashMap[String, String]): Unit = { if (uri != null && uri.getScheme == LOCAL_SCHEME) { - addClasspathEntry(uri.getPath, env) + addClasspathEntry(getClusterPath(conf, uri.getPath), env) } else if (fileName != null) { addClasspathEntry(buildPath( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env) @@ -1182,6 +1186,29 @@ object Client extends Logging { private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit = YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path) + /** + * Returns the path to be sent to the NM for a path that is valid on the gateway. + * + * This method uses two configuration values: + * + * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may + * only be valid in the gateway node. + * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may + * contain, for example, env variable references, which will be expanded by the NMs when + * starting containers. + * + * If either config is not available, the input path is returned. + */ + def getClusterPath(conf: SparkConf, path: String): String = { + val localPath = conf.get("spark.yarn.config.gatewayPath", null) + val clusterPath = conf.get("spark.yarn.config.replacementPath", null) + if (localPath != null && clusterPath != null) { + path.replace(localPath, clusterPath) + } else { + path + } + } + /** * Obtains token for the Hive metastore and adds them to the credentials. */ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index b0937083bc536..78e27fb7f3337 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -146,7 +146,7 @@ class ExecutorRunnable( javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.props.get("spark.executor.extraLibraryPath").foreach { p => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) + prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) } javaOpts += "-Djava.io.tmpdir=" + @@ -195,7 +195,7 @@ class ExecutorRunnable( val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri => val absPath = if (new File(uri.getPath()).isAbsolute()) { - uri.getPath() + Client.getClusterPath(sparkConf, uri.getPath()) } else { Client.buildPath(Environment.PWD.$(), uri.getPath()) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 4ec976aa31387..837f8d3fa55a7 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -151,6 +151,25 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { } } + test("Cluster path translation") { + val conf = new Configuration() + val sparkConf = new SparkConf() + .set(Client.CONF_SPARK_JAR, "local:/localPath/spark.jar") + .set("spark.yarn.config.gatewayPath", "/localPath") + .set("spark.yarn.config.replacementPath", "/remotePath") + + Client.getClusterPath(sparkConf, "/localPath") should be ("/remotePath") + Client.getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be ( + "/remotePath/1:/remotePath/2") + + val env = new MutableHashMap[String, String]() + Client.populateClasspath(null, conf, sparkConf, env, false, + extraClassPath = Some("/localPath/my1.jar")) + val cp = classpath(env) + cp should contain ("/remotePath/spark.jar") + cp should contain ("/remotePath/my1.jar") + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = From 41afa16500e682475eaa80e31c0434b7ab66abcb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 26 Jun 2015 08:12:22 -0700 Subject: [PATCH 36/44] [SPARK-8652] [PYSPARK] Check return value for all uses of doctest.testmod() This patch addresses a critical issue in the PySpark tests: Several of our Python modules' `__main__` methods call `doctest.testmod()` in order to run doctests but forget to check and handle its return value. As a result, some PySpark test failures can go unnoticed because they will not fail the build. Fortunately, there was only one test failure which was masked by this bug: a `pyspark.profiler` doctest was failing due to changes in RDD pipelining. Author: Josh Rosen Closes #7032 from JoshRosen/testmod-fix and squashes the following commits: 60dbdc0 [Josh Rosen] Account for int vs. long formatting change in Python 3 8b8d80a [Josh Rosen] Fix failing test. e6423f9 [Josh Rosen] Check return code for all uses of doctest.testmod(). --- dev/merge_spark_pr.py | 4 +++- python/pyspark/accumulators.py | 4 +++- python/pyspark/broadcast.py | 4 +++- python/pyspark/heapq3.py | 5 +++-- python/pyspark/profiler.py | 8 ++++++-- python/pyspark/serializers.py | 8 +++++--- python/pyspark/shuffle.py | 4 +++- python/pyspark/streaming/util.py | 4 +++- 8 files changed, 29 insertions(+), 12 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index cd83b352c1bfb..cf827ce89b857 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -431,6 +431,8 @@ def main(): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) main() diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index adca90ddaf397..6ef8cf53cc747 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -264,4 +264,6 @@ def _start_update_server(): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 3de4615428bb6..663c9abe0881e 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -115,4 +115,6 @@ def __reduce__(self): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index 4ef2afe03544f..b27e91a4cc251 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -883,6 +883,7 @@ def nlargest(n, iterable, key=None): if __name__ == "__main__": - import doctest - print(doctest.testmod()) + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index d18daaabfcb3c..44d17bd629473 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -90,9 +90,11 @@ class Profiler(object): >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler) >>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10) [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + >>> sc.parallelize(range(1000)).count() + 1000 >>> sc.show_profiles() My custom profiles for RDD:1 - My custom profiles for RDD:2 + My custom profiles for RDD:3 >>> sc.stop() """ @@ -169,4 +171,6 @@ def stats(self): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 7f9d0a338d31e..411b4dbf481f1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -44,8 +44,8 @@ >>> rdd.glom().collect() [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ->>> rdd._jrdd.count() -8L +>>> int(rdd._jrdd.count()) +8 >>> sc.stop() """ @@ -556,4 +556,6 @@ def write_with_length(obj, stream): if __name__ == '__main__': import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 67752c0d150b9..8fb71bac64a5e 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -838,4 +838,6 @@ def load_partition(j): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 34291f30a5652..a9bfec2aab8fc 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -125,4 +125,6 @@ def rddToFileName(prefix, suffix, timestamp): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) From a56516fc9280724db8fdef8e7d109ed7e28e427d Mon Sep 17 00:00:00 2001 From: cafreeman Date: Fri, 26 Jun 2015 10:07:35 -0700 Subject: [PATCH 37/44] [SPARK-8662] SparkR Update SparkSQL Test Test `infer_type` using a more fine-grained approach rather than comparing environments. Since `all.equal`'s behavior has changed in R 3.2, the test became unpassable. JIRA here: https://issues.apache.org/jira/browse/SPARK-8662 Author: cafreeman Closes #7045 from cafreeman/R32_Test and squashes the following commits: b97cc52 [cafreeman] Add `checkStructField` utility 3381e5c [cafreeman] Update SparkSQL Test (cherry picked from commit 78b31a2a630c2178987322d0221aeea183ec565f) Signed-off-by: Shivaram Venkataraman --- R/pkg/inst/tests/test_sparkSQL.R | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 417153dc0985c..6a08f894313c4 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -19,6 +19,14 @@ library(testthat) context("SparkSQL functions") +# Utility function for easily checking the values of a StructField +checkStructField <- function(actual, expectedName, expectedType, expectedNullable) { + expect_equal(class(actual), "structField") + expect_equal(actual$name(), expectedName) + expect_equal(actual$dataType.toString(), expectedType) + expect_equal(actual$nullable(), expectedNullable) +} + # Tests for SparkSQL functions in SparkR sc <- sparkR.init() @@ -52,9 +60,10 @@ test_that("infer types", { list(type = 'array', elementType = "integer", containsNull = TRUE)) expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) - expect_equal(infer_type(list(a = 1L, b = "2")), - structType(structField(x = "a", type = "integer", nullable = TRUE), - structField(x = "b", type = "string", nullable = TRUE))) + testStruct <- infer_type(list(a = 1L, b = "2")) + expect_true(class(testStruct) == "structType") + checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) + checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), From 9d11817765e2817b11b73c61bae3b32c9f119cfd Mon Sep 17 00:00:00 2001 From: cafreeman Date: Fri, 26 Jun 2015 17:06:02 -0700 Subject: [PATCH 38/44] [SPARK-8607] SparkR -- jars not being added to application classpath correctly Add `getStaticClass` method in SparkR's `RBackendHandler` This is a fix for the problem referenced in [SPARK-5185](https://issues.apache.org/jira/browse/SPARK-5185). cc shivaram Author: cafreeman Closes #7001 from cafreeman/branch-1.4 and squashes the following commits: 8f81194 [cafreeman] Add missing license 31aedcf [cafreeman] Refactor test to call an external R script 2c22073 [cafreeman] Merge branch 'branch-1.4' of github.com:apache/spark into branch-1.4 0bea809 [cafreeman] Fixed relative path issue and added smaller JAR ee25e60 [cafreeman] Merge branch 'branch-1.4' of github.com:apache/spark into branch-1.4 9a5c362 [cafreeman] test for including JAR when launching sparkContext 9101223 [cafreeman] Merge branch 'branch-1.4' of github.com:apache/spark into branch-1.4 5a80844 [cafreeman] Fix style nits 7c6bd0c [cafreeman] [SPARK-8607] SparkR (cherry picked from commit 2579948bf5d89ac2d822ace605a6a4afce5258d6) Signed-off-by: Shivaram Venkataraman --- .../test_support/sparktestjar_2.10-1.0.jar | Bin 0 -> 2886 bytes R/pkg/inst/tests/jarTest.R | 32 +++++++++++++++ R/pkg/inst/tests/test_includeJAR.R | 37 ++++++++++++++++++ .../apache/spark/api/r/RBackendHandler.scala | 17 +++++++- 4 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar create mode 100644 R/pkg/inst/tests/jarTest.R create mode 100644 R/pkg/inst/tests/test_includeJAR.R diff --git a/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar new file mode 100644 index 0000000000000000000000000000000000000000..1d5c2af631aa3ae88aa7836e8db598e59cbcf1b7 GIT binary patch literal 2886 zcmaJ@2T)UK7Y!kyL@*d49Sj{JRS@YAsv!nq=*S|~fYJj-g+=LYfzbORNV8FN{bdCk zg0zIbAfgZyq{;FFilHpKi8@1Yb?=)u^SycZJMX@^=brE2Fzg^WfQyR@ut^-V0I&oc z00Lmm?NG{SYYSB@${KB9ZfmE4wbuIiGP0M_cNecVtU;Sur6_lz zsaWb^v=SR+A;CLuy3$1vE+`}5lL)YnkQMN#MIA*nA&%;3C~FAI<>zOe_O2pl73Db< z*~X|~OI8rlVFtFrMKPnCLMpTwAOMHqFeX~AEe^t??EK`oE#4vGUh9FCIkePXoe5stL%}& zNVAhqLvP+N2Ki!4axnH{pCBmT>;l@Gm+pFq9mvWPc4#D^ejDn^&%HvUK^OB z)EgN^0iUSck}R0#%xqa6FDyH=u4lyK#mnchVHF8GuTYWvvwq8}Wg!P7M*Y|;8%rrT zSMkG(9`ZZmI9>ms=PmBAWd9<1I7oL}szYkqA=5BaS>~7Fg4Fu;I@PigC*Z->SEaFo z)mA(U+Et_0mnO!HE930f3)`UwV_k-Irmgy4i(gJx=yLphKC9Z|?D+e)*g9|J@Rsv7 zk~1Pi$~vua^rxVv3bbNK;S<5F#?d~J`xnxM^8*4`fzJ2NyI^8D1%q_oI-HdKqvn++ z@j;zGC;;QDtGL>2zA`gW*&5@)PLxO097`UtpP@;T8Uuo*gl5_07hrMQ`oFAcb9Zpv zU;eNo|IP=^hg}#?g_&HGTvq1g0WHlng}gi$D0%CHfbvHm+#?yqW(9U9_w{s}bsD60 zJbKlZTGQ2eSt2gWZfi(kU)^7K5xfcIeb*Fv%>>#q_2YJy&1isVF^WX@#n6hj6z{`$ z%jC^tPLH<4yHC$bm@gpo;104PgXP`zRac*-ITr)B!AzcE1SyllYyf~A@C(nrPYj{& z5kuw+Gpg{PnPDKR7ye#IBnz)F3M|+5yB;wmr()vg;(b}QHGS>#( z3vS3(@8`Xdz_ZHZ5;L)1`Z2Y^eK=$BWymyG{N|B)PX$>RK0=zub0-sStx+Pv>1NSF zCCw=kgsKW_ELQGn0T9Y&+I!bpHCO`y5xKugYBh#)_91|y|bv3&bbNn zhl?VRCe)^$J}q+=Qe+qHvz>9J*ohAqsNP^``th`sTtvmSm>alzorsQR37?+mp&3oj zY& z-k5K3sz5^m;*OY4n|5OT1|(XAJm)oLyn_?h$3hGPO(hQPO*jb4N2-&N?h;NAI z=|S}KvIr0K3iAsJ{7SfRa*uVZF+Ab#XH`{T@KaGilFF5MkgL^;S{WGyZiRC-RlWyO z=G$vROpZ`kR$=NeU#AyYCMvqt$Fy_Aie4LY#jBp$cw7PqDEPtTh@#xpf`=2-BF zNjKdZWD>bJ*>Rv?-)y70+L^rhn@RkT%g8=`TM8KUP}ueLr+ORi+?~H)>@&F$&>PUP zt`~RVo|c@#Qa{x=v06Gob2O>Tdzy1dQPO$o<5>ff*2{Mcw5%I^VYT-&l@prsYNVsX zI3RXuY@8*{uGUH>j~-8trp2Gz=N0(6E0(-8o*pi$#M6@|wdpI|RQ<=js%)_?XnEIF z5HtmslKq+Jh;CpZU$T;;eF? zxpCQ&39>;eWaswZk0LCAGCDUJ2Jb_)35SFHJ|10PT-J6CIOVOCTx+gwW=oG1vhuKp z`HXD|GMhCjwbXK&p$fcvRN>oZ9r~|keG!N7%#%Sd6ki7+&^etwJrwvqoSM$1N*KS2 zyVq^A?D&|c{p8Sqd+4LiG{mX-(qIxc^8cb6)3H!PI@>lVO62We8|R@1UGFJIYdx3m zYCPWArgH<7Usxz`l+icezS#d7@teFp-%!*v)^sWYi7;6&nhcHTdKhm|MDoF3kZT-3_HO6kWVS)v`toS%e?3*2|~%cZro` z)VXA?Azv-4#lxOOl5G8Ij-%eTgldL}ZV-}0)X%oM3#AHbR)oLKXTqvz+lZ{3N@(p3yvbfn5G-99m*{uhm!9X}zvQGyzicGsj=}Bjwy);9H?Ff==SdETklfOmO6)|W6!QWPkfel(kTtipW5d1 z@LrE>9H@+1qE^&56Lbj5@sZ7BeuDcHvXXB&dthawfgqWa`03h`<{Wtd!F$TDA5oZ- z8_l=4sUaviPan9nu-=YOd855*67tq9$@oN`%9_5>v?#yMX4~T`V}5wj(|1=tWTzjQ ztwaJc#Q{sA-k9desi&iQNycsy5F9t4QbsJloP&HNu~-XC-^XN1_zY>^YX(ysKQo05 z2nXz*AgmsSX{+|ek4zR0!v=%^e(ZO4QEu<@@4q%N{m*U;GL~O0(^ogNw`kS_k?Dta zW1F#L-O1vPn4f3;b5^lqo}IgKkRgBn0{JRztSHP`W1T|8E(Bv01>TGDJ(>I#jkQzE i$=!{^3>(Q>(;l=hbBx1)IhY$b8P_ + val clsContext = Class.forName(objId, true, Thread.currentThread().getContextClassLoader) + clsContext + } + } + def handleMethodCall( isStatic: Boolean, objId: String, @@ -98,7 +113,7 @@ private[r] class RBackendHandler(server: RBackend) var obj: Object = null try { val cls = if (isStatic) { - Class.forName(objId) + getStaticClass(objId) } else { JVMObjectTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) From b5a6663da28198c905df27534cd123360a9bbef1 Mon Sep 17 00:00:00 2001 From: Rosstin Date: Sat, 27 Jun 2015 08:47:00 +0300 Subject: [PATCH 39/44] [SPARK-8639] [DOCS] Fixed Minor Typos in Documentation Ticket: [SPARK-8639](https://issues.apache.org/jira/browse/SPARK-8639) fixed minor typos in docs/README.md and docs/api.md Author: Rosstin Closes #7046 from Rosstin/SPARK-8639 and squashes the following commits: 6c18058 [Rosstin] fixed minor typos in docs/README.md and docs/api.md --- docs/README.md | 2 +- docs/api.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/README.md b/docs/README.md index 5852f972a051d..d7652e921f7df 100644 --- a/docs/README.md +++ b/docs/README.md @@ -28,7 +28,7 @@ in some cases: $ sudo gem install jekyll $ sudo gem install jekyll-redirect-from -Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory +Execute `jekyll build` from the `docs/` directory to compile the site. Compiling the site with Jekyll will create a directory called `_site` containing index.html as well as the rest of the compiled files. You can modify the default Jekyll build as follows: diff --git a/docs/api.md b/docs/api.md index 45df77ac05f78..ae7d51c2aefbf 100644 --- a/docs/api.md +++ b/docs/api.md @@ -3,7 +3,7 @@ layout: global title: Spark API Documentation --- -Here you can API docs for Spark and its submodules. +Here you can read API docs for Spark and its submodules. - [Spark Scala API (Scaladoc)](api/scala/index.html) - [Spark Java API (Javadoc)](api/java/index.html) From d48e78934a346f023bd5cf44a34320f4d5a88e12 Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Sat, 27 Jun 2015 09:07:10 +0300 Subject: [PATCH 40/44] [SPARK-3629] [YARN] [DOCS]: Improvement of the "Running Spark on YARN" document As per the description in the JIRA, I moved the contents of the page and added a few additional content. Author: Neelesh Srinivas Salian Closes #6924 from nssalian/SPARK-3629 and squashes the following commits: 944b7a0 [Neelesh Srinivas Salian] Changed the lines about deploy-mode and added backticks to all parameters 40dbc0b [Neelesh Srinivas Salian] Changed dfs to HDFS, deploy-mode in backticks and updated the master yarn line 9cbc072 [Neelesh Srinivas Salian] Updated a few lines in the Launching Spark on YARN Section 8e8db7f [Neelesh Srinivas Salian] Removed the changes in this commit to help clearly distinguish movement from update 151c298 [Neelesh Srinivas Salian] SPARK-3629: Improvement of the Spark on YARN document --- docs/running-on-yarn.md | 164 ++++++++++++++++++++-------------------- 1 file changed, 82 insertions(+), 82 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 3f8a093bbe957..de22ab557cacf 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -7,6 +7,51 @@ Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) was added to Spark in version 0.6.0, and improved in subsequent releases. +# Launching Spark on YARN + +Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. +These configs are used to write to HDFS and connect to the YARN ResourceManager. The +configuration contained in this directory will be distributed to the YARN cluster so that all +containers used by the application use the same configuration. If the configuration references +Java system properties or environment variables not managed by YARN, they should also be set in the +Spark application's configuration (driver, executors, and the AM when running in client mode). + +There are two deploy modes that can be used to launch Spark applications on YARN. In `yarn-cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `yarn-client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. + +Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. +To launch a Spark application in `yarn-cluster` mode: + + `$ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options]` + +For example: + + $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ + --master yarn-cluster \ + --num-executors 3 \ + --driver-memory 4g \ + --executor-memory 2g \ + --executor-cores 1 \ + --queue thequeue \ + lib/spark-examples*.jar \ + 10 + +The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. + +To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: + + $ ./bin/spark-shell --master yarn-client + +## Adding Other JARs + +In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. + + $ ./bin/spark-submit --class my.main.Class \ + --master yarn-cluster \ + --jars my-other-jar.jar,my-other-other-jar.jar + my-main-jar.jar + app_arg1 app_arg2 + + # Preparations Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. @@ -17,6 +62,38 @@ To build Spark yourself, refer to [Building Spark](building-spark.html). Most of the configs are the same for Spark on YARN as for other deployment modes. See the [configuration page](configuration.html) for more information on those. These are configs that are specific to Spark on YARN. +# Debugging your Application + +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. + + yarn logs -applicationId + +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). + +When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. + +To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a +large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +on the nodes on which containers are launched. This directory contains the launch script, JARs, and +all environment variables used for launching each container. This process is useful for debugging +classpath problems in particular. (Note that enabling this requires admin privileges on cluster +settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). + +To use a custom log4j configuration for the application master or executors, there are two options: + +- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files + to be uploaded with the application. +- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` + (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, + the `file:` protocol should be explicitly provided, and the file needs to exist locally on all + the nodes. + +Note that for the first option, both executors and the application master will share the same +log4j configuration, which may cause issues when they run on the same node (e.g. trying to write +to the same log file). + +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. + #### Spark Properties @@ -50,8 +127,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -189,8 +266,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -206,7 +283,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -286,83 +363,6 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
spark.yarn.am.waitTime 100s - In yarn-cluster mode, time for the application master to wait for the - SparkContext to be initialized. In yarn-client mode, time for the application master to wait + In `yarn-cluster` mode, time for the application master to wait for the + SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait for the driver to connect to it.
Add the environment variable specified by EnvironmentVariableName to the Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In yarn-cluster mode this controls - the environment of the SPARK driver and in yarn-client mode it only controls + these and to set multiple environment variables. In `yarn-cluster` mode this controls + the environment of the SPARK driver and in `yarn-client` mode it only controls the environment of the executor launcher.
(none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use spark.driver.extraJavaOptions instead. + In cluster mode, use `spark.driver.extraJavaOptions` instead.
-# Launching Spark on YARN - -Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. -These configs are used to write to the dfs and connect to the YARN ResourceManager. The -configuration contained in this directory will be distributed to the YARN cluster so that all -containers used by the application use the same configuration. If the configuration references -Java system properties or environment variables not managed by YARN, they should also be set in the -Spark application's configuration (driver, executors, and the AM when running in client mode). - -There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. - -Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the "master" parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the master parameter is simply "yarn-client" or "yarn-cluster". - -To launch a Spark application in yarn-cluster mode: - - ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] - -For example: - - $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ - --num-executors 3 \ - --driver-memory 4g \ - --executor-memory 2g \ - --executor-cores 1 \ - --queue thequeue \ - lib/spark-examples*.jar \ - 10 - -The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. - -To launch a Spark application in yarn-client mode, do the same, but replace "yarn-cluster" with "yarn-client". To run spark-shell: - - $ ./bin/spark-shell --master yarn-client - -## Adding Other JARs - -In yarn-cluster mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. - - $ ./bin/spark-submit --class my.main.Class \ - --master yarn-cluster \ - --jars my-other-jar.jar,my-other-other-jar.jar - my-main-jar.jar - app_arg1 app_arg2 - -# Debugging your Application - -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. - - yarn logs -applicationId - -will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). - -When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. - -To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` -on the nodes on which containers are launched. This directory contains the launch script, JARs, and -all environment variables used for launching each container. This process is useful for debugging -classpath problems in particular. (Note that enabling this requires admin privileges on cluster -settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). - -To use a custom log4j configuration for the application master or executors, there are two options: - -- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files - to be uploaded with the application. -- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` - (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, - the `file:` protocol should be explicitly provided, and the file needs to exist locally on all - the nodes. - -Note that for the first option, both executors and the application master will share the same -log4j configuration, which may cause issues when they run on the same node (e.g. trying to write -to the same log file). - -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. - # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. From 4153776fd840ae075e6bb608f054091b6d3ec0c4 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sat, 27 Jun 2015 14:33:31 -0700 Subject: [PATCH 41/44] [SPARK-8623] Hadoop RDDs fail to properly serialize configuration Author: Sandy Ryza Closes #7050 from sryza/sandy-spark-8623 and squashes the following commits: 58a8079 [Sandy Ryza] SPARK-8623. Hadoop RDDs fail to properly serialize configuration --- .../scala/org/apache/spark/serializer/KryoSerializer.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index cd8a82347a1e9..ed35cffe968f8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -36,7 +36,7 @@ import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} import org.apache.spark.util.collection.CompactBuffer /** @@ -94,8 +94,10 @@ class KryoSerializer(conf: SparkConf) // For results returned by asJavaIterable. See JavaIterableWrapperSerializer. kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer) - // Allow sending SerializableWritable + // Allow sending classes with custom Java serializers kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) + kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) + kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) From 0b5abbf5f96a5f6bfd15a65e8788cf3fa96fe54c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 27 Jun 2015 14:40:45 -0700 Subject: [PATCH 42/44] [SPARK-8606] Prevent exceptions in RDD.getPreferredLocations() from crashing DAGScheduler If `RDD.getPreferredLocations()` throws an exception it may crash the DAGScheduler and SparkContext. This patch addresses this by adding a try-catch block. Author: Josh Rosen Closes #7023 from JoshRosen/SPARK-8606 and squashes the following commits: 770b169 [Josh Rosen] Fix getPreferredLocations() DAGScheduler crash with try block. 44a9b55 [Josh Rosen] Add test of a buggy getPartitions() method 19aa9f7 [Josh Rosen] Add (failing) regression test for getPreferredLocations() DAGScheduler crash --- .../apache/spark/scheduler/DAGScheduler.scala | 37 +++++++++++-------- .../spark/scheduler/DAGSchedulerSuite.scala | 31 ++++++++++++++++ 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b00a5fee09bf2..a7cf0c23d9613 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -907,22 +907,29 @@ class DAGScheduler( return } - val tasks: Seq[Task[_]] = stage match { - case stage: ShuffleMapStage => - partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) - val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) - } + val tasks: Seq[Task[_]] = try { + stage match { + case stage: ShuffleMapStage => + partitionsToCompute.map { id => + val locs = getPreferredLocs(stage.rdd, id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, taskBinary, part, locs) + } - case stage: ResultStage => - val job = stage.resultOfJob.get - partitionsToCompute.map { id => - val p: Int = job.partitions(id) - val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) - } + case stage: ResultStage => + val job = stage.resultOfJob.get + partitionsToCompute.map { id => + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + new ResultTask(stage.id, taskBinary, part, locs, id) + } + } + } catch { + case NonFatal(e) => + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return } if (tasks.size > 0) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 833b600746e90..6bc45f249f975 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -784,6 +784,37 @@ class DAGSchedulerSuite assert(sc.parallelize(1 to 10, 2).first() === 1) } + test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") { + val e1 = intercept[DAGSchedulerSuiteDummyException] { + val rdd = new MyRDD(sc, 2, Nil) { + override def getPartitions: Array[Partition] = { + throw new DAGSchedulerSuiteDummyException + } + } + rdd.reduceByKey(_ + _, 1).count() + } + + // Make sure we can still run local commands as well as cluster commands. + assert(sc.parallelize(1 to 10, 2).count() === 10) + assert(sc.parallelize(1 to 10, 2).first() === 1) + } + + test("getPreferredLocations errors should not crash DAGScheduler and SparkContext (SPARK-8606)") { + val e1 = intercept[SparkException] { + val rdd = new MyRDD(sc, 2, Nil) { + override def getPreferredLocations(split: Partition): Seq[String] = { + throw new DAGSchedulerSuiteDummyException + } + } + rdd.count() + } + assert(e1.getMessage.contains(classOf[DAGSchedulerSuiteDummyException].getName)) + + // Make sure we can still run local commands as well as cluster commands. + assert(sc.parallelize(1 to 10, 2).count() === 10) + assert(sc.parallelize(1 to 10, 2).first() === 1) + } + test("accumulator not calculated for resubmitted result stage") { // just for register val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam) From 40648c56cdaa52058a4771082f8f44a2d8e5a1ec Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 27 Jun 2015 20:24:34 -0700 Subject: [PATCH 43/44] [SPARK-8583] [SPARK-5482] [BUILD] Refactor python/run-tests to integrate with dev/run-tests module system This patch refactors the `python/run-tests` script: - It's now written in Python instead of Bash. - The descriptions of the tests to run are now stored in `dev/run-tests`'s modules. This allows the pull request builder to skip Python tests suites that were not affected by the pull request's changes. For example, we can now skip the PySpark Streaming test cases when only SQL files are changed. - `python/run-tests` now supports command-line flags to make it easier to run individual test suites (this addresses SPARK-5482): ``` Usage: run-tests [options] Options: -h, --help show this help message and exit --python-executables=PYTHON_EXECUTABLES A comma-separated list of Python executables to test against (default: python2.6,python3.4,pypy) --modules=MODULES A comma-separated list of Python modules to test (default: pyspark-core,pyspark-ml,pyspark-mllib ,pyspark-sql,pyspark-streaming) ``` - `dev/run-tests` has been split into multiple files: the module definitions and test utility functions are now stored inside of a `dev/sparktestsupport` Python module, allowing them to be re-used from the Python test runner script. Author: Josh Rosen Closes #6967 from JoshRosen/run-tests-python-modules and squashes the following commits: f578d6d [Josh Rosen] Fix print for Python 2.x 8233d61 [Josh Rosen] Add python/run-tests.py to Python lint checks 34c98d2 [Josh Rosen] Fix universal_newlines for Python 3 8f65ed0 [Josh Rosen] Fix handling of module in python/run-tests 37aff00 [Josh Rosen] Python 3 fix 27a389f [Josh Rosen] Skip MLLib tests for PyPy c364ccf [Josh Rosen] Use which() to convert PYSPARK_PYTHON to an absolute path before shelling out to run tests 568a3fd [Josh Rosen] Fix hashbang 3b852ae [Josh Rosen] Fall back to PYSPARK_PYTHON when sys.executable is None (fixes a test) f53db55 [Josh Rosen] Remove python2 flag, since the test runner script also works fine under Python 3 9c80469 [Josh Rosen] Fix passing of PYSPARK_PYTHON d33e525 [Josh Rosen] Merge remote-tracking branch 'origin/master' into run-tests-python-modules 4f8902c [Josh Rosen] Python lint fixes. 8f3244c [Josh Rosen] Use universal_newlines to fix dev/run-tests doctest failures on Python 3. f542ac5 [Josh Rosen] Fix lint check for Python 3 fff4d09 [Josh Rosen] Add dev/sparktestsupport to pep8 checks 2efd594 [Josh Rosen] Update dev/run-tests to use new Python test runner flags b2ab027 [Josh Rosen] Add command-line options for running individual suites in python/run-tests caeb040 [Josh Rosen] Fixes to PySpark test module definitions d6a77d3 [Josh Rosen] Fix the tests of dev/run-tests def2d8a [Josh Rosen] Two minor fixes aec0b8f [Josh Rosen] Actually get the Kafka stuff to run properly 04015b9 [Josh Rosen] First attempt at getting PySpark Kafka test to work in new runner script 4c97136 [Josh Rosen] PYTHONPATH fixes dcc9c09 [Josh Rosen] Fix time division 32660fc [Josh Rosen] Initial cut at Python test runner refactoring 311c6a9 [Josh Rosen] Move shell utility functions to own module. 1bdeb87 [Josh Rosen] Move module definitions to separate file. --- dev/lint-python | 3 +- dev/run-tests.py | 435 ++++------------------------- dev/sparktestsupport/__init__.py | 21 ++ dev/sparktestsupport/modules.py | 385 +++++++++++++++++++++++++ dev/sparktestsupport/shellutils.py | 81 ++++++ python/pyspark/streaming/tests.py | 16 ++ python/pyspark/tests.py | 3 +- python/run-tests | 164 +---------- python/run-tests.py | 132 +++++++++ 9 files changed, 700 insertions(+), 540 deletions(-) create mode 100644 dev/sparktestsupport/__init__.py create mode 100644 dev/sparktestsupport/modules.py create mode 100644 dev/sparktestsupport/shellutils.py create mode 100755 python/run-tests.py diff --git a/dev/lint-python b/dev/lint-python index f50d149dc4d44..0c3586462cb37 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -19,7 +19,8 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/" +PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py" PYTHON_LINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/python-lint-report.txt" cd "$SPARK_ROOT_DIR" diff --git a/dev/run-tests.py b/dev/run-tests.py index e7c09b0f40cdc..c51b0d3010a0f 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -17,297 +17,23 @@ # limitations under the License. # +from __future__ import print_function import itertools import os import re import sys -import shutil import subprocess from collections import namedtuple -SPARK_HOME = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..") -USER_HOME = os.environ.get("HOME") - +from sparktestsupport import SPARK_HOME, USER_HOME +from sparktestsupport.shellutils import exit_from_command_with_retcode, run_cmd, rm_r, which +import sparktestsupport.modules as modules # ------------------------------------------------------------------------------------------------- -# Test module definitions and functions for traversing module dependency graph +# Functions for traversing module dependency graph # ------------------------------------------------------------------------------------------------- -all_modules = [] - - -class Module(object): - """ - A module is the basic abstraction in our test runner script. Each module consists of a set of - source files, a set of test commands, and a set of dependencies on other modules. We use modules - to define a dependency graph that lets determine which tests to run based on which files have - changed. - """ - - def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), - sbt_test_goals=(), should_run_python_tests=False, should_run_r_tests=False): - """ - Define a new module. - - :param name: A short module name, for display in logging and error messages. - :param dependencies: A set of dependencies for this module. This should only include direct - dependencies; transitive dependencies are resolved automatically. - :param source_file_regexes: a set of regexes that match source files belonging to this - module. These regexes are applied by attempting to match at the beginning of the - filename strings. - :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in - order to build and test this module (e.g. '-PprofileName'). - :param sbt_test_goals: A set of SBT test goals for testing this module. - :param should_run_python_tests: If true, changes in this module will trigger Python tests. - For now, this has the effect of causing _all_ Python tests to be run, although in the - future this should be changed to run only a subset of the Python tests that depend - on this module. - :param should_run_r_tests: If true, changes in this module will trigger all R tests. - """ - self.name = name - self.dependencies = dependencies - self.source_file_prefixes = source_file_regexes - self.sbt_test_goals = sbt_test_goals - self.build_profile_flags = build_profile_flags - self.should_run_python_tests = should_run_python_tests - self.should_run_r_tests = should_run_r_tests - - self.dependent_modules = set() - for dep in dependencies: - dep.dependent_modules.add(self) - all_modules.append(self) - - def contains_file(self, filename): - return any(re.match(p, filename) for p in self.source_file_prefixes) - - -sql = Module( - name="sql", - dependencies=[], - source_file_regexes=[ - "sql/(?!hive-thriftserver)", - "bin/spark-sql", - ], - build_profile_flags=[ - "-Phive", - ], - sbt_test_goals=[ - "catalyst/test", - "sql/test", - "hive/test", - ] -) - - -hive_thriftserver = Module( - name="hive-thriftserver", - dependencies=[sql], - source_file_regexes=[ - "sql/hive-thriftserver", - "sbin/start-thriftserver.sh", - ], - build_profile_flags=[ - "-Phive-thriftserver", - ], - sbt_test_goals=[ - "hive-thriftserver/test", - ] -) - - -graphx = Module( - name="graphx", - dependencies=[], - source_file_regexes=[ - "graphx/", - ], - sbt_test_goals=[ - "graphx/test" - ] -) - - -streaming = Module( - name="streaming", - dependencies=[], - source_file_regexes=[ - "streaming", - ], - sbt_test_goals=[ - "streaming/test", - ] -) - - -streaming_kinesis_asl = Module( - name="kinesis-asl", - dependencies=[streaming], - source_file_regexes=[ - "extras/kinesis-asl/", - ], - build_profile_flags=[ - "-Pkinesis-asl", - ], - sbt_test_goals=[ - "kinesis-asl/test", - ] -) - - -streaming_zeromq = Module( - name="streaming-zeromq", - dependencies=[streaming], - source_file_regexes=[ - "external/zeromq", - ], - sbt_test_goals=[ - "streaming-zeromq/test", - ] -) - - -streaming_twitter = Module( - name="streaming-twitter", - dependencies=[streaming], - source_file_regexes=[ - "external/twitter", - ], - sbt_test_goals=[ - "streaming-twitter/test", - ] -) - - -streaming_mqtt = Module( - name="streaming-mqtt", - dependencies=[streaming], - source_file_regexes=[ - "external/mqtt", - ], - sbt_test_goals=[ - "streaming-mqtt/test", - ] -) - - -streaming_kafka = Module( - name="streaming-kafka", - dependencies=[streaming], - source_file_regexes=[ - "external/kafka", - "external/kafka-assembly", - ], - sbt_test_goals=[ - "streaming-kafka/test", - ] -) - - -streaming_flume_sink = Module( - name="streaming-flume-sink", - dependencies=[streaming], - source_file_regexes=[ - "external/flume-sink", - ], - sbt_test_goals=[ - "streaming-flume-sink/test", - ] -) - - -streaming_flume = Module( - name="streaming_flume", - dependencies=[streaming], - source_file_regexes=[ - "external/flume", - ], - sbt_test_goals=[ - "streaming-flume/test", - ] -) - - -mllib = Module( - name="mllib", - dependencies=[streaming, sql], - source_file_regexes=[ - "data/mllib/", - "mllib/", - ], - sbt_test_goals=[ - "mllib/test", - ] -) - - -examples = Module( - name="examples", - dependencies=[graphx, mllib, streaming, sql], - source_file_regexes=[ - "examples/", - ], - sbt_test_goals=[ - "examples/test", - ] -) - - -pyspark = Module( - name="pyspark", - dependencies=[mllib, streaming, streaming_kafka, sql], - source_file_regexes=[ - "python/" - ], - should_run_python_tests=True -) - - -sparkr = Module( - name="sparkr", - dependencies=[sql, mllib], - source_file_regexes=[ - "R/", - ], - should_run_r_tests=True -) - - -docs = Module( - name="docs", - dependencies=[], - source_file_regexes=[ - "docs/", - ] -) - - -ec2 = Module( - name="ec2", - dependencies=[], - source_file_regexes=[ - "ec2/", - ] -) - - -# The root module is a dummy module which is used to run all of the tests. -# No other modules should directly depend on this module. -root = Module( - name="root", - dependencies=[], - source_file_regexes=[], - # In order to run all of the tests, enable every test profile: - build_profile_flags= - list(set(itertools.chain.from_iterable(m.build_profile_flags for m in all_modules))), - sbt_test_goals=[ - "test", - ], - should_run_python_tests=True, - should_run_r_tests=True -) - - def determine_modules_for_files(filenames): """ Given a list of filenames, return the set of modules that contain those files. @@ -315,19 +41,19 @@ def determine_modules_for_files(filenames): file to belong to the 'root' module. >>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/test/foo"])) - ['pyspark', 'sql'] + ['pyspark-core', 'sql'] >>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])] ['root'] """ changed_modules = set() for filename in filenames: matched_at_least_one_module = False - for module in all_modules: + for module in modules.all_modules: if module.contains_file(filename): changed_modules.add(module) matched_at_least_one_module = True if not matched_at_least_one_module: - changed_modules.add(root) + changed_modules.add(modules.root) return changed_modules @@ -352,7 +78,8 @@ def identify_changed_files_from_git_commits(patch_sha, target_branch=None, targe run_cmd(['git', 'fetch', 'origin', str(target_branch+':'+target_branch)]) else: diff_target = target_ref - raw_output = subprocess.check_output(['git', 'diff', '--name-only', patch_sha, diff_target]) + raw_output = subprocess.check_output(['git', 'diff', '--name-only', patch_sha, diff_target], + universal_newlines=True) # Remove any empty strings return [f for f in raw_output.split('\n') if f] @@ -362,18 +89,20 @@ def determine_modules_to_test(changed_modules): Given a set of modules that have changed, compute the transitive closure of those modules' dependent modules in order to determine the set of modules that should be tested. - >>> sorted(x.name for x in determine_modules_to_test([root])) + >>> sorted(x.name for x in determine_modules_to_test([modules.root])) ['root'] - >>> sorted(x.name for x in determine_modules_to_test([graphx])) + >>> sorted(x.name for x in determine_modules_to_test([modules.graphx])) ['examples', 'graphx'] - >>> sorted(x.name for x in determine_modules_to_test([sql])) - ['examples', 'hive-thriftserver', 'mllib', 'pyspark', 'sparkr', 'sql'] + >>> x = sorted(x.name for x in determine_modules_to_test([modules.sql])) + >>> x # doctest: +NORMALIZE_WHITESPACE + ['examples', 'hive-thriftserver', 'mllib', 'pyspark-core', 'pyspark-ml', \ + 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming', 'sparkr', 'sql'] """ # If we're going to have to run all of the tests, then we can just short-circuit # and return 'root'. No module depends on root, so if it appears then it will be # in changed_modules. - if root in changed_modules: - return [root] + if modules.root in changed_modules: + return [modules.root] modules_to_test = set() for module in changed_modules: modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules)) @@ -398,60 +127,6 @@ def get_error_codes(err_code_file): ERROR_CODES = get_error_codes(os.path.join(SPARK_HOME, "dev/run-tests-codes.sh")) -def exit_from_command_with_retcode(cmd, retcode): - print "[error] running", ' '.join(cmd), "; received return code", retcode - sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) - - -def rm_r(path): - """Given an arbitrary path properly remove it with the correct python - construct if it exists - - from: http://stackoverflow.com/a/9559881""" - - if os.path.isdir(path): - shutil.rmtree(path) - elif os.path.exists(path): - os.remove(path) - - -def run_cmd(cmd): - """Given a command as a list of arguments will attempt to execute the - command from the determined SPARK_HOME directory and, on failure, print - an error message""" - - if not isinstance(cmd, list): - cmd = cmd.split() - try: - subprocess.check_call(cmd) - except subprocess.CalledProcessError as e: - exit_from_command_with_retcode(e.cmd, e.returncode) - - -def is_exe(path): - """Check if a given path is an executable file - - from: http://stackoverflow.com/a/377028""" - - return os.path.isfile(path) and os.access(path, os.X_OK) - - -def which(program): - """Find and return the given program by its absolute path or 'None' - - from: http://stackoverflow.com/a/377028""" - - fpath = os.path.split(program)[0] - - if fpath: - if is_exe(program): - return program - else: - for path in os.environ.get("PATH").split(os.pathsep): - path = path.strip('"') - exe_file = os.path.join(path, program) - if is_exe(exe_file): - return exe_file - return None - - def determine_java_executable(): """Will return the path of the java executable that will be used by Spark's tests or `None`""" @@ -476,7 +151,8 @@ def determine_java_version(java_exe): with accessors '.major', '.minor', '.patch', '.update'""" raw_output = subprocess.check_output([java_exe, "-version"], - stderr=subprocess.STDOUT) + stderr=subprocess.STDOUT, + universal_newlines=True) raw_output_lines = raw_output.split('\n') @@ -504,10 +180,10 @@ def set_title_and_block(title, err_block): os.environ["CURRENT_BLOCK"] = ERROR_CODES[err_block] line_str = '=' * 72 - print - print line_str - print title - print line_str + print('') + print(line_str) + print(title) + print(line_str) def run_apache_rat_checks(): @@ -534,8 +210,8 @@ def build_spark_documentation(): jekyll_bin = which("jekyll") if not jekyll_bin: - print "[error] Cannot find a version of `jekyll` on the system; please", - print "install one and retry to build documentation." + print("[error] Cannot find a version of `jekyll` on the system; please" + " install one and retry to build documentation.") sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) else: run_cmd([jekyll_bin, "build"]) @@ -571,7 +247,7 @@ def exec_sbt(sbt_args=()): echo_proc.wait() for line in iter(sbt_proc.stdout.readline, ''): if not sbt_output_filter.match(line): - print line, + print(line, end='') retcode = sbt_proc.wait() if retcode > 0: @@ -594,33 +270,33 @@ def get_hadoop_profiles(hadoop_version): if hadoop_version in sbt_maven_hadoop_profiles: return sbt_maven_hadoop_profiles[hadoop_version] else: - print "[error] Could not find", hadoop_version, "in the list. Valid options", - print "are", sbt_maven_hadoop_profiles.keys() + print("[error] Could not find", hadoop_version, "in the list. Valid options" + " are", sbt_maven_hadoop_profiles.keys()) sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) def build_spark_maven(hadoop_version): # Enable all of the profiles for the build: - build_profiles = get_hadoop_profiles(hadoop_version) + root.build_profile_flags + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags mvn_goals = ["clean", "package", "-DskipTests"] profiles_and_goals = build_profiles + mvn_goals - print "[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments:", - print " ".join(profiles_and_goals) + print("[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments: " + " ".join(profiles_and_goals)) exec_maven(profiles_and_goals) def build_spark_sbt(hadoop_version): # Enable all of the profiles for the build: - build_profiles = get_hadoop_profiles(hadoop_version) + root.build_profile_flags + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["package", "assembly/assembly", "streaming-kafka-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals - print "[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments:", - print " ".join(profiles_and_goals) + print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: " + " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) @@ -648,8 +324,8 @@ def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] profiles_and_goals = test_profiles + mvn_test_goals - print "[info] Running Spark tests using Maven with these arguments:", - print " ".join(profiles_and_goals) + print("[info] Running Spark tests using Maven with these arguments: " + " ".join(profiles_and_goals)) exec_maven(profiles_and_goals) @@ -663,8 +339,8 @@ def run_scala_tests_sbt(test_modules, test_profiles): profiles_and_goals = test_profiles + list(sbt_test_goals) - print "[info] Running Spark tests using SBT with these arguments:", - print " ".join(profiles_and_goals) + print("[info] Running Spark tests using SBT with these arguments: " + " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) @@ -684,10 +360,13 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): run_scala_tests_sbt(test_modules, test_profiles) -def run_python_tests(): +def run_python_tests(test_modules): set_title_and_block("Running PySpark tests", "BLOCK_PYSPARK_UNIT_TESTS") - run_cmd([os.path.join(SPARK_HOME, "python", "run-tests")]) + command = [os.path.join(SPARK_HOME, "python", "run-tests")] + if test_modules != [modules.root]: + command.append("--modules=%s" % ','.join(m.name for m in modules)) + run_cmd(command) def run_sparkr_tests(): @@ -697,14 +376,14 @@ def run_sparkr_tests(): run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) run_cmd([os.path.join(SPARK_HOME, "R", "run-tests.sh")]) else: - print "Ignoring SparkR tests as R was not found in PATH" + print("Ignoring SparkR tests as R was not found in PATH") def main(): # Ensure the user home directory (HOME) is valid and is an absolute directory if not USER_HOME or not os.path.isabs(USER_HOME): - print "[error] Cannot determine your home directory as an absolute path;", - print "ensure the $HOME environment variable is set properly." + print("[error] Cannot determine your home directory as an absolute path;" + " ensure the $HOME environment variable is set properly.") sys.exit(1) os.chdir(SPARK_HOME) @@ -718,14 +397,14 @@ def main(): java_exe = determine_java_executable() if not java_exe: - print "[error] Cannot find a version of `java` on the system; please", - print "install one and retry." + print("[error] Cannot find a version of `java` on the system; please" + " install one and retry.") sys.exit(2) java_version = determine_java_version(java_exe) if java_version.minor < 8: - print "[warn] Java 8 tests will not run because JDK version is < 1.8." + print("[warn] Java 8 tests will not run because JDK version is < 1.8.") if os.environ.get("AMPLAB_JENKINS"): # if we're on the Amplab Jenkins build servers setup variables @@ -741,8 +420,8 @@ def main(): hadoop_version = "hadoop2.3" test_env = "local" - print "[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version, - print "under environment", test_env + print("[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version, + "under environment", test_env) changed_modules = None changed_files = None @@ -751,8 +430,9 @@ def main(): changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) if not changed_modules: - changed_modules = [root] - print "[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules) + changed_modules = [modules.root] + print("[info] Found the following changed modules:", + ", ".join(x.name for x in changed_modules)) test_modules = determine_modules_to_test(changed_modules) @@ -779,8 +459,9 @@ def main(): # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules) - if any(m.should_run_python_tests for m in test_modules): - run_python_tests() + modules_with_python_tests = [m for m in test_modules if m.python_test_goals] + if modules_with_python_tests: + run_python_tests(modules_with_python_tests) if any(m.should_run_r_tests for m in test_modules): run_sparkr_tests() diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py new file mode 100644 index 0000000000000..12696d98fb988 --- /dev/null +++ b/dev/sparktestsupport/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +SPARK_HOME = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../")) +USER_HOME = os.environ.get("HOME") diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py new file mode 100644 index 0000000000000..efe3a897e9c10 --- /dev/null +++ b/dev/sparktestsupport/modules.py @@ -0,0 +1,385 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import itertools +import re + +all_modules = [] + + +class Module(object): + """ + A module is the basic abstraction in our test runner script. Each module consists of a set of + source files, a set of test commands, and a set of dependencies on other modules. We use modules + to define a dependency graph that lets determine which tests to run based on which files have + changed. + """ + + def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), + sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), + should_run_r_tests=False): + """ + Define a new module. + + :param name: A short module name, for display in logging and error messages. + :param dependencies: A set of dependencies for this module. This should only include direct + dependencies; transitive dependencies are resolved automatically. + :param source_file_regexes: a set of regexes that match source files belonging to this + module. These regexes are applied by attempting to match at the beginning of the + filename strings. + :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in + order to build and test this module (e.g. '-PprofileName'). + :param sbt_test_goals: A set of SBT test goals for testing this module. + :param python_test_goals: A set of Python test goals for testing this module. + :param blacklisted_python_implementations: A set of Python implementations that are not + supported by this module's Python components. The values in this set should match + strings returned by Python's `platform.python_implementation()`. + :param should_run_r_tests: If true, changes in this module will trigger all R tests. + """ + self.name = name + self.dependencies = dependencies + self.source_file_prefixes = source_file_regexes + self.sbt_test_goals = sbt_test_goals + self.build_profile_flags = build_profile_flags + self.python_test_goals = python_test_goals + self.blacklisted_python_implementations = blacklisted_python_implementations + self.should_run_r_tests = should_run_r_tests + + self.dependent_modules = set() + for dep in dependencies: + dep.dependent_modules.add(self) + all_modules.append(self) + + def contains_file(self, filename): + return any(re.match(p, filename) for p in self.source_file_prefixes) + + +sql = Module( + name="sql", + dependencies=[], + source_file_regexes=[ + "sql/(?!hive-thriftserver)", + "bin/spark-sql", + ], + build_profile_flags=[ + "-Phive", + ], + sbt_test_goals=[ + "catalyst/test", + "sql/test", + "hive/test", + ] +) + + +hive_thriftserver = Module( + name="hive-thriftserver", + dependencies=[sql], + source_file_regexes=[ + "sql/hive-thriftserver", + "sbin/start-thriftserver.sh", + ], + build_profile_flags=[ + "-Phive-thriftserver", + ], + sbt_test_goals=[ + "hive-thriftserver/test", + ] +) + + +graphx = Module( + name="graphx", + dependencies=[], + source_file_regexes=[ + "graphx/", + ], + sbt_test_goals=[ + "graphx/test" + ] +) + + +streaming = Module( + name="streaming", + dependencies=[], + source_file_regexes=[ + "streaming", + ], + sbt_test_goals=[ + "streaming/test", + ] +) + + +streaming_kinesis_asl = Module( + name="kinesis-asl", + dependencies=[streaming], + source_file_regexes=[ + "extras/kinesis-asl/", + ], + build_profile_flags=[ + "-Pkinesis-asl", + ], + sbt_test_goals=[ + "kinesis-asl/test", + ] +) + + +streaming_zeromq = Module( + name="streaming-zeromq", + dependencies=[streaming], + source_file_regexes=[ + "external/zeromq", + ], + sbt_test_goals=[ + "streaming-zeromq/test", + ] +) + + +streaming_twitter = Module( + name="streaming-twitter", + dependencies=[streaming], + source_file_regexes=[ + "external/twitter", + ], + sbt_test_goals=[ + "streaming-twitter/test", + ] +) + + +streaming_mqtt = Module( + name="streaming-mqtt", + dependencies=[streaming], + source_file_regexes=[ + "external/mqtt", + ], + sbt_test_goals=[ + "streaming-mqtt/test", + ] +) + + +streaming_kafka = Module( + name="streaming-kafka", + dependencies=[streaming], + source_file_regexes=[ + "external/kafka", + "external/kafka-assembly", + ], + sbt_test_goals=[ + "streaming-kafka/test", + ] +) + + +streaming_flume_sink = Module( + name="streaming-flume-sink", + dependencies=[streaming], + source_file_regexes=[ + "external/flume-sink", + ], + sbt_test_goals=[ + "streaming-flume-sink/test", + ] +) + + +streaming_flume = Module( + name="streaming_flume", + dependencies=[streaming], + source_file_regexes=[ + "external/flume", + ], + sbt_test_goals=[ + "streaming-flume/test", + ] +) + + +mllib = Module( + name="mllib", + dependencies=[streaming, sql], + source_file_regexes=[ + "data/mllib/", + "mllib/", + ], + sbt_test_goals=[ + "mllib/test", + ] +) + + +examples = Module( + name="examples", + dependencies=[graphx, mllib, streaming, sql], + source_file_regexes=[ + "examples/", + ], + sbt_test_goals=[ + "examples/test", + ] +) + + +pyspark_core = Module( + name="pyspark-core", + dependencies=[mllib, streaming, streaming_kafka], + source_file_regexes=[ + "python/(?!pyspark/(ml|mllib|sql|streaming))" + ], + python_test_goals=[ + "pyspark.rdd", + "pyspark.context", + "pyspark.conf", + "pyspark.broadcast", + "pyspark.accumulators", + "pyspark.serializers", + "pyspark.profiler", + "pyspark.shuffle", + "pyspark.tests", + ] +) + + +pyspark_sql = Module( + name="pyspark-sql", + dependencies=[pyspark_core, sql], + source_file_regexes=[ + "python/pyspark/sql" + ], + python_test_goals=[ + "pyspark.sql.types", + "pyspark.sql.context", + "pyspark.sql.column", + "pyspark.sql.dataframe", + "pyspark.sql.group", + "pyspark.sql.functions", + "pyspark.sql.readwriter", + "pyspark.sql.window", + "pyspark.sql.tests", + ] +) + + +pyspark_streaming = Module( + name="pyspark-streaming", + dependencies=[pyspark_core, streaming, streaming_kafka], + source_file_regexes=[ + "python/pyspark/streaming" + ], + python_test_goals=[ + "pyspark.streaming.util", + "pyspark.streaming.tests", + ] +) + + +pyspark_mllib = Module( + name="pyspark-mllib", + dependencies=[pyspark_core, pyspark_streaming, pyspark_sql, mllib], + source_file_regexes=[ + "python/pyspark/mllib" + ], + python_test_goals=[ + "pyspark.mllib.classification", + "pyspark.mllib.clustering", + "pyspark.mllib.evaluation", + "pyspark.mllib.feature", + "pyspark.mllib.fpm", + "pyspark.mllib.linalg", + "pyspark.mllib.random", + "pyspark.mllib.recommendation", + "pyspark.mllib.regression", + "pyspark.mllib.stat._statistics", + "pyspark.mllib.stat.KernelDensity", + "pyspark.mllib.tree", + "pyspark.mllib.util", + "pyspark.mllib.tests", + ], + blacklisted_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there + ] +) + + +pyspark_ml = Module( + name="pyspark-ml", + dependencies=[pyspark_core, pyspark_mllib], + source_file_regexes=[ + "python/pyspark/ml/" + ], + python_test_goals=[ + "pyspark.ml.feature", + "pyspark.ml.classification", + "pyspark.ml.recommendation", + "pyspark.ml.regression", + "pyspark.ml.tuning", + "pyspark.ml.tests", + "pyspark.ml.evaluation", + ], + blacklisted_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there + ] +) + +sparkr = Module( + name="sparkr", + dependencies=[sql, mllib], + source_file_regexes=[ + "R/", + ], + should_run_r_tests=True +) + + +docs = Module( + name="docs", + dependencies=[], + source_file_regexes=[ + "docs/", + ] +) + + +ec2 = Module( + name="ec2", + dependencies=[], + source_file_regexes=[ + "ec2/", + ] +) + + +# The root module is a dummy module which is used to run all of the tests. +# No other modules should directly depend on this module. +root = Module( + name="root", + dependencies=[], + source_file_regexes=[], + # In order to run all of the tests, enable every test profile: + build_profile_flags=list(set( + itertools.chain.from_iterable(m.build_profile_flags for m in all_modules))), + sbt_test_goals=[ + "test", + ], + python_test_goals=list(itertools.chain.from_iterable(m.python_test_goals for m in all_modules)), + should_run_r_tests=True +) diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py new file mode 100644 index 0000000000000..ad9b0cc89e4ab --- /dev/null +++ b/dev/sparktestsupport/shellutils.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import shutil +import subprocess +import sys + + +def exit_from_command_with_retcode(cmd, retcode): + print("[error] running", ' '.join(cmd), "; received return code", retcode) + sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) + + +def rm_r(path): + """ + Given an arbitrary path, properly remove it with the correct Python construct if it exists. + From: http://stackoverflow.com/a/9559881 + """ + + if os.path.isdir(path): + shutil.rmtree(path) + elif os.path.exists(path): + os.remove(path) + + +def run_cmd(cmd): + """ + Given a command as a list of arguments will attempt to execute the command + and, on failure, print an error message and exit. + """ + + if not isinstance(cmd, list): + cmd = cmd.split() + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError as e: + exit_from_command_with_retcode(e.cmd, e.returncode) + + +def is_exe(path): + """ + Check if a given path is an executable file. + From: http://stackoverflow.com/a/377028 + """ + + return os.path.isfile(path) and os.access(path, os.X_OK) + + +def which(program): + """ + Find and return the given program by its absolute path or 'None' if the program cannot be found. + From: http://stackoverflow.com/a/377028 + """ + + fpath = os.path.split(program)[0] + + if fpath: + if is_exe(program): + return program + else: + for path in os.environ.get("PATH").split(os.pathsep): + path = path.strip('"') + exe_file = os.path.join(path, program) + if is_exe(exe_file): + return exe_file + return None diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 57049beea4dba..91ce681fbe169 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -15,6 +15,7 @@ # limitations under the License. # +import glob import os import sys from itertools import chain @@ -677,4 +678,19 @@ def test_kafka_rdd_with_leaders(self): self._validateRddResult(sendData, rdd) if __name__ == "__main__": + SPARK_HOME = os.environ["SPARK_HOME"] + kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") + jars = glob.glob( + os.path.join(kafka_assembly_dir, "target/scala-*/spark-streaming-kafka-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " + "remove all but one") % kafka_assembly_dir) + else: + os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars[0] unittest.main() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 78265423682b0..17256dfc95744 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1421,7 +1421,8 @@ def do_termination_test(self, terminator): # start daemon daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") - daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE) + python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") + daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) # read the port number port = read_int(daemon.stdout) diff --git a/python/run-tests b/python/run-tests index 4468fdb3f267e..24949657ed7ab 100755 --- a/python/run-tests +++ b/python/run-tests @@ -18,165 +18,7 @@ # -# Figure out where the Spark framework is installed -FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" +FWDIR="$(cd "`dirname $0`"/..; pwd)" +cd "$FWDIR" -. "$FWDIR"/bin/load-spark-env.sh - -# CD into the python directory to find things on the right path -cd "$FWDIR/python" - -FAILED=0 -LOG_FILE=unit-tests.log -START=$(date +"%s") - -rm -f $LOG_FILE - -# Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL -rm -rf metastore warehouse - -function run_test() { - echo -en "Running test: $1 ... " | tee -a $LOG_FILE - start=$(date +"%s") - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1 - - FAILED=$((PIPESTATUS[0]||$FAILED)) - - # Fail and exit on the first test failure. - if [[ $FAILED != 0 ]]; then - cat $LOG_FILE | grep -v "^[0-9][0-9]*" # filter all lines starting with a number. - echo -en "\033[31m" # Red - echo "Had test failures; see logs." - echo -en "\033[0m" # No color - exit -1 - else - now=$(date +"%s") - echo "ok ($(($now - $start))s)" - fi -} - -function run_core_tests() { - echo "Run core tests ..." - run_test "pyspark.rdd" - run_test "pyspark.context" - run_test "pyspark.conf" - run_test "pyspark.broadcast" - run_test "pyspark.accumulators" - run_test "pyspark.serializers" - run_test "pyspark.profiler" - run_test "pyspark.shuffle" - run_test "pyspark.tests" -} - -function run_sql_tests() { - echo "Run sql tests ..." - run_test "pyspark.sql.types" - run_test "pyspark.sql.context" - run_test "pyspark.sql.column" - run_test "pyspark.sql.dataframe" - run_test "pyspark.sql.group" - run_test "pyspark.sql.functions" - run_test "pyspark.sql.readwriter" - run_test "pyspark.sql.window" - run_test "pyspark.sql.tests" -} - -function run_mllib_tests() { - echo "Run mllib tests ..." - run_test "pyspark.mllib.classification" - run_test "pyspark.mllib.clustering" - run_test "pyspark.mllib.evaluation" - run_test "pyspark.mllib.feature" - run_test "pyspark.mllib.fpm" - run_test "pyspark.mllib.linalg" - run_test "pyspark.mllib.random" - run_test "pyspark.mllib.recommendation" - run_test "pyspark.mllib.regression" - run_test "pyspark.mllib.stat._statistics" - run_test "pyspark.mllib.stat.KernelDensity" - run_test "pyspark.mllib.tree" - run_test "pyspark.mllib.util" - run_test "pyspark.mllib.tests" -} - -function run_ml_tests() { - echo "Run ml tests ..." - run_test "pyspark.ml.feature" - run_test "pyspark.ml.classification" - run_test "pyspark.ml.recommendation" - run_test "pyspark.ml.regression" - run_test "pyspark.ml.tuning" - run_test "pyspark.ml.tests" - run_test "pyspark.ml.evaluation" -} - -function run_streaming_tests() { - echo "Run streaming tests ..." - - KAFKA_ASSEMBLY_DIR="$FWDIR"/external/kafka-assembly - JAR_PATH="${KAFKA_ASSEMBLY_DIR}/target/scala-${SPARK_SCALA_VERSION}" - for f in "${JAR_PATH}"/spark-streaming-kafka-assembly-*.jar; do - if [[ ! -e "$f" ]]; then - echo "Failed to find Spark Streaming Kafka assembly jar in $KAFKA_ASSEMBLY_DIR" 1>&2 - echo "You need to build Spark with " \ - "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or" \ - "'build/mvn package' before running this program" 1>&2 - exit 1 - fi - KAFKA_ASSEMBLY_JAR="$f" - done - - export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell" - run_test "pyspark.streaming.util" - run_test "pyspark.streaming.tests" -} - -echo "Running PySpark tests. Output is in python/$LOG_FILE." - -export PYSPARK_PYTHON="python" - -# Try to test with Python 2.6, since that's the minimum version that we support: -if [ $(which python2.6) ]; then - export PYSPARK_PYTHON="python2.6" -fi - -echo "Testing with Python version:" -$PYSPARK_PYTHON --version - -run_core_tests -run_sql_tests -run_mllib_tests -run_ml_tests -run_streaming_tests - -# Try to test with Python 3 -if [ $(which python3.4) ]; then - export PYSPARK_PYTHON="python3.4" - echo "Testing with Python3.4 version:" - $PYSPARK_PYTHON --version - - run_core_tests - run_sql_tests - run_mllib_tests - run_ml_tests - run_streaming_tests -fi - -# Try to test with PyPy -if [ $(which pypy) ]; then - export PYSPARK_PYTHON="pypy" - echo "Testing with PyPy version:" - $PYSPARK_PYTHON --version - - run_core_tests - run_sql_tests - run_streaming_tests -fi - -if [[ $FAILED == 0 ]]; then - now=$(date +"%s") - echo -e "\033[32mTests passed \033[0min $(($now - $START)) seconds" -fi - -# TODO: in the long-run, it would be nice to use a test runner like `nose`. -# The doctest fixtures are the current barrier to doing this. +exec python -u ./python/run-tests.py "$@" diff --git a/python/run-tests.py b/python/run-tests.py new file mode 100755 index 0000000000000..7d485b500ee3a --- /dev/null +++ b/python/run-tests.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function +from optparse import OptionParser +import os +import re +import subprocess +import sys +import time + + +# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../dev/")) + + +from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) +from sparktestsupport.shellutils import which # noqa +from sparktestsupport.modules import all_modules # noqa + + +python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root') + + +def print_red(text): + print('\033[31m' + text + '\033[0m') + + +LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") + + +def run_individual_python_test(test_name, pyspark_python): + env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} + print(" Running test: %s ..." % test_name, end='') + start_time = time.time() + with open(LOG_FILE, 'a') as log_file: + retcode = subprocess.call( + [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + stderr=log_file, stdout=log_file, env=env) + duration = time.time() - start_time + # Exit on the first failure. + if retcode != 0: + with open(LOG_FILE, 'r') as log_file: + for line in log_file: + if not re.match('[0-9]+', line): + print(line, end='') + print_red("\nHad test failures in %s; see logs." % test_name) + exit(-1) + else: + print("ok (%is)" % duration) + + +def get_default_python_executables(): + python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)] + if "python2.6" not in python_execs: + print("WARNING: Not testing against `python2.6` because it could not be found; falling" + " back to `python` instead") + python_execs.insert(0, "python") + return python_execs + + +def parse_opts(): + parser = OptionParser( + prog="run-tests" + ) + parser.add_option( + "--python-executables", type="string", default=','.join(get_default_python_executables()), + help="A comma-separated list of Python executables to test against (default: %default)" + ) + parser.add_option( + "--modules", type="string", + default=",".join(sorted(python_modules.keys())), + help="A comma-separated list of Python modules to test (default: %default)" + ) + + (opts, args) = parser.parse_args() + if args: + parser.error("Unsupported arguments: %s" % ' '.join(args)) + return opts + + +def main(): + opts = parse_opts() + print("Running PySpark tests. Output is in python/%s" % LOG_FILE) + if os.path.exists(LOG_FILE): + os.remove(LOG_FILE) + python_execs = opts.python_executables.split(',') + modules_to_test = [] + for module_name in opts.modules.split(','): + if module_name in python_modules: + modules_to_test.append(python_modules[module_name]) + else: + print("Error: unrecognized module %s" % module_name) + sys.exit(-1) + print("Will test against the following Python executables: %s" % python_execs) + print("Will test the following Python modules: %s" % [x.name for x in modules_to_test]) + + start_time = time.time() + for python_exec in python_execs: + python_implementation = subprocess.check_output( + [python_exec, "-c", "import platform; print(platform.python_implementation())"], + universal_newlines=True).strip() + print("Testing with `%s`: " % python_exec, end='') + subprocess.call([python_exec, "--version"]) + + for module in modules_to_test: + if python_implementation not in module.blacklisted_python_implementations: + print("Running %s tests ..." % module.name) + for test_goal in module.python_test_goals: + run_individual_python_test(test_goal, python_exec) + total_duration = time.time() - start_time + print("Tests passed in %i seconds" % total_duration) + + +if __name__ == "__main__": + main() From 42db3a1c2fb6db61e01756be7fe88c4110ae638e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 27 Jun 2015 23:07:20 -0700 Subject: [PATCH 44/44] [HOTFIX] Fix pull request builder bug in #6967 --- dev/run-tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index c51b0d3010a0f..3533e0c857b9b 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -365,7 +365,7 @@ def run_python_tests(test_modules): command = [os.path.join(SPARK_HOME, "python", "run-tests")] if test_modules != [modules.root]: - command.append("--modules=%s" % ','.join(m.name for m in modules)) + command.append("--modules=%s" % ','.join(m.name for m in test_modules)) run_cmd(command)