From 78a430ea4d2aef58a8bf38ce488553ca6acea428 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 17 Jun 2015 23:22:54 -0700 Subject: [PATCH 01/22] [SPARK-7961][SQL]Refactor SQLConf to display better error message 1. Add `SQLConfEntry` to store the information about a configuration. For those configurations that cannot be found in `sql-programming-guide.md`, I left the doc as ``. 2. Verify the value when setting a configuration if this is in SQLConf. 3. Use `SET -v` to display all public configurations. Author: zsxwing Closes #6747 from zsxwing/sqlconf and squashes the following commits: 7d09bad [zsxwing] Use SQLConfEntry in HiveContext 49f6213 [zsxwing] Add getConf, setConf to SQLContext and HiveContext e014f53 [zsxwing] Merge branch 'master' into sqlconf 93dad8e [zsxwing] Fix the unit tests cf950c1 [zsxwing] Fix the code style and tests 3c5f03e [zsxwing] Add unsetConf(SQLConfEntry) and fix the code style a2f4add [zsxwing] getConf will return the default value if a config is not set 037b1db [zsxwing] Add schema to SetCommand 0520c3c [zsxwing] Merge branch 'master' into sqlconf 7afb0ec [zsxwing] Fix the configurations about HiveThriftServer 7e728e3 [zsxwing] Add doc for SQLConfEntry and fix 'toString' 5e95b10 [zsxwing] Add enumConf c6ba76d [zsxwing] setRawString => setConfString, getRawString => getConfString 4abd807 [zsxwing] Fix the test for 'set -v' 6e47e56 [zsxwing] Fix the compilation error 8973ced [zsxwing] Remove floatConf 1fc3a8b [zsxwing] Remove the 'conf' command and use 'set -v' instead 99c9c16 [zsxwing] Fix tests that use SQLConfEntry as a string 88a03cc [zsxwing] Add new lines between confs and return types ce7c6c8 [zsxwing] Remove seqConf f3c1b33 [zsxwing] Refactor SQLConf to display better error message --- docs/sql-programming-guide.md | 4 +- .../scala/org/apache/spark/sql/SQLConf.scala | 493 ++++++++++++++---- .../org/apache/spark/sql/SQLContext.scala | 25 +- .../org/apache/spark/sql/SparkSQLParser.scala | 4 +- .../apache/spark/sql/execution/commands.scala | 98 +++- .../spark/sql/execution/debug/package.scala | 2 +- .../sql/parquet/ParquetTableOperations.scala | 8 +- .../apache/spark/sql/parquet/newParquet.scala | 6 +- .../apache/spark/sql/sources/commands.scala | 2 +- .../spark/sql/test/TestSQLContext.scala | 2 +- .../spark/sql/DataFrameAggregateSuite.scala | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 14 +- .../org/apache/spark/sql/JoinSuite.scala | 14 +- .../apache/spark/sql/SQLConfEntrySuite.scala | 150 ++++++ .../org/apache/spark/sql/SQLConfSuite.scala | 10 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 42 +- .../columnar/PartitionBatchPruningSuite.scala | 8 +- .../spark/sql/execution/PlannerSuite.scala | 8 +- .../org/apache/spark/sql/json/JsonSuite.scala | 4 +- .../sql/parquet/ParquetFilterSuite.scala | 14 +- .../spark/sql/parquet/ParquetIOSuite.scala | 16 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 8 +- .../spark/sql/sources/DataSourceTest.scala | 2 +- .../apache/spark/sql/test/SQLTestUtils.scala | 6 +- .../hive/thriftserver/HiveThriftServer2.scala | 4 +- .../SparkExecuteStatementOperation.scala | 2 +- .../HiveThriftServer2Suites.scala | 22 +- .../execution/HiveCompatibilitySuite.scala | 8 +- .../SortMergeCompatibilitySuite.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 88 +++- .../apache/spark/sql/hive/test/TestHive.scala | 5 +- .../spark/sql/hive/HiveParquetSuite.scala | 4 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 16 +- .../spark/sql/hive/StatisticsSuite.scala | 8 +- .../sql/hive/execution/HiveQuerySuite.scala | 12 +- .../sql/hive/execution/SQLQuerySuite.scala | 20 +- .../apache/spark/sql/hive/parquetSuites.scala | 20 +- 37 files changed, 861 insertions(+), 296 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 61f9c5f02ac72..c6e6ec88a205f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1220,7 +1220,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` false Some other Parquet-producing systems, in particular Impala and older versions of Spark SQL, do - not differentiate between binary data and strings when writing out the Parquet schema. This + not differentiate between binary data and strings when writing out the Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. @@ -1237,7 +1237,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.cacheMetadata true - Turns on caching of Parquet schema metadata. Can speed up querying of static data. + Turns on caching of Parquet schema metadata. Can speed up querying of static data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 55ab6b3358e3c..16493c3d7c19c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -25,74 +25,333 @@ import scala.collection.JavaConversions._ import org.apache.spark.sql.catalyst.CatalystConf private[spark] object SQLConf { - val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" - val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" - val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning" - val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" - val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" - val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" - val CODEGEN_ENABLED = "spark.sql.codegen" - val UNSAFE_ENABLED = "spark.sql.unsafe.enabled" - val DIALECT = "spark.sql.dialect" - val CASE_SENSITIVE = "spark.sql.caseSensitive" - - val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" - val PARQUET_INT96_AS_TIMESTAMP = "spark.sql.parquet.int96AsTimestamp" - val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" - val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" - val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" - val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" - - val ORC_FILTER_PUSHDOWN_ENABLED = "spark.sql.orc.filterPushdown" - - val HIVE_VERIFY_PARTITIONPATH = "spark.sql.hive.verifyPartitionPath" - - val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" - val BROADCAST_TIMEOUT = "spark.sql.broadcastTimeout" + + private val sqlConfEntries = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, SQLConfEntry[_]]()) + + /** + * An entry contains all meta information for a configuration. + * + * @param key the key for the configuration + * @param defaultValue the default value for the configuration + * @param valueConverter how to convert a string to the value. It should throw an exception if the + * string does not have the required format. + * @param stringConverter how to convert a value to a string that the user can use it as a valid + * string value. It's usually `toString`. But sometimes, a custom converter + * is necessary. E.g., if T is List[String], `a, b, c` is better than + * `List(a, b, c)`. + * @param doc the document for the configuration + * @param isPublic if this configuration is public to the user. If it's `false`, this + * configuration is only used internally and we should not expose it to the user. + * @tparam T the value type + */ + private[sql] class SQLConfEntry[T] private( + val key: String, + val defaultValue: Option[T], + val valueConverter: String => T, + val stringConverter: T => String, + val doc: String, + val isPublic: Boolean) { + + def defaultValueString: String = defaultValue.map(stringConverter).getOrElse("") + + override def toString: String = { + s"SQLConfEntry(key = $key, defaultValue=$defaultValueString, doc=$doc, isPublic = $isPublic)" + } + } + + private[sql] object SQLConfEntry { + + private def apply[T]( + key: String, + defaultValue: Option[T], + valueConverter: String => T, + stringConverter: T => String, + doc: String, + isPublic: Boolean): SQLConfEntry[T] = + sqlConfEntries.synchronized { + if (sqlConfEntries.containsKey(key)) { + throw new IllegalArgumentException(s"Duplicate SQLConfEntry. $key has been registered") + } + val entry = + new SQLConfEntry[T](key, defaultValue, valueConverter, stringConverter, doc, isPublic) + sqlConfEntries.put(key, entry) + entry + } + + def intConf( + key: String, + defaultValue: Option[Int] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Int] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toInt + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be int, but was $v") + } + }, _.toString, doc, isPublic) + + def longConf( + key: String, + defaultValue: Option[Long] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Long] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toLong + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be long, but was $v") + } + }, _.toString, doc, isPublic) + + def doubleConf( + key: String, + defaultValue: Option[Double] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Double] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toDouble + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be double, but was $v") + } + }, _.toString, doc, isPublic) + + def booleanConf( + key: String, + defaultValue: Option[Boolean] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Boolean] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toBoolean + } catch { + case _: IllegalArgumentException => + throw new IllegalArgumentException(s"$key should be boolean, but was $v") + } + }, _.toString, doc, isPublic) + + def stringConf( + key: String, + defaultValue: Option[String] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[String] = + SQLConfEntry(key, defaultValue, v => v, v => v, doc, isPublic) + + def enumConf[T]( + key: String, + valueConverter: String => T, + validValues: Set[T], + defaultValue: Option[T] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[T] = + SQLConfEntry(key, defaultValue, v => { + val _v = valueConverter(v) + if (!validValues.contains(_v)) { + throw new IllegalArgumentException( + s"The value of $key should be one of ${validValues.mkString(", ")}, but was $v") + } + _v + }, _.toString, doc, isPublic) + + def seqConf[T]( + key: String, + valueConverter: String => T, + defaultValue: Option[Seq[T]] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Seq[T]] = { + SQLConfEntry( + key, defaultValue, _.split(",").map(valueConverter), _.mkString(","), doc, isPublic) + } + + def stringSeqConf( + key: String, + defaultValue: Option[Seq[String]] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Seq[String]] = { + seqConf(key, s => s, defaultValue, doc, isPublic) + } + } + + import SQLConfEntry._ + + val COMPRESS_CACHED = booleanConf("spark.sql.inMemoryColumnarStorage.compressed", + defaultValue = Some(true), + doc = "When set to true Spark SQL will automatically select a compression codec for each " + + "column based on statistics of the data.") + + val COLUMN_BATCH_SIZE = intConf("spark.sql.inMemoryColumnarStorage.batchSize", + defaultValue = Some(10000), + doc = "Controls the size of batches for columnar caching. Larger batch sizes can improve " + + "memory utilization and compression, but risk OOMs when caching data.") + + val IN_MEMORY_PARTITION_PRUNING = + booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning", + defaultValue = Some(false), + doc = "") + + val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold", + defaultValue = Some(10 * 1024 * 1024), + doc = "Configures the maximum size in bytes for a table that will be broadcast to all worker " + + "nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " + + "Note that currently statistics are only supported for Hive Metastore tables where the " + + "commandANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.") + + val DEFAULT_SIZE_IN_BYTES = longConf("spark.sql.defaultSizeInBytes", isPublic = false) + + val SHUFFLE_PARTITIONS = intConf("spark.sql.shuffle.partitions", + defaultValue = Some(200), + doc = "Configures the number of partitions to use when shuffling data for joins or " + + "aggregations.") + + val CODEGEN_ENABLED = booleanConf("spark.sql.codegen", + defaultValue = Some(true), + doc = "When true, code will be dynamically generated at runtime for expression evaluation in" + + " a specific query. For some queries with complicated expression this option can lead to " + + "significant speed-ups. However, for simple queries this can actually slow down query " + + "execution.") + + val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", + defaultValue = Some(false), + doc = "") + + val DIALECT = stringConf("spark.sql.dialect", defaultValue = Some("sql"), doc = "") + + val CASE_SENSITIVE = booleanConf("spark.sql.caseSensitive", + defaultValue = Some(true), + doc = "") + + val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString", + defaultValue = Some(false), + doc = "Some other Parquet-producing systems, in particular Impala and older versions of " + + "Spark SQL, do not differentiate between binary data and strings when writing out the " + + "Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide " + + "compatibility with these systems.") + + val PARQUET_INT96_AS_TIMESTAMP = booleanConf("spark.sql.parquet.int96AsTimestamp", + defaultValue = Some(true), + doc = "Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. " + + "Spark would also store Timestamp as INT96 because we need to avoid precision lost of the " + + "nanoseconds field. This flag tells Spark SQL to interpret INT96 data as a timestamp to " + + "provide compatibility with these systems.") + + val PARQUET_CACHE_METADATA = booleanConf("spark.sql.parquet.cacheMetadata", + defaultValue = Some(true), + doc = "Turns on caching of Parquet schema metadata. Can speed up querying of static data.") + + val PARQUET_COMPRESSION = enumConf("spark.sql.parquet.compression.codec", + valueConverter = v => v.toLowerCase, + validValues = Set("uncompressed", "snappy", "gzip", "lzo"), + defaultValue = Some("gzip"), + doc = "Sets the compression codec use when writing Parquet files. Acceptable values include: " + + "uncompressed, snappy, gzip, lzo.") + + val PARQUET_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.parquet.filterPushdown", + defaultValue = Some(false), + doc = "Turn on Parquet filter pushdown optimization. This feature is turned off by default" + + " because of a known bug in Paruet 1.6.0rc3 " + + "(PARQUET-136). However, " + + "if your table doesn't contain any nullable string or binary columns, it's still safe to " + + "turn this feature on.") + + val PARQUET_USE_DATA_SOURCE_API = booleanConf("spark.sql.parquet.useDataSourceApi", + defaultValue = Some(true), + doc = "") + + val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", + defaultValue = Some(false), + doc = "") + + val HIVE_VERIFY_PARTITIONPATH = booleanConf("spark.sql.hive.verifyPartitionPath", + defaultValue = Some(true), + doc = "") + + val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", + defaultValue = Some("_corrupt_record"), + doc = "") + + val BROADCAST_TIMEOUT = intConf("spark.sql.broadcastTimeout", + defaultValue = Some(5 * 60), + doc = "") // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. - val EXTERNAL_SORT = "spark.sql.planner.externalSort" - val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" + val EXTERNAL_SORT = booleanConf("spark.sql.planner.externalSort", + defaultValue = Some(true), + doc = "When true, performs sorts spilling to disk as needed otherwise sort each partition in" + + " memory.") + + val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", + defaultValue = Some(false), + doc = "") // This is only used for the thriftserver - val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" - val THRIFTSERVER_UI_STATEMENT_LIMIT = "spark.sql.thriftserver.ui.retainedStatements" - val THRIFTSERVER_UI_SESSION_LIMIT = "spark.sql.thriftserver.ui.retainedSessions" + val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool", + doc = "Set a Fair Scheduler pool for a JDBC client session") + + val THRIFTSERVER_UI_STATEMENT_LIMIT = intConf("spark.sql.thriftserver.ui.retainedStatements", + defaultValue = Some(200), + doc = "") + + val THRIFTSERVER_UI_SESSION_LIMIT = intConf("spark.sql.thriftserver.ui.retainedSessions", + defaultValue = Some(200), + doc = "") // This is used to set the default data source - val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default" + val DEFAULT_DATA_SOURCE_NAME = stringConf("spark.sql.sources.default", + defaultValue = Some("org.apache.spark.sql.parquet"), + doc = "") + // This is used to control the when we will split a schema's JSON string to multiple pieces // in order to fit the JSON string in metastore's table property (by default, the value has // a length restriction of 4000 characters). We will split the JSON string of a schema // to its length exceeds the threshold. - val SCHEMA_STRING_LENGTH_THRESHOLD = "spark.sql.sources.schemaStringLengthThreshold" + val SCHEMA_STRING_LENGTH_THRESHOLD = intConf("spark.sql.sources.schemaStringLengthThreshold", + defaultValue = Some(4000), + doc = "") // Whether to perform partition discovery when loading external data sources. Default to true. - val PARTITION_DISCOVERY_ENABLED = "spark.sql.sources.partitionDiscovery.enabled" + val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled", + defaultValue = Some(true), + doc = "") // Whether to perform partition column type inference. Default to true. - val PARTITION_COLUMN_TYPE_INFERENCE = "spark.sql.sources.partitionColumnTypeInference.enabled" + val PARTITION_COLUMN_TYPE_INFERENCE = + booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled", + defaultValue = Some(true), + doc = "") // The output committer class used by FSBasedRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. // NOTE: This property should be set in Hadoop `Configuration` rather than Spark `SQLConf` - val OUTPUT_COMMITTER_CLASS = "spark.sql.sources.outputCommitterClass" + val OUTPUT_COMMITTER_CLASS = + stringConf("spark.sql.sources.outputCommitterClass", isPublic = false) // Whether to perform eager analysis when constructing a dataframe. // Set to false when debugging requires the ability to look at invalid query plans. - val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" + val DATAFRAME_EAGER_ANALYSIS = booleanConf("spark.sql.eagerAnalysis", + defaultValue = Some(true), + doc = "") // Whether to automatically resolve ambiguity in join conditions for self-joins. // See SPARK-6231. - val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = "spark.sql.selfJoinAutoResolveAmbiguity" + val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = + booleanConf("spark.sql.selfJoinAutoResolveAmbiguity", defaultValue = Some(true), doc = "") // Whether to retain group by columns or not in GroupedData.agg. - val DATAFRAME_RETAIN_GROUP_COLUMNS = "spark.sql.retainGroupColumns" + val DATAFRAME_RETAIN_GROUP_COLUMNS = booleanConf("spark.sql.retainGroupColumns", + defaultValue = Some(true), + doc = "") - val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2" + val USE_SQL_SERIALIZER2 = booleanConf("spark.sql.useSerializer2", + defaultValue = Some(true), doc = "") - val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI" + val USE_JACKSON_STREAMING_API = booleanConf("spark.sql.json.useJacksonStreamingAPI", + defaultValue = Some(true), doc = "") object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -131,56 +390,54 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * Note that the choice of dialect does not affect things like what tables are available or * how query execution is performed. */ - private[spark] def dialect: String = getConf(DIALECT, "sql") + private[spark] def dialect: String = getConf(DIALECT) /** When true tables cached using the in-memory columnar caching will be compressed. */ - private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "true").toBoolean + private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED) /** The compression codec for writing to a Parquetfile */ - private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION, "gzip") + private[spark] def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) + + private[spark] def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) /** The number of rows that will be */ - private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "10000").toInt + private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) /** Number of partitions to use for shuffle operators. */ - private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt + private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) /** When true predicates will be passed to the parquet record reader when possible. */ - private[spark] def parquetFilterPushDown = - getConf(PARQUET_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) /** When true uses Parquet implementation based on data source API */ - private[spark] def parquetUseDataSourceApi = - getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + private[spark] def parquetUseDataSourceApi: Boolean = getConf(PARQUET_USE_DATA_SOURCE_API) - private[spark] def orcFilterPushDown = - getConf(ORC_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) /** When true uses verifyPartitionPath to prune the path which is not exists. */ - private[spark] def verifyPartitionPath = - getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean + private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITIONPATH) /** When true the planner will use the external sort, which may spill to disk. */ - private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "true").toBoolean + private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) /** * Sort merge join would sort the two side of join first, and then iterate both sides together * only once to get all matches. Using sort merge join can save a lot of memory usage compared * to HashJoin. */ - private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean + private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) /** * When set to true, Spark SQL will use the Janino at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster * than interpreted evaluation, but there are some start-up costs (5-10ms) due to compilation. */ - private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "true").toBoolean + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED) /** * caseSensitive analysis true by default */ - def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, "true").toBoolean + def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE) /** * When set to true, Spark SQL will use managed memory for certain operations. This option only @@ -188,15 +445,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * * Defaults to false as this feature is currently experimental. */ - private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) - private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) /** * Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0 */ - private[spark] def useJacksonStreamingAPI: Boolean = - getConf(USE_JACKSON_STREAMING_API, "true").toBoolean + private[spark] def useJacksonStreamingAPI: Boolean = getConf(USE_JACKSON_STREAMING_API) /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -205,8 +461,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * * Hive setting: hive.auto.convert.join.noconditionaltask.size, whose default value is 10000. */ - private[spark] def autoBroadcastJoinThreshold: Int = - getConf(AUTO_BROADCASTJOIN_THRESHOLD, (10 * 1024 * 1024).toString).toInt + private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) /** * The default size in bytes to assign to a logical operator's estimation statistics. By default, @@ -215,82 +470,116 @@ private[sql] class SQLConf extends Serializable with CatalystConf { * in joins. */ private[spark] def defaultSizeInBytes: Long = - getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong + getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) /** * When set to true, we always treat byte arrays in Parquet files as strings. */ - private[spark] def isParquetBinaryAsString: Boolean = - getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING) /** * When set to true, we always treat INT96Values in Parquet files as timestamp. */ - private[spark] def isParquetINT96AsTimestamp: Boolean = - getConf(PARQUET_INT96_AS_TIMESTAMP, "true").toBoolean + private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) /** * When set to true, partition pruning for in-memory columnar tables is enabled. */ - private[spark] def inMemoryPartitionPruning: Boolean = - getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean + private[spark] def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) - private[spark] def columnNameOfCorruptRecord: String = - getConf(COLUMN_NAME_OF_CORRUPT_RECORD, "_corrupt_record") + private[spark] def columnNameOfCorruptRecord: String = getConf(COLUMN_NAME_OF_CORRUPT_RECORD) /** * Timeout in seconds for the broadcast wait time in hash join */ - private[spark] def broadcastTimeout: Int = - getConf(BROADCAST_TIMEOUT, (5 * 60).toString).toInt + private[spark] def broadcastTimeout: Int = getConf(BROADCAST_TIMEOUT) - private[spark] def defaultDataSourceName: String = - getConf(DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.parquet") + private[spark] def defaultDataSourceName: String = getConf(DEFAULT_DATA_SOURCE_NAME) - private[spark] def partitionDiscoveryEnabled() = - getConf(SQLConf.PARTITION_DISCOVERY_ENABLED, "true").toBoolean + private[spark] def partitionDiscoveryEnabled(): Boolean = + getConf(SQLConf.PARTITION_DISCOVERY_ENABLED) - private[spark] def partitionColumnTypeInferenceEnabled() = - getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE, "true").toBoolean + private[spark] def partitionColumnTypeInferenceEnabled(): Boolean = + getConf(SQLConf.PARTITION_COLUMN_TYPE_INFERENCE) // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. - private[spark] def schemaStringLengthThreshold: Int = - getConf(SCHEMA_STRING_LENGTH_THRESHOLD, "4000").toInt + private[spark] def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD) - private[spark] def dataFrameEagerAnalysis: Boolean = - getConf(DATAFRAME_EAGER_ANALYSIS, "true").toBoolean + private[spark] def dataFrameEagerAnalysis: Boolean = getConf(DATAFRAME_EAGER_ANALYSIS) private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = - getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY, "true").toBoolean + getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) - private[spark] def dataFrameRetainGroupColumns: Boolean = - getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean + private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS) /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ def setConf(props: Properties): Unit = settings.synchronized { - props.foreach { case (k, v) => settings.put(k, v) } + props.foreach { case (k, v) => setConfString(k, v) } } - /** Set the given Spark SQL configuration property. */ - def setConf(key: String, value: String): Unit = { + /** Set the given Spark SQL configuration property using a `string` value. */ + def setConfString(key: String, value: String): Unit = { require(key != null, "key cannot be null") require(value != null, s"value cannot be null for key: $key") + val entry = sqlConfEntries.get(key) + if (entry != null) { + // Only verify configs in the SQLConf object + entry.valueConverter(value) + } settings.put(key, value) } + /** Set the given Spark SQL configuration property. */ + def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + require(entry != null, "entry cannot be null") + require(value != null, s"value cannot be null for key: ${entry.key}") + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + settings.put(entry.key, entry.stringConverter(value)) + } + /** Return the value of Spark SQL configuration property for the given key. */ - def getConf(key: String): String = { - Option(settings.get(key)).getOrElse(throw new NoSuchElementException(key)) + def getConfString(key: String): String = { + Option(settings.get(key)). + orElse { + // Try to use the default value + Option(sqlConfEntries.get(key)).map(_.defaultValueString) + }. + getOrElse(throw new NoSuchElementException(key)) + } + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the + * desired one. + */ + def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = { + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + Option(settings.get(entry.key)).map(entry.valueConverter).getOrElse(defaultValue) } /** * Return the value of Spark SQL configuration property for the given key. If the key is not set - * yet, return `defaultValue`. + * yet, return `defaultValue` in [[SQLConfEntry]]. + */ + def getConf[T](entry: SQLConfEntry[T]): T = { + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + Option(settings.get(entry.key)).map(entry.valueConverter).orElse(entry.defaultValue). + getOrElse(throw new NoSuchElementException(entry.key)) + } + + /** + * Return the `string` value of Spark SQL configuration property for the given key. If the key is + * not set yet, return `defaultValue`. */ - def getConf(key: String, defaultValue: String): String = { + def getConfString(key: String, defaultValue: String): String = { + val entry = sqlConfEntries.get(key) + if (entry != null && defaultValue != "") { + // Only verify configs in the SQLConf object + entry.valueConverter(defaultValue) + } Option(settings.get(key)).getOrElse(defaultValue) } @@ -300,11 +589,25 @@ private[sql] class SQLConf extends Serializable with CatalystConf { */ def getAllConfs: immutable.Map[String, String] = settings.synchronized { settings.toMap } - private[spark] def unsetConf(key: String) { + /** + * Return all the configuration definitions that have been defined in [[SQLConf]]. Each + * definition contains key, defaultValue and doc. + */ + def getAllDefinedConfs: Seq[(String, String, String)] = sqlConfEntries.synchronized { + sqlConfEntries.values.filter(_.isPublic).map { entry => + (entry.key, entry.defaultValueString, entry.doc) + }.toSeq + } + + private[spark] def unsetConf(key: String): Unit = { settings -= key } - private[spark] def clear() { + private[spark] def unsetConf(entry: SQLConfEntry[_]): Unit = { + settings -= entry.key + } + + private[spark] def clear(): Unit = { settings.clear() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 6b605f7130167..04fc798bf3738 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,6 +31,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ @@ -79,13 +80,16 @@ class SQLContext(@transient val sparkContext: SparkContext) */ def setConf(props: Properties): Unit = conf.setConf(props) + /** Set the given Spark SQL configuration property. */ + private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = conf.setConf(entry, value) + /** * Set the given Spark SQL configuration property. * * @group config * @since 1.0.0 */ - def setConf(key: String, value: String): Unit = conf.setConf(key, value) + def setConf(key: String, value: String): Unit = conf.setConfString(key, value) /** * Return the value of Spark SQL configuration property for the given key. @@ -93,7 +97,22 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group config * @since 1.0.0 */ - def getConf(key: String): String = conf.getConf(key) + def getConf(key: String): String = conf.getConfString(key) + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue` in [[SQLConfEntry]]. + */ + private[sql] def getConf[T](entry: SQLConfEntry[T]): T = conf.getConf(entry) + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the + * desired one. + */ + private[sql] def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = { + conf.getConf(entry, defaultValue) + } /** * Return the value of Spark SQL configuration property for the given key. If the key is not set @@ -102,7 +121,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group config * @since 1.0.0 */ - def getConf(key: String, defaultValue: String): String = conf.getConf(key, defaultValue) + def getConf(key: String, defaultValue: String): String = conf.getConfString(key, defaultValue) /** * Return all the configuration properties that have been set (i.e. not the default). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index 305b306a79871..e59fa6e162900 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -44,8 +44,8 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr private val pair: Parser[LogicalPlan] = (key ~ ("=".r ~> value).?).? ^^ { - case None => SetCommand(None, output) - case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim)), output) + case None => SetCommand(None) + case Some(k ~ v) => SetCommand(Some(k.trim -> v.map(_.trim))) } def apply(input: String): LogicalPlan = parseAll(pair, input) match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index c9dfcea5d051e..5e9951f248ff2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.NoSuchElementException + import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD @@ -75,48 +77,92 @@ private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan * :: DeveloperApi :: */ @DeveloperApi -case class SetCommand( - kv: Option[(String, Option[String])], - override val output: Seq[Attribute]) - extends RunnableCommand with Logging { +case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableCommand with Logging { + + private def keyValueOutput: Seq[Attribute] = { + val schema = StructType( + StructField("key", StringType, false) :: + StructField("value", StringType, false) :: Nil) + schema.toAttributes + } - override def run(sqlContext: SQLContext): Seq[Row] = kv match { + private val (_output, runFunc): (Seq[Attribute], SQLContext => Seq[Row]) = kv match { // Configures the deprecated "mapred.reduce.tasks" property. case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, Some(value))) => - logWarning( - s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + - s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") - if (value.toInt < 1) { - val msg = s"Setting negative ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} for automatically " + - "determining the number of reducers is not supported." - throw new IllegalArgumentException(msg) - } else { - sqlContext.setConf(SQLConf.SHUFFLE_PARTITIONS, value) - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$value")) + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") + if (value.toInt < 1) { + val msg = + s"Setting negative ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} for automatically " + + "determining the number of reducers is not supported." + throw new IllegalArgumentException(msg) + } else { + sqlContext.setConf(SQLConf.SHUFFLE_PARTITIONS.key, value) + Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, value)) + } } + (keyValueOutput, runFunc) // Configures a single property. case Some((key, Some(value))) => - sqlContext.setConf(key, value) - Seq(Row(s"$key=$value")) + val runFunc = (sqlContext: SQLContext) => { + sqlContext.setConf(key, value) + Seq(Row(key, value)) + } + (keyValueOutput, runFunc) - // Queries all key-value pairs that are set in the SQLConf of the sqlContext. - // Notice that different from Hive, here "SET -v" is an alias of "SET". // (In Hive, "SET" returns all changed properties while "SET -v" returns all properties.) - case Some(("-v", None)) | None => - sqlContext.getAllConfs.map { case (k, v) => Row(s"$k=$v") }.toSeq + // Queries all key-value pairs that are set in the SQLConf of the sqlContext. + case None => + val runFunc = (sqlContext: SQLContext) => { + sqlContext.getAllConfs.map { case (k, v) => Row(k, v) }.toSeq + } + (keyValueOutput, runFunc) + + // Queries all properties along with their default values and docs that are defined in the + // SQLConf of the sqlContext. + case Some(("-v", None)) => + val runFunc = (sqlContext: SQLContext) => { + sqlContext.conf.getAllDefinedConfs.map { case (key, defaultValue, doc) => + Row(key, defaultValue, doc) + } + } + val schema = StructType( + StructField("key", StringType, false) :: + StructField("default", StringType, false) :: + StructField("meaning", StringType, false) :: Nil) + (schema.toAttributes, runFunc) // Queries the deprecated "mapred.reduce.tasks" property. case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) => - logWarning( - s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + - s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") - Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${sqlContext.conf.numShufflePartitions}")) + val runFunc = (sqlContext: SQLContext) => { + logWarning( + s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"showing ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") + Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, sqlContext.conf.numShufflePartitions.toString)) + } + (keyValueOutput, runFunc) // Queries a single property. case Some((key, None)) => - Seq(Row(s"$key=${sqlContext.getConf(key, "")}")) + val runFunc = (sqlContext: SQLContext) => { + val value = + try { + sqlContext.getConf(key) + } catch { + case _: NoSuchElementException => "" + } + Seq(Row(key, value)) + } + (keyValueOutput, runFunc) } + + override val output: Seq[Attribute] = _output + + override def run(sqlContext: SQLContext): Seq[Row] = runFunc(sqlContext) + } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 3ee4033baee2e..2964edac1aba2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -48,7 +48,7 @@ package object debug { */ implicit class DebugSQLContext(sqlContext: SQLContext) { def debug(): Unit = { - sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 39360e13313a3..65ecad9878f8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -113,12 +113,12 @@ private[sql] case class ParquetTableScan( .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - conf.set( - SQLConf.PARQUET_CACHE_METADATA, - sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true")) + conf.setBoolean( + SQLConf.PARQUET_CACHE_METADATA.key, + sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, true)) // Use task side metadata in parquet - conf.setBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true); + conf.setBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) val baseRDD = new org.apache.spark.rdd.NewHadoopRDD( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index bba6f1ec96aa8..4c702c3b0d43f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -220,7 +220,7 @@ private[sql] class ParquetRelation2( } conf.setClass( - SQLConf.OUTPUT_COMMITTER_CLASS, + SQLConf.OUTPUT_COMMITTER_CLASS.key, committerClass, classOf[ParquetOutputCommitter]) @@ -259,7 +259,7 @@ private[sql] class ParquetRelation2( filters: Array[Filter], inputFiles: Array[FileStatus], broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { - val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean + val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown // Create the function to set variable Parquet confs at both driver and executor side. val initLocalJobFuncOpt = @@ -498,7 +498,7 @@ private[sql] object ParquetRelation2 extends Logging { ParquetTypesConverter.convertToString(dataSchema.toAttributes)) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - conf.set(SQLConf.PARQUET_CACHE_METADATA, useMetadataCache.toString) + conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) } /** This closure sets input paths at the driver side. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 3dbe6faabf453..d39a20b388375 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -323,7 +323,7 @@ private[sql] abstract class BaseWriterContainer( private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { val committerClass = context.getConfiguration.getClass( - SQLConf.OUTPUT_COMMITTER_CLASS, null, classOf[OutputCommitter]) + SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) Option(committerClass).map { clazz => logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index 356a6100d2cf5..9fa394525d65c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -38,7 +38,7 @@ class LocalSQLContext protected[sql] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { /** Fewer partitions to speed up testing. */ - override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 790b405c72697..b26d3ab253a1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -68,12 +68,12 @@ class DataFrameAggregateSuite extends QueryTest { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - ctx.conf.setConf("spark.sql.retainGroupColumns", "false") + ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - ctx.conf.setConf("spark.sql.retainGroupColumns", "true") + ctx.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, true) } test("agg without groups") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index fa98e23e3d147..ba1d020f22f11 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -33,7 +33,7 @@ class DataFrameSuite extends QueryTest { test("analysis error should be eagerly reported") { val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) intercept[Exception] { testData.select('nonExistentName) } intercept[Exception] { @@ -47,11 +47,11 @@ class DataFrameSuite extends QueryTest { } // No more eager analysis once the flag is turned off - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) testData.select('nonExistentName) // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) } test("dataframe toString") { @@ -70,7 +70,7 @@ class DataFrameSuite extends QueryTest { test("invalid plan toString, debug mode") { val oldSetting = ctx.conf.dataFrameEagerAnalysis - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ @@ -83,7 +83,7 @@ class DataFrameSuite extends QueryTest { badPlan.toString) // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) } test("access complex data") { @@ -556,13 +556,13 @@ class DataFrameSuite extends QueryTest { test("SPARK-6899") { val originalValue = ctx.conf.codegenEnabled - ctx.setConf(SQLConf.CODEGEN_ENABLED, "true") + ctx.setConf(SQLConf.CODEGEN_ENABLED, true) try{ checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) } finally { - ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index ffd26c4f5a7c2..20390a5544304 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -95,14 +95,14 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) Seq( ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } } @@ -118,7 +118,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", @@ -127,7 +127,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) } ctx.sql("UNCACHE TABLE testData") @@ -416,7 +416,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ctx.sql("CACHE TABLE testData") val tmp = ctx.conf.autoBroadcastJoinThreshold - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=1000000000") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastLeftSemiJoinHash]) @@ -424,7 +424,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) @@ -432,7 +432,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp) ctx.sql("UNCACHE TABLE testData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala new file mode 100644 index 0000000000000..2e33777f14adc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfEntrySuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SQLConf._ + +class SQLConfEntrySuite extends SparkFunSuite { + + val conf = new SQLConf + + test("intConf") { + val key = "spark.sql.SQLConfEntrySuite.int" + val confEntry = SQLConfEntry.intConf(key) + assert(conf.getConf(confEntry, 5) === 5) + + conf.setConf(confEntry, 10) + assert(conf.getConf(confEntry, 5) === 10) + + conf.setConfString(key, "20") + assert(conf.getConfString(key, "5") === "20") + assert(conf.getConfString(key) === "20") + assert(conf.getConf(confEntry, 5) === 20) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be int, but was abc") + } + + test("longConf") { + val key = "spark.sql.SQLConfEntrySuite.long" + val confEntry = SQLConfEntry.longConf(key) + assert(conf.getConf(confEntry, 5L) === 5L) + + conf.setConf(confEntry, 10L) + assert(conf.getConf(confEntry, 5L) === 10L) + + conf.setConfString(key, "20") + assert(conf.getConfString(key, "5") === "20") + assert(conf.getConfString(key) === "20") + assert(conf.getConf(confEntry, 5L) === 20L) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be long, but was abc") + } + + test("booleanConf") { + val key = "spark.sql.SQLConfEntrySuite.boolean" + val confEntry = SQLConfEntry.booleanConf(key) + assert(conf.getConf(confEntry, false) === false) + + conf.setConf(confEntry, true) + assert(conf.getConf(confEntry, false) === true) + + conf.setConfString(key, "true") + assert(conf.getConfString(key, "false") === "true") + assert(conf.getConfString(key) === "true") + assert(conf.getConf(confEntry, false) === true) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be boolean, but was abc") + } + + test("doubleConf") { + val key = "spark.sql.SQLConfEntrySuite.double" + val confEntry = SQLConfEntry.doubleConf(key) + assert(conf.getConf(confEntry, 5.0) === 5.0) + + conf.setConf(confEntry, 10.0) + assert(conf.getConf(confEntry, 5.0) === 10.0) + + conf.setConfString(key, "20.0") + assert(conf.getConfString(key, "5.0") === "20.0") + assert(conf.getConfString(key) === "20.0") + assert(conf.getConf(confEntry, 5.0) === 20.0) + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "abc") + } + assert(e.getMessage === s"$key should be double, but was abc") + } + + test("stringConf") { + val key = "spark.sql.SQLConfEntrySuite.string" + val confEntry = SQLConfEntry.stringConf(key) + assert(conf.getConf(confEntry, "abc") === "abc") + + conf.setConf(confEntry, "abcd") + assert(conf.getConf(confEntry, "abc") === "abcd") + + conf.setConfString(key, "abcde") + assert(conf.getConfString(key, "abc") === "abcde") + assert(conf.getConfString(key) === "abcde") + assert(conf.getConf(confEntry, "abc") === "abcde") + } + + test("enumConf") { + val key = "spark.sql.SQLConfEntrySuite.enum" + val confEntry = SQLConfEntry.enumConf(key, v => v, Set("a", "b", "c"), defaultValue = Some("a")) + assert(conf.getConf(confEntry) === "a") + + conf.setConf(confEntry, "b") + assert(conf.getConf(confEntry) === "b") + + conf.setConfString(key, "c") + assert(conf.getConfString(key, "a") === "c") + assert(conf.getConfString(key) === "c") + assert(conf.getConf(confEntry) === "c") + + val e = intercept[IllegalArgumentException] { + conf.setConfString(key, "d") + } + assert(e.getMessage === s"The value of $key should be one of a, b, c, but was d") + } + + test("stringSeqConf") { + val key = "spark.sql.SQLConfEntrySuite.stringSeq" + val confEntry = SQLConfEntry.stringSeqConf("spark.sql.SQLConfEntrySuite.stringSeq", + defaultValue = Some(Nil)) + assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c")) + + conf.setConf(confEntry, Seq("a", "b", "c", "d")) + assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d")) + + conf.setConfString(key, "a,b,c,d,e") + assert(conf.getConfString(key, "a,b,c") === "a,b,c,d,e") + assert(conf.getConfString(key) === "a,b,c,d,e") + assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d", "e")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 76d0dd1744a41..75791e9d53c20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -75,6 +75,14 @@ class SQLConfSuite extends QueryTest { test("deprecated property") { ctx.conf.clear() ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(ctx.getConf(SQLConf.SHUFFLE_PARTITIONS) === "10") + assert(ctx.conf.numShufflePartitions === 10) + } + + test("invalid conf value") { + ctx.conf.clear() + val e = intercept[IllegalArgumentException] { + ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") + } + assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 30db840166ca6..82f3fdb48b557 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -190,7 +190,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("aggregation with codegen") { val originalValue = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) // Prepare a table that we can group some rows. sqlContext.table("testData") .unionAll(sqlContext.table("testData")) @@ -287,7 +287,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(0, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } @@ -480,41 +480,41 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("sorting") { val before = sqlContext.conf.externalSortEnabled - sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") + sqlContext.setConf(SQLConf.EXTERNAL_SORT, false) sortTest() - sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before) } test("external sorting") { val before = sqlContext.conf.externalSortEnabled - sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") + sqlContext.setConf(SQLConf.EXTERNAL_SORT, true) sortTest() - sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before) } test("SPARK-6927 sorting with codegen on") { val externalbefore = sqlContext.conf.externalSortEnabled val codegenbefore = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + sqlContext.setConf(SQLConf.EXTERNAL_SORT, false) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) try{ sortTest() } finally { - sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore) } } test("SPARK-6927 external sorting with codegen on") { val externalbefore = sqlContext.conf.externalSortEnabled val codegenbefore = sqlContext.conf.codegenEnabled - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") - sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, true) try { sortTest() } finally { - sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore) } } @@ -908,25 +908,25 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { sql(s"SET $testKey=$testVal") checkAnswer( sql("SET"), - Row(s"$testKey=$testVal") + Row(testKey, testVal) ) sql(s"SET ${testKey + testKey}=${testVal + testVal}") checkAnswer( sql("set"), Seq( - Row(s"$testKey=$testVal"), - Row(s"${testKey + testKey}=${testVal + testVal}")) + Row(testKey, testVal), + Row(testKey + testKey, testVal + testVal)) ) // "set key" checkAnswer( sql(s"SET $testKey"), - Row(s"$testKey=$testVal") + Row(testKey, testVal) ) checkAnswer( sql(s"SET $nonexistentKey"), - Row(s"$nonexistentKey=") + Row(nonexistentKey, "") ) sqlContext.conf.clear() } @@ -1340,12 +1340,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-4699 case sensitivity SQL query") { - sqlContext.setConf(SQLConf.CASE_SENSITIVE, "false") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, false) val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) - sqlContext.setConf(SQLConf.CASE_SENSITIVE, "true") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, true) } test("SPARK-6145: ORDER BY test for nested fields") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 6545c6b314a4c..2c0879927a129 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -32,7 +32,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, "10") + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10) val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString @@ -41,14 +41,14 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Enable in-memory table scan accumulators ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { - ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) - ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) } before { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3e27f58a92d01..5854ab48db552 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -63,7 +63,7 @@ class PlannerSuite extends SparkFunSuite { test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = { - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold.toString) + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold) val fields = fieldTypes.zipWithIndex.map { case (dataType, index) => StructField(s"c${index}", dataType, true) } :+ StructField("key", IntegerType, true) @@ -119,12 +119,12 @@ class PlannerSuite extends SparkFunSuite { checkPlan(complexTypes, newThreshold = 901617) - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } test("InMemoryRelation statistics propagation") { val origThreshold = conf.autoBroadcastJoinThreshold - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920.toString) + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920) testData.limit(3).registerTempTable("tiny") sql("CACHE TABLE tiny") @@ -139,6 +139,6 @@ class PlannerSuite extends SparkFunSuite { assert(broadcastHashJoins.size === 1, "Should use broadcast hash join") assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join") - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold.toString) + setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index fca24364fe6ec..945d4375035fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -1077,14 +1077,14 @@ class JsonSuite extends QueryTest with TestJsonData { } test("SPARK-7565 MapType in JsonRDD") { - val useStreaming = ctx.getConf(SQLConf.USE_JACKSON_STREAMING_API, "true") + val useStreaming = ctx.conf.useJacksonStreamingAPI val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) try{ - for (useStreaming <- List("true", "false")) { + for (useStreaming <- List(true, false)) { ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) val temp = Utils.createTempDir().getPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index fa5d4eca05d9f..a2763c78b6450 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -51,7 +51,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { val query = df .select(output.map(e => Column(e)): _*) .where(Column(predicate)) @@ -314,17 +314,17 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("SPARK-6554: don't push down predicates which reference partition columns") { import sqlContext.implicits._ - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) @@ -343,17 +343,17 @@ class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with Before lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("SPARK-6742: don't push down predicates which reference partition columns") { import sqlContext.implicits._ - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index fc827bc4ca11b..284d99d4938d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -94,8 +94,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val data = (1 to 4).map(i => Tuple1(i.toString)) // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL // as we store Spark SQL schema in the extra metadata. - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data)) - withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data)) + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "false")(checkParquetFile(data)) + withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true")(checkParquetFile(data)) } test("fixed-length decimals") { @@ -231,7 +231,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val data = (0 until 10).map(i => (i, i.toString)) def checkCompressionCodec(codec: CompressionCodecName): Unit = { - withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> codec.name()) { withParquetFile(data) { path => assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { compressionCodecFor(path) @@ -408,7 +408,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val clonedConf = new Configuration(configuration) configuration.set( - SQLConf.OUTPUT_COMMITTER_CLASS, classOf[ParquetOutputCommitter].getCanonicalName) + SQLConf.OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter].getCanonicalName) configuration.set( "spark.sql.parquet.output.committer.class", @@ -440,11 +440,11 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key, originalConf.toString) } test("SPARK-6330 regression test") { @@ -464,10 +464,10 @@ class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfter private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index be3b34d5b9b70..fafad67dde3a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -128,11 +128,11 @@ class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAnd private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } } @@ -140,10 +140,10 @@ class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAn private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 3f77960d09246..00cc7d5ea580f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -27,7 +27,7 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter { // We want to test some edge cases. protected implicit lazy val caseInsensitiveContext = { val ctx = new SQLContext(TestSQLContext.sparkContext) - ctx.setConf(SQLConf.CASE_SENSITIVE, "false") + ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index ac4a00a6f3dac..fa01823e9417c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -37,11 +37,11 @@ trait SQLTestUtils { */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConf(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConf) + val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(sqlContext.conf.setConfString) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConf(key, value) + case (key, Some(value)) => sqlContext.conf.setConfString(key, value) case (key, None) => sqlContext.conf.unsetConf(key) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index c9da25253e13f..700d994bb6a83 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -153,9 +153,9 @@ object HiveThriftServer2 extends Logging { val sessionList = new mutable.LinkedHashMap[String, SessionInfo] val executionList = new mutable.LinkedHashMap[String, ExecutionInfo] val retainedStatements = - conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT, "200").toInt + conf.getConf(SQLConf.THRIFTSERVER_UI_STATEMENT_LIMIT) val retainedSessions = - conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT, "200").toInt + conf.getConf(SQLConf.THRIFTSERVER_UI_SESSION_LIMIT) var totalRunning = 0 override def onJobStart(jobStart: SparkListenerJobStart): Unit = { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e071103df925c..e8758887ff3a2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -219,7 +219,7 @@ private[hive] class SparkExecuteStatementOperation( result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) => + case SetCommand(Some((SQLConf.THRIFTSERVER_POOL.key, Some(value)))) => sessionToActivePool(parentSession.getSessionHandle) = value logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") case _ => diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 178bd1f5cb164..301aa5a6411e2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -113,8 +113,8 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === - s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") + assert(resultSet.getString(1) === "spark.sql.hive.version") + assert(resultSet.getString(2) === HiveContext.hiveExecutionVersion) } } @@ -238,7 +238,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { // first session, we get the default value of the session status { statement => - val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}") rs1.next() defaultV1 = rs1.getString(1) assert(defaultV1 != "200") @@ -256,19 +256,21 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { { statement => val queries = Seq( - s"SET ${SQLConf.SHUFFLE_PARTITIONS}=291", + s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}=291", "SET hive.cli.print.header=true" ) queries.map(statement.execute) - val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}") rs1.next() - assert("spark.sql.shuffle.partitions=291" === rs1.getString(1)) + assert("spark.sql.shuffle.partitions" === rs1.getString(1)) + assert("291" === rs1.getString(2)) rs1.close() val rs2 = statement.executeQuery("SET hive.cli.print.header") rs2.next() - assert("hive.cli.print.header=true" === rs2.getString(1)) + assert("hive.cli.print.header" === rs2.getString(1)) + assert("true" === rs2.getString(2)) rs2.close() }, @@ -276,7 +278,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { // default value { statement => - val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS}") + val rs1 = statement.executeQuery(s"SET ${SQLConf.SHUFFLE_PARTITIONS.key}") rs1.next() assert(defaultV1 === rs1.getString(1)) rs1.close() @@ -404,8 +406,8 @@ class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { withJdbcStatement { statement => val resultSet = statement.executeQuery("SET spark.sql.hive.version") resultSet.next() - assert(resultSet.getString(1) === - s"spark.sql.hive.version=${HiveContext.hiveExecutionVersion}") + assert(resultSet.getString(1) === "spark.sql.hive.version") + assert(resultSet.getString(2) === HiveContext.hiveExecutionVersion) } } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 82c0b494598a8..432de2564d080 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -47,17 +47,17 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Add Locale setting Locale.setDefault(Locale.US) // Set a relatively small column batch size for testing purposes - TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, "5") + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes - TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) } override def afterAll() { TestHive.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) - TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) - TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index 65d070bd3cbde..f458567e5d7ea 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -26,11 +26,11 @@ import org.apache.spark.sql.hive.test.TestHive class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { override def beforeAll() { super.beforeAll() - TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true") + TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) } override def afterAll() { - TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false") + TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) super.afterAll() } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c50835dd8f11d..4a66d6508ae0a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -21,15 +21,13 @@ import java.io.File import java.net.{URL, URLClassLoader} import java.sql.Timestamp -import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.spark.sql.catalyst.ParserDialect - import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap import scala.language.implicitConversions import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.parse.VariableSubstitution @@ -39,6 +37,9 @@ import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ +import org.apache.spark.sql.SQLConf.SQLConfEntry +import org.apache.spark.sql.SQLConf.SQLConfEntry._ +import org.apache.spark.sql.catalyst.ParserDialect import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} @@ -69,13 +70,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { import HiveContext._ + println("create HiveContext") + /** * When true, enables an experimental feature where metastore tables that use the parquet SerDe * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive * SerDe. */ - protected[sql] def convertMetastoreParquet: Boolean = - getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" + protected[sql] def convertMetastoreParquet: Boolean = getConf(CONVERT_METASTORE_PARQUET) /** * When true, also tries to merge possibly different but compatible Parquet schemas in different @@ -84,7 +86,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. */ protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = - getConf("spark.sql.hive.convertMetastoreParquet.mergeSchema", "false") == "true" + getConf(CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) /** * When true, a table created by a Hive CTAS statement (no USING clause) will be @@ -98,8 +100,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * - The CTAS statement specifies SequenceFile (STORED AS SEQUENCEFILE) as the file format * and no SerDe is specified (no ROW FORMAT SERDE clause). */ - protected[sql] def convertCTAS: Boolean = - getConf("spark.sql.hive.convertCTAS", "false").toBoolean + protected[sql] def convertCTAS: Boolean = getConf(CONVERT_CTAS) /** * The version of the hive client that will be used to communicate with the metastore. Note that @@ -117,8 +118,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * option is only valid when using the execution version of Hive. * - maven - download the correct version of hive on demand from maven. */ - protected[hive] def hiveMetastoreJars: String = - getConf(HIVE_METASTORE_JARS, "builtin") + protected[hive] def hiveMetastoreJars: String = getConf(HIVE_METASTORE_JARS) /** * A comma separated list of class prefixes that should be loaded using the classloader that @@ -128,11 +128,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * custom appenders that are used by log4j. */ protected[hive] def hiveMetastoreSharedPrefixes: Seq[String] = - getConf("spark.sql.hive.metastore.sharedPrefixes", jdbcPrefixes) - .split(",").filterNot(_ == "") - - private def jdbcPrefixes = Seq( - "com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc").mkString(",") + getConf(HIVE_METASTORE_SHARED_PREFIXES).filterNot(_ == "") /** * A comma separated list of class prefixes that should explicitly be reloaded for each version @@ -140,14 +136,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * prefix that typically would be shared (i.e. org.apache.spark.*) */ protected[hive] def hiveMetastoreBarrierPrefixes: Seq[String] = - getConf("spark.sql.hive.metastore.barrierPrefixes", "") - .split(",").filterNot(_ == "") + getConf(HIVE_METASTORE_BARRIER_PREFIXES).filterNot(_ == "") /* * hive thrift server use background spark sql thread pool to execute sql queries */ - protected[hive] def hiveThriftServerAsync: Boolean = - getConf("spark.sql.hive.thriftServer.async", "true").toBoolean + protected[hive] def hiveThriftServerAsync: Boolean = getConf(HIVE_THRIFT_SERVER_ASYNC) @transient protected[sql] lazy val substitutor = new VariableSubstitution() @@ -364,7 +358,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { hiveconf.set(key, value) } - /* A catalyst metadata catalog that points to the Hive Metastore. */ + private[sql] override def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + setConf(entry.key, entry.stringConverter(value)) + } + + /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient override protected[sql] lazy val catalog = new HiveMetastoreCatalog(metadataHive, this) with OverrideCatalog @@ -402,8 +400,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[hive] class SQLSession extends super.SQLSession { protected[sql] override lazy val conf: SQLConf = new SQLConf { override def dialect: String = getConf(SQLConf.DIALECT, "hiveql") - override def caseSensitiveAnalysis: Boolean = - getConf(SQLConf.CASE_SENSITIVE, "false").toBoolean + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } /** @@ -519,7 +516,50 @@ private[hive] object HiveContext { val hiveExecutionVersion: String = "0.13.1" val HIVE_METASTORE_VERSION: String = "spark.sql.hive.metastore.version" - val HIVE_METASTORE_JARS: String = "spark.sql.hive.metastore.jars" + val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", + defaultValue = Some("builtin"), + doc = "Location of the jars that should be used to instantiate the HiveMetastoreClient. This" + + " property can be one of three options: " + + "1. \"builtin\" Use Hive 0.13.1, which is bundled with the Spark assembly jar when " + + "-Phive is enabled. When this option is chosen, " + + "spark.sql.hive.metastore.version must be either 0.13.1 or not defined. " + + "2. \"maven\" Use Hive jars of specified version downloaded from Maven repositories." + + "3. A classpath in the standard format for both Hive and Hadoop.") + + val CONVERT_METASTORE_PARQUET = booleanConf("spark.sql.hive.convertMetastoreParquet", + defaultValue = Some(true), + doc = "When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + + "the built in support.") + + val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING = booleanConf( + "spark.sql.hive.convertMetastoreParquet.mergeSchema", + defaultValue = Some(false), + doc = "TODO") + + val CONVERT_CTAS = booleanConf("spark.sql.hive.convertCTAS", + defaultValue = Some(false), + doc = "TODO") + + val HIVE_METASTORE_SHARED_PREFIXES = stringSeqConf("spark.sql.hive.metastore.sharedPrefixes", + defaultValue = Some(jdbcPrefixes), + doc = "A comma separated list of class prefixes that should be loaded using the classloader " + + "that is shared between Spark SQL and a specific version of Hive. An example of classes " + + "that should be shared is JDBC drivers that are needed to talk to the metastore. Other " + + "classes that need to be shared are those that interact with classes that are already " + + "shared. For example, custom appenders that are used by log4j.") + + private def jdbcPrefixes = Seq( + "com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc") + + val HIVE_METASTORE_BARRIER_PREFIXES = stringSeqConf("spark.sql.hive.metastore.barrierPrefixes", + defaultValue = Some(Seq()), + doc = "A comma separated list of class prefixes that should explicitly be reloaded for each " + + "version of Hive that Spark SQL is communicating with. For example, Hive UDFs that are " + + "declared in a prefix that typically would be shared (i.e. org.apache.spark.*).") + + val HIVE_THRIFT_SERVER_ASYNC = booleanConf("spark.sql.hive.thriftServer.async", + defaultValue = Some(true), + doc = "TODO") /** Constructs a configuration for hive, where the metastore is located in a temp directory. */ def newTemporaryConfiguration(): Map[String, String] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 92155096202b3..f901bd8171508 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -112,12 +112,11 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { protected[hive] class SQLSession extends super.SQLSession { /** Fewer partitions to speed up testing. */ protected[sql] override lazy val conf: SQLConf = new SQLConf { - override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt + override def numShufflePartitions: Int = getConf(SQLConf.SHUFFLE_PARTITIONS, 5) // TODO as in unit test, conf.clear() probably be called, all of the value will be cleared. // The super.getConf(SQLConf.DIALECT) is "sql" by default, we need to set it as "hiveql" override def dialect: String = super.getConf(SQLConf.DIALECT, "hiveql") - override def caseSensitiveAnalysis: Boolean = - getConf(SQLConf.CASE_SENSITIVE, "false").toBoolean + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index a0d80dc39c108..af68615e8e9d6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -81,11 +81,11 @@ class HiveParquetSuite extends QueryTest with ParquetTest { } } - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") { run("Parquet data source enabled") } - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") { + withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "false") { run("Parquet data source disabled") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 79a85b24d2f60..cc294bc3e8bc3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -456,7 +456,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA withTable("savedJsonTable") { val df = (1 to 10).map(i => i -> s"str$i").toDF("a", "b") - withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "json") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { // Save the df as a managed table (by not specifying the path). df.write.saveAsTable("savedJsonTable") @@ -484,7 +484,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA } // Create an external table by specifying the path. - withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { df.write .format("org.apache.spark.sql.json") .mode(SaveMode.Append) @@ -508,7 +508,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA s"""{ "a": $i, "b": "str$i" }""" })) - withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { df.write .format("json") .mode(SaveMode.Append) @@ -516,7 +516,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA .saveAsTable("savedJsonTable") } - withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "json") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { createExternalTable("createdJsonTable", tempPath.toString) assert(table("createdJsonTable").schema === df.schema) checkAnswer(sql("SELECT * FROM createdJsonTable"), df) @@ -533,7 +533,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA checkAnswer(read.json(tempPath.toString), df) // Try to specify the schema. - withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "not a source name") { val schema = StructType(StructField("b", StringType, true) :: Nil) createExternalTable( "createdJsonTable", @@ -563,8 +563,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA test("scan a parquet table created through a CTAS statement") { withSQLConf( - "spark.sql.hive.convertMetastoreParquet" -> "true", - SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + HiveContext.CONVERT_METASTORE_PARQUET.key -> "true", + SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") { withTempTable("jt") { (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") @@ -706,7 +706,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA } test("SPARK-6024 wide schema support") { - withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD -> "4000") { + withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD.key -> "4000") { withTable("wide_schema") { // We will need 80 splits for this schema if the threshold is 4000. val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 78c94e6490e36..f067ea0d4fc75 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -167,7 +167,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { ctx.conf.settings.synchronized { val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1""") df = sql(query) bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") @@ -176,7 +176,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") - sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""") + sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp""") } after() @@ -225,7 +225,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { ctx.conf.settings.synchronized { val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1") df = sql(leftSemiJoinQuery) bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastLeftSemiJoinHash => j @@ -238,7 +238,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { assert(shj.size === 1, "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off") - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp") + sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 6d8d99ebc8164..51dabc67fa7c1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1084,14 +1084,16 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val testKey = "spark.sql.key.usedfortestonly" val testVal = "test.val.0" val nonexistentKey = "nonexistent" - val KV = "([^=]+)=([^=]*)".r - def collectResults(df: DataFrame): Set[(String, String)] = + def collectResults(df: DataFrame): Set[Any] = df.collect().map { case Row(key: String, value: String) => key -> value - case Row(KV(key, value)) => key -> value + case Row(key: String, defaultValue: String, doc: String) => (key, defaultValue, doc) }.toSet conf.clear() + val expectedConfs = conf.getAllDefinedConfs.toSet + assertResult(expectedConfs)(collectResults(sql("SET -v"))) + // "SET" itself returns all config variables currently specified in SQLConf. // TODO: Should we be listing the default here always? probably... assert(sql("SET").collect().size == 0) @@ -1102,16 +1104,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(hiveconf.get(testKey, "") == testVal) assertResult(Set(testKey -> testVal))(collectResults(sql("SET"))) - assertResult(Set(testKey -> testVal))(collectResults(sql("SET -v"))) sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { collectResults(sql("SET")) } - assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(sql("SET -v")) - } // "SET key" assertResult(Set(testKey -> testVal)) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 984d97d27bf54..e1c9926bed524 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.hive.{HiveQLDialect, MetastoreRelation} +import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ @@ -191,9 +191,9 @@ class SQLQuerySuite extends QueryTest { } } - val originalConf = getConf("spark.sql.hive.convertCTAS", "false") + val originalConf = convertCTAS - setConf("spark.sql.hive.convertCTAS", "true") + setConf(HiveContext.CONVERT_CTAS, true) sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") @@ -235,7 +235,7 @@ class SQLQuerySuite extends QueryTest { checkRelation("ctas1", false) sql("DROP TABLE ctas1") - setConf("spark.sql.hive.convertCTAS", originalConf) + setConf(HiveContext.CONVERT_CTAS, originalConf) } test("SQL Dialect Switching") { @@ -332,7 +332,7 @@ class SQLQuerySuite extends QueryTest { val origUseParquetDataSource = conf.parquetUseDataSourceApi try { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) sql( """CREATE TABLE ctas5 | STORED AS parquet AS @@ -348,7 +348,7 @@ class SQLQuerySuite extends QueryTest { "MANAGED_TABLE" ) - val default = getConf("spark.sql.hive.convertMetastoreParquet", "true") + val default = convertMetastoreParquet // use the Hive SerDe for parquet tables sql("set spark.sql.hive.convertMetastoreParquet = false") checkAnswer( @@ -356,7 +356,7 @@ class SQLQuerySuite extends QueryTest { sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) sql(s"set spark.sql.hive.convertMetastoreParquet = $default") } finally { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource) } } @@ -603,8 +603,8 @@ class SQLQuerySuite extends QueryTest { // generates an invalid query plan. val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) read.json(rdd).registerTempTable("data") - val originalConf = getConf("spark.sql.hive.convertCTAS", "false") - setConf("spark.sql.hive.convertCTAS", "false") + val originalConf = convertCTAS + setConf(HiveContext.CONVERT_CTAS, false) sql("CREATE TABLE explodeTest (key bigInt)") table("explodeTest").queryExecution.analyzed match { @@ -621,7 +621,7 @@ class SQLQuerySuite extends QueryTest { sql("DROP TABLE explodeTest") dropTempTable("data") - setConf("spark.sql.hive.convertCTAS", originalConf) + setConf(HiveContext.CONVERT_CTAS, originalConf) } test("sanity test for SPARK-6618") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 3864349cdbd89..c2e09800933b5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -153,7 +153,7 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) read.json(rdd2).registerTempTable("jt_array") - setConf("spark.sql.hive.convertMetastoreParquet", "true") + setConf(HiveContext.CONVERT_METASTORE_PARQUET, true) } override def afterAll(): Unit = { @@ -164,7 +164,7 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { sql("DROP TABLE normal_parquet") sql("DROP TABLE IF EXISTS jt") sql("DROP TABLE IF EXISTS jt_array") - setConf("spark.sql.hive.convertMetastoreParquet", "false") + setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) } test(s"conversion is working") { @@ -199,14 +199,14 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override def afterAll(): Unit = { super.afterAll() sql("DROP TABLE IF EXISTS test_parquet") - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("scan an empty parquet table") { @@ -546,12 +546,12 @@ class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { override def beforeAll(): Unit = { super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override def afterAll(): Unit = { super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("MetastoreRelation in InsertIntoTable will not be converted") { @@ -692,12 +692,12 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { override def beforeAll(): Unit = { super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) } override def afterAll(): Unit = { super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } test("values in arrays and maps stored in parquet are always nullable") { @@ -750,12 +750,12 @@ class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { override def beforeAll(): Unit = { super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") + conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) } override def afterAll(): Unit = { super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) + setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) } } From fee3438a32136a8edbca71efb566965587a88826 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 17 Jun 2015 23:31:30 -0700 Subject: [PATCH 02/22] [SPARK-8218][SQL] Add binary log math function JIRA: https://issues.apache.org/jira/browse/SPARK-8218 Because there is already `log` unary function defined, the binary log function is called `logarithm` for now. Author: Liang-Chi Hsieh Closes #6725 from viirya/expr_binary_log and squashes the following commits: bf96bd9 [Liang-Chi Hsieh] Compare log result in string. 102070d [Liang-Chi Hsieh] Round log result to better comparing in python test. fd01863 [Liang-Chi Hsieh] For comments. beed631 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log 6089d11 [Liang-Chi Hsieh] Remove unnecessary override. 8cf37b7 [Liang-Chi Hsieh] For comments. bc89597 [Liang-Chi Hsieh] For comments. db7dc38 [Liang-Chi Hsieh] Use ctor instead of companion object. 0634ef7 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log 1750034 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log 3d75bfc [Liang-Chi Hsieh] Fix scala style. 5b39c02 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log 23c54a3 [Liang-Chi Hsieh] Fix scala style. ebc9929 [Liang-Chi Hsieh] Let Logarithm accept one parameter too. 605574d [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log 21c3bfd [Liang-Chi Hsieh] Fix scala style. c6c187f [Liang-Chi Hsieh] For comments. c795342 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_binary_log f373bac [Liang-Chi Hsieh] Add binary log expression. --- python/pyspark/sql/functions.py | 18 ++++++++++++++++- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 20 +++++++++++++++++++ .../expressions/MathFunctionsSuite.scala | 18 +++++++++++++++++ .../org/apache/spark/sql/functions.scala | 16 +++++++++++++++ .../spark/sql/MathExpressionsSuite.scala | 13 ++++++++++++ 6 files changed, 85 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bbf465aca8d4d..177fc196e0834 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -18,6 +18,7 @@ """ A collections of builtin functions """ +import math import sys if sys.version < "3": @@ -143,7 +144,7 @@ def _(): 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + 'polar coordinates (r, theta).', 'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.', - 'pow': 'Returns the value of the first argument raised to the power of the second argument.' + 'pow': 'Returns the value of the first argument raised to the power of the second argument.', } _window_functions = { @@ -403,6 +404,21 @@ def when(condition, value): return Column(jc) +@since(1.4) +def log(col, base=math.e): + """Returns the first argument-based logarithm of the second argument. + + >>> df.select(log(df.age, 10.0).alias('ten')).map(lambda l: str(l.ten)[:7]).collect() + ['0.30102', '0.69897'] + + >>> df.select(log(df.age).alias('e')).map(lambda l: str(l.e)[:7]).collect() + ['0.69314', '1.60943'] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.log(base, _to_java_column(col)) + return Column(jc) + + @since(1.4) def lag(col, count=1, default=None): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 97b123ec2f6d9..13b2bb05f5280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -112,6 +112,7 @@ object FunctionRegistry { expression[Expm1]("expm1"), expression[Floor]("floor"), expression[Hypot]("hypot"), + expression[Logarithm]("log"), expression[Log]("ln"), expression[Log10]("log10"), expression[Log1p]("log1p"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 42c596b5b31ab..67cb0b508ca9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -255,3 +255,23 @@ case class Pow(left: Expression, right: Expression) """ } } + +case class Logarithm(left: Expression, right: Expression) + extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { + def this(child: Expression) = { + this(EulerNumber(), child) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val logCode = if (left.isInstanceOf[EulerNumber]) { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2)") + } else { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)") + } + logCode + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 864c954ee82cb..0050ad3fe8302 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -204,4 +204,22 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testBinary(Atan2, math.atan2) } + test("binary log") { + val f = (c1: Double, c2: Double) => math.log(c2) / math.log(c1) + val domain = (1 to 20).map(v => (v * 0.1, v * 0.2)) + + domain.foreach { case (v1, v2) => + checkEvaluation(Logarithm(Literal(v1), Literal(v2)), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(Logarithm(Literal(v2), Literal(v1)), f(v2 + 0.0, v1 + 0.0), EmptyRow) + checkEvaluation(new Logarithm(Literal(v1)), f(math.E, v1 + 0.0), EmptyRow) + } + checkEvaluation( + Logarithm(Literal.create(null, DoubleType), Literal(1.0)), + null, + create_row(null)) + checkEvaluation( + Logarithm(Literal(1.0), Literal.create(null, DoubleType)), + null, + create_row(null)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c5b77724aae17..dff0932c450a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1083,6 +1083,22 @@ object functions { */ def log(columnName: String): Column = log(Column(columnName)) + /** + * Returns the first argument-base logarithm of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr) + + /** + * Returns the first argument-base logarithm of the second argument. + * + * @group math_funcs + * @since 1.4.0 + */ + def log(base: Double, columnName: String): Column = log(base, Column(columnName)) + /** * Computes the logarithm of the given value in base 10. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index e2daaf6b730c5..7c9c121b956bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -236,6 +236,19 @@ class MathExpressionsSuite extends QueryTest { testOneToOneNonNegativeMathFunction(log1p, math.log1p) } + test("binary log") { + val df = Seq[(Integer, Integer)]((123, null)).toDF("a", "b") + checkAnswer( + df.select(org.apache.spark.sql.functions.log("a"), + org.apache.spark.sql.functions.log(2.0, "a"), + org.apache.spark.sql.functions.log("b")), + Row(math.log(123), math.log(123) / math.log(2), null)) + + checkAnswer( + df.selectExpr("log(a)", "log(2.0, a)", "log(b)"), + Row(math.log(123), math.log(123) / math.log(2), null)) + } + test("abs") { val input = Seq[(java.lang.Double, java.lang.Double)]((null, null), (0.0, 0.0), (1.5, 1.5), (-2.5, 2.5)) From e86fbdb1e6f1538f65ef78d90bbc41604f6bd580 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 17 Jun 2015 23:46:57 -0700 Subject: [PATCH 03/22] [SPARK-8283][SQL] Resolve udf_struct test failure in HiveCompatibilitySuite This PR aimed to resolve udf_struct test failure in HiveCompatibilitySuite. Currently, this is done by loosening CreateStruct's children type from NamedExpression to Expression and automatically generating StructField name for non-NamedExpression children. The naming convention for unnamed children follows the udf's counterpart in Hive: `col1, col2, col3, ...` Author: Yijie Shen Closes #6828 from yijieshen/SPARK-8283 and squashes the following commits: 6052b73 [Yijie Shen] Doc fix 677e0b7 [Yijie Shen] Resolve udf_struct test failure by automatically generate structField name for non-NamedExpression children --- .../sql/catalyst/expressions/complexTypes.scala | 13 +++++++++---- .../sql/hive/execution/HiveCompatibilitySuite.scala | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 1aaf9b309efc3..72fdcebb4cbc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -53,7 +53,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { * Returns a Row containing the evaluation of all children expressions. * TODO: [[CreateStruct]] does not support codegen. */ -case class CreateStruct(children: Seq[NamedExpression]) extends Expression { +case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -62,9 +62,14 @@ case class CreateStruct(children: Seq[NamedExpression]) extends Expression { override lazy val dataType: StructType = { assert(resolved, s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") - val fields = children.map { child => - StructField(child.name, child.dataType, child.nullable, child.metadata) - } + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } StructType(fields) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 432de2564d080..f88e62763ca70 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -933,7 +933,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_stddev_pop", "udf_stddev_samp", "udf_string", - // "udf_struct", TODO: FIX THIS and enable it. + "udf_struct", "udf_substring", "udf_subtract", "udf_sum", From ddc5baf17d7b09623b91190ee7754a6c8f7b5d10 Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Thu, 18 Jun 2015 09:44:36 -0700 Subject: [PATCH 04/22] [SPARK-8320] [STREAMING] Add example in streaming programming guide that shows union of multiple input streams Added python code to https://spark.apache.org/docs/latest/streaming-programming-guide.html to the Level of Parallelism in Data Receiving section. Please review and let me know if there are any additional changes that are needed. Thank you. Author: Neelesh Srinivas Salian Closes #6862 from nssalian/SPARK-8320 and squashes the following commits: 4bfd126 [Neelesh Srinivas Salian] Changed loop structure to be more in line with Python style e5345de [Neelesh Srinivas Salian] Changes to kafak append, for loop and show to print() 3fc5c6d [Neelesh Srinivas Salian] SPARK-8320 --- docs/streaming-programming-guide.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 1eb3b30332e4f..b784d59666fec 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1937,6 +1937,14 @@ JavaPairDStream unifiedStream = streamingContext.union(kafkaStre unifiedStream.print(); {% endhighlight %} +
+{% highlight python %} +numStreams = 5 +kafkaStreams = [KafkaUtils.createStream(...) for _ in range (numStreams)] +unifiedStream = streamingContext.union(kafkaStreams) +unifiedStream.print() +{% endhighlight %} +
Another parameter that should be considered is the receiver's blocking interval, From 31641128b34d6f2aa7cb67324c24dd8b3ed84689 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 18 Jun 2015 13:00:31 -0700 Subject: [PATCH 05/22] [SPARK-8363][SQL] Move sqrt to math and extend UnaryMathExpression JIRA: https://issues.apache.org/jira/browse/SPARK-8363 Author: Liang-Chi Hsieh Closes #6823 from viirya/move_sqrt and squashes the following commits: 8977e11 [Liang-Chi Hsieh] Remove unnecessary old tests. d23e79e [Liang-Chi Hsieh] Explicitly indicate sqrt value sequence. 699f48b [Liang-Chi Hsieh] Use correct @since tag. 8dff6d1 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into move_sqrt bc2ed77 [Liang-Chi Hsieh] Remove/move arithmetic expression test and expression type checking test. Remove unnecessary Sqrt type rule. d38492f [Liang-Chi Hsieh] Now sqrt accepts boolean because type casting is handled by HiveTypeCoercion. 297cc90 [Liang-Chi Hsieh] Sqrt only accepts double input. ef4a21a [Liang-Chi Hsieh] Move sqrt to math. --- .../catalyst/analysis/HiveTypeCoercion.scala | 1 - .../sql/catalyst/expressions/arithmetic.scala | 32 ------------------- .../spark/sql/catalyst/expressions/math.scala | 2 ++ .../ArithmeticExpressionSuite.scala | 15 --------- .../ExpressionTypeCheckingSuite.scala | 2 -- .../expressions/MathFunctionsSuite.scala | 10 ++++++ .../org/apache/spark/sql/functions.scala | 10 +++++- .../spark/sql/MathExpressionsSuite.scala | 10 ++++++ 8 files changed, 31 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 189451d0d9ad7..8012b224eb444 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -307,7 +307,6 @@ trait HiveTypeCoercion { case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) - case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 167e460d5a93e..ace8427c8ddaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -67,38 +67,6 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = evalE } -case class Sqrt(child: Expression) extends UnaryArithmetic { - override def dataType: DataType = DoubleType - override def nullable: Boolean = true - override def toString: String = s"SQRT($child)" - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sqrt") - - private lazy val numeric = TypeUtils.getNumeric(child.dataType) - - protected override def evalInternal(evalE: Any) = { - val value = numeric.toDouble(evalE) - if (value < 0) null - else math.sqrt(value) - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - if (${eval.primitive} < 0.0) { - ${ev.isNull} = true; - } else { - ${ev.primitive} = java.lang.Math.sqrt(${eval.primitive}); - } - } - """ - } -} - /** * A function that get the absolute value of the numeric value. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 67cb0b508ca9e..3b83c6da0e60c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -193,6 +193,8 @@ case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") +case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") + case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 3f4843259e80b..4bbbbe6c7f091 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -142,19 +142,4 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1) checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1) } - - test("SQRT") { - val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) - val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => create_row(l.toDouble)) - val d = 'a.double.at(0) - - for ((row, expected) <- rowSequence zip expectedResults) { - checkEvaluation(Sqrt(d), expected, row) - } - - checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) - checkEvaluation(Sqrt(-1), null, EmptyRow) - checkEvaluation(Sqrt(-1.5), null, EmptyRow) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala index dcb3635c5ccae..49b111989799b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -54,8 +54,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { test("check types for unary arithmetic") { assertError(UnaryMinus('stringField), "operator - accepts numeric type") - assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt - assertError(Sqrt('booleanField), "function sqrt accepts numeric type") assertError(Abs('stringField), "function abs accepts numeric type") assertError(BitwiseNot('stringField), "operator ~ accepts integral type") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 0050ad3fe8302..21e9b92b7214e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.DoubleType @@ -191,6 +192,15 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Log2, f, (-5 to -1).map(_ * 1.0), expectNull = true) } + test("sqrt") { + testUnary(Sqrt, math.sqrt, (0 to 20).map(_ * 0.1)) + testUnary(Sqrt, math.sqrt, (-5 to -1).map(_ * 1.0), expectNull = true) + + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) + checkEvaluation(Sqrt(Literal(-1.0)), null, EmptyRow) + checkEvaluation(Sqrt(Literal(-1.5)), null, EmptyRow) + } + test("pow") { testBinary(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index dff0932c450a8..d8a91bead7c33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -707,11 +707,19 @@ object functions { /** * Computes the square root of the specified float value. * - * @group normal_funcs + * @group math_funcs * @since 1.3.0 */ def sqrt(e: Column): Column = Sqrt(e.expr) + /** + * Computes the square root of the specified float value. + * + * @group math_funcs + * @since 1.5.0 + */ + def sqrt(colName: String): Column = sqrt(Column(colName)) + /** * Creates a new struct column. The input column must be a column in a [[DataFrame]], or * a derived column expression that is named (i.e. aliased). diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 7c9c121b956bb..2768d7dfc8030 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -270,6 +270,16 @@ class MathExpressionsSuite extends QueryTest { checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) } + test("sqrt") { + val df = Seq((1, 4)).toDF("a", "b") + checkAnswer( + df.select(sqrt("a"), sqrt("b")), + Row(1.0, 2.0)) + + checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null)) + checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null)) + } + test("negative") { checkAnswer( ctx.sql("SELECT negative(1), negative(0), negative(-1)"), From 9b2002722273f98e193ad6cd54c9626292ab27d1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 18 Jun 2015 13:45:58 -0700 Subject: [PATCH 06/22] [SPARK-8202] [PYSPARK] fix infinite loop during external sort in PySpark The batch size during external sort will grow up to max 10000, then shrink down to zero, causing infinite loop. Given the assumption that the items usually have similar size, so we don't need to adjust the batch size after first spill. cc JoshRosen rxin angelini Author: Davies Liu Closes #6714 from davies/batch_size and squashes the following commits: b170dfb [Davies Liu] update test b9be832 [Davies Liu] Merge branch 'batch_size' of github.com:davies/spark into batch_size 6ade745 [Davies Liu] update test 5c21777 [Davies Liu] Update shuffle.py e746aec [Davies Liu] fix batch size during sort --- python/pyspark/shuffle.py | 5 +---- python/pyspark/tests.py | 5 ++++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 81c420ce16541..67752c0d150b9 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -486,7 +486,7 @@ def sorted(self, iterator, key=None, reverse=False): goes above the limit. """ global MemoryBytesSpilled, DiskBytesSpilled - batch, limit = 100, self.memory_limit + batch, limit = 100, self._next_limit() chunks, current_chunk = [], [] iterator = iter(iterator) while True: @@ -512,9 +512,6 @@ def load(f): f.close() chunks.append(load(open(path, 'rb'))) current_chunk = [] - gc.collect() - batch //= 2 - limit = self._next_limit() MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 DiskBytesSpilled += os.path.getsize(path) os.unlink(path) # data will be deleted after close diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 11b402e6df6c1..78265423682b0 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -179,9 +179,12 @@ def test_in_memory_sort(self): list(sorter.sorted(l, key=lambda x: -x, reverse=True))) def test_external_sort(self): + class CustomizedSorter(ExternalSorter): + def _next_limit(self): + return self.memory_limit l = list(range(1024)) random.shuffle(l) - sorter = ExternalSorter(1) + sorter = CustomizedSorter(1) self.assertEqual(sorted(l), list(sorter.sorted(l))) self.assertGreater(shuffle.DiskBytesSpilled, 0) last = shuffle.DiskBytesSpilled From 44c931f006194a833f09517c9e35fb3cdf5852b1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 15:10:09 -0700 Subject: [PATCH 07/22] [SPARK-8353] [DOCS] Show anchor links when hovering over documentation headers This patch uses [AnchorJS](https://bryanbraun.github.io/anchorjs/) to show deep anchor links when hovering over headers in the Spark documentation. For example: ![image](https://cloud.githubusercontent.com/assets/50748/8240800/1502f85c-15ba-11e5-819a-97b231370a39.png) This makes it easier for users to link to specific sections of the documentation. I also removed some dead Javascript which isn't used in our current docs (it was introduced for the old AMPCamp training, but isn't used anymore). Author: Josh Rosen Closes #6808 from JoshRosen/SPARK-8353 and squashes the following commits: e59d8a7 [Josh Rosen] Suppress underline on hover f518b6a [Josh Rosen] Turn on for all headers, since we use H1s in a bunch of places a9fec01 [Josh Rosen] Add anchor links when hovering over headers; remove some dead JS code --- LICENSE | 1 + docs/_layouts/global.html | 1 + docs/css/main.css | 5 +++++ docs/js/main.js | 34 ++++++---------------------------- docs/js/vendor/anchor.min.js | 6 ++++++ 5 files changed, 19 insertions(+), 28 deletions(-) create mode 100755 docs/js/vendor/anchor.min.js diff --git a/LICENSE b/LICENSE index d0cd0dcb4bdb7..42010d9f5f0e6 100644 --- a/LICENSE +++ b/LICENSE @@ -950,3 +950,4 @@ The following components are provided under the MIT License. See project link fo (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) (The MIT License) Mockito (org.mockito:mockito-all:1.8.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) + (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index eebb3faf90fc0..b4952fe97ca0e 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -138,6 +138,7 @@

{{ page.title }}

+ diff --git a/docs/css/main.css b/docs/css/main.css index f6fe7d5f07da1..89305a7d3a358 100755 --- a/docs/css/main.css +++ b/docs/css/main.css @@ -146,3 +146,8 @@ ul.nav li.dropdown ul.dropdown-menu li.dropdown-submenu ul.dropdown-menu { .MathJax .mi { color: inherit } .MathJax .mf { color: inherit } .MathJax .mh { color: inherit } + +/** + * AnchorJS (anchor links when hovering over headers) + */ +a.anchorjs-link:hover { text-decoration: none; } diff --git a/docs/js/main.js b/docs/js/main.js index f1a90e47e89a7..f5d66b16f7b21 100755 --- a/docs/js/main.js +++ b/docs/js/main.js @@ -68,38 +68,11 @@ function codeTabs() { }); } -function makeCollapsable(elt, accordionClass, accordionBodyId, title) { - $(elt).addClass("accordion-inner"); - $(elt).wrap('
') - $(elt).wrap('
') - $(elt).wrap('
') - $(elt).parent().before( - '
' + - '' + - title + - '' + - '
' - ); -} - -// Enable "view solution" sections (for exercises) -function viewSolution() { - var counter = 0 - $("div.solution").each(function() { - var id = "solution_" + counter - makeCollapsable(this, "", id, - '' + - '' + "View Solution"); - counter++; - }); -} // A script to fix internal hash links because we have an overlapping top bar. // Based on https://github.com/twitter/bootstrap/issues/193#issuecomment-2281510 function maybeScrollToHash() { - console.log("HERE"); if (window.location.hash && $(window.location.hash).length) { - console.log("HERE2", $(window.location.hash), $(window.location.hash).offset().top); var newTop = $(window.location.hash).offset().top - 57; $(window).scrollTop(newTop); } @@ -107,7 +80,12 @@ function maybeScrollToHash() { $(function() { codeTabs(); - viewSolution(); + // Display anchor links when hovering over headers. For documentation of the + // configuration options, see the AnchorJS documentation. + anchors.options = { + placement: 'left' + }; + anchors.add(); $(window).bind('hashchange', function() { maybeScrollToHash(); diff --git a/docs/js/vendor/anchor.min.js b/docs/js/vendor/anchor.min.js new file mode 100755 index 0000000000000..68c3cb7073b6d --- /dev/null +++ b/docs/js/vendor/anchor.min.js @@ -0,0 +1,6 @@ +/*! + * AnchorJS - v1.1.1 - 2015-05-23 + * https://github.com/bryanbraun/anchorjs + * Copyright (c) 2015 Bryan Braun; Licensed MIT + */ +function AnchorJS(A){"use strict";this.options=A||{},this._applyRemainingDefaultOptions=function(A){this.options.icon=this.options.hasOwnProperty("icon")?A.icon:"",this.options.visible=this.options.hasOwnProperty("visible")?A.visible:"hover",this.options.placement=this.options.hasOwnProperty("placement")?A.placement:"right",this.options.class=this.options.hasOwnProperty("class")?A.class:""},this._applyRemainingDefaultOptions(A),this.add=function(A){var e,t,o,n,i,s,a,l,c,r,h,g,B,Q;if(this._applyRemainingDefaultOptions(this.options),A){if("string"!=typeof A)throw new Error("The selector provided to AnchorJS was invalid.")}else A="h1, h2, h3, h4, h5, h6";if(e=document.querySelectorAll(A),0===e.length)return!1;for(this._addBaselineStyles(),t=document.querySelectorAll("[id]"),o=[].map.call(t,function(A){return A.id}),i=0;i',B=document.createElement("div"),B.innerHTML=g,Q=B.childNodes,"always"===this.options.visible&&(Q[0].style.opacity="1"),""===this.options.icon&&(Q[0].style.fontFamily="anchorjs-icons",Q[0].style.fontStyle="normal",Q[0].style.fontVariant="normal",Q[0].style.fontWeight="normal"),"left"===this.options.placement?(Q[0].style.position="absolute",Q[0].style.marginLeft="-1em",Q[0].style.paddingRight="0.5em",e[i].insertBefore(Q[0],e[i].firstChild)):(Q[0].style.paddingLeft="0.375em",e[i].appendChild(Q[0]))}return this},this.remove=function(A){for(var e,t=document.querySelectorAll(A),o=0;o Date: Thu, 18 Jun 2015 16:00:27 -0700 Subject: [PATCH 08/22] [SPARK-8376] [DOCS] Add common lang3 to the Spark Flume Sink doc Commons Lang 3 has been added as one of the dependencies of Spark Flume Sink since #5703. This PR updates the doc for it. Author: zsxwing Closes #6829 from zsxwing/flume-sink-dep and squashes the following commits: f8617f0 [zsxwing] Add common lang3 to the Spark Flume Sink doc --- docs/streaming-flume-integration.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index c8ab146bcae0a..8d6e74370918f 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -99,6 +99,12 @@ Configuring Flume on the chosen machine requires the following two steps. artifactId = scala-library version = {{site.SCALA_VERSION}} + (iii) *Commons Lang 3 JAR*: Download the Commons Lang 3 JAR. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/apache/commons/commons-lang3/3.3.2/commons-lang3-3.3.2.jar)). + + groupId = org.apache.commons + artifactId = commons-lang3 + version = 3.3.2 + 2. **Configuration file**: On that machine, configure Flume agent to send data to an Avro sink by having the following in the configuration file. agent.sinks = spark From 207a98ca59757d9cdd033d0f72863ad9ffb4e4b9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 18 Jun 2015 16:45:14 -0700 Subject: [PATCH 09/22] [SPARK-8446] [SQL] Add helper functions for testing SparkPlan physical operators This patch introduces `SparkPlanTest`, a base class for unit tests of SparkPlan physical operators. This is analogous to Spark SQL's existing `QueryTest`, which does something similar for end-to-end tests with actual queries. These helper methods provide nicer error output when tests fail and help developers to avoid writing lots of boilerplate in order to execute manually constructed physical plans. Author: Josh Rosen Author: Josh Rosen Author: Michael Armbrust Closes #6885 from JoshRosen/spark-plan-test and squashes the following commits: f8ce275 [Josh Rosen] Fix some IntelliJ inspections and delete some dead code 84214be [Josh Rosen] Add an extra column which isn't part of the sort ae1896b [Josh Rosen] Provide implicits automatically a80f9b0 [Josh Rosen] Merge pull request #4 from marmbrus/pr/6885 d9ab1e4 [Michael Armbrust] Add simple resolver c60a44d [Josh Rosen] Manually bind references 996332a [Josh Rosen] Add types so that tests compile a46144a [Josh Rosen] WIP --- .../spark/sql/execution/SortSuite.scala | 44 +++++ .../spark/sql/execution/SparkPlanTest.scala | 167 ++++++++++++++++++ 2 files changed, 211 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala new file mode 100644 index 0000000000000..a1e3ca11b1ad9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class SortSuite extends SparkPlanTest { + + // This test was originally added as an example of how to use [[SparkPlanTest]]; + // it's not designed to be a comprehensive test of ExternalSort. + test("basic sorting using ExternalSort") { + + val input = Seq( + ("Hello", 4, 2.0), + ("Hello", 1, 1.0), + ("World", 8, 3.0) + ) + + checkAnswer( + input.toDF("a", "b", "c"), + ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan), + input.sorted) + + checkAnswer( + input.toDF("a", "b", "c"), + ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan), + input.sortBy(t => (t._2, t._1))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala new file mode 100644 index 0000000000000..13f3be8ca28d6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal + +import org.apache.spark.SparkFunSuite + +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.util._ + +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.{DataFrameHolder, Row, DataFrame} + +/** + * Base class for writing tests for individual physical operators. For an example of how this + * class's test helper methods can be used, see [[SortSuite]]. + */ +class SparkPlanTest extends SparkFunSuite { + + /** + * Creates a DataFrame from a local Seq of Product. + */ + implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { + TestSQLContext.implicits.localSeqToDataFrameHolder(data) + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + protected def checkAnswer( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[Row]): Unit = { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s. + */ + protected def checkAnswer[A <: Product : TypeTag]( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[A]): Unit = { + val expectedRows = expectedAnswer.map(Row.fromTuple) + SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } +} + +/** + * Helper methods for writing tests of individual physical operators. + */ +object SparkPlanTest { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate + * the physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + */ + def checkAnswer( + input: DataFrame, + planFunction: SparkPlan => SparkPlan, + expectedAnswer: Seq[Row]): Option[String] = { + + val outputPlan = planFunction(input.queryExecution.sparkPlan) + + // A very simple resolver to make writing tests easier. In contrast to the real resolver + // this is always case sensitive and does not try to handle scoping or complex type resolution. + val resolvedPlan = outputPlan transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).zipWithIndex.map { + case (a, i) => + (a.name, BoundReference(i, a.dataType, a.nullable)) + }.toMap + + plan.transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + converted.sortBy(_.toString()) + } + + val sparkAnswer: Seq[Row] = try { + resolvedPlan.executeCollect().toSeq + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing Spark plan: + | $outputPlan + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | Results do not match for Spark plan: + | $outputPlan + | == Results == + | ${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + return Some(errorMessage) + } + + None + } +} + From dc413138995b45a7a957acae007dc11622110310 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 18 Jun 2015 18:41:15 -0700 Subject: [PATCH 10/22] [SPARK-8218][SQL] Binary log math function update. Some minor updates based on after merging #6725. Author: Reynold Xin Closes #6871 from rxin/log and squashes the following commits: ab51542 [Reynold Xin] Use JVM log 76fc8de [Reynold Xin] Fixed arg. a7c1522 [Reynold Xin] [SPARK-8218][SQL] Binary log math function update. --- python/pyspark/sql/functions.py | 13 +++++++++---- .../spark/sql/catalyst/expressions/math.scala | 4 ++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 177fc196e0834..acdb01d3d3f5f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -404,18 +404,23 @@ def when(condition, value): return Column(jc) -@since(1.4) -def log(col, base=math.e): +@since(1.5) +def log(arg1, arg2=None): """Returns the first argument-based logarithm of the second argument. - >>> df.select(log(df.age, 10.0).alias('ten')).map(lambda l: str(l.ten)[:7]).collect() + If there is only one argument, then this takes the natural logarithm of the argument. + + >>> df.select(log(10.0, df.age).alias('ten')).map(lambda l: str(l.ten)[:7]).collect() ['0.30102', '0.69897'] >>> df.select(log(df.age).alias('e')).map(lambda l: str(l.e)[:7]).collect() ['0.69314', '1.60943'] """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.log(base, _to_java_column(col)) + if arg2 is None: + jc = sc._jvm.functions.log(_to_java_column(arg1)) + else: + jc = sc._jvm.functions.log(arg1, _to_java_column(arg2)) return Column(jc) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 3b83c6da0e60c..f79bf4aee00d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -260,6 +260,10 @@ case class Pow(left: Expression, right: Expression) case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { + + /** + * Natural log, i.e. using e as the base. + */ def this(child: Expression) = { this(EulerNumber(), child) } From 43f50decdd20fafc55913c56ffa30f56040090e4 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Thu, 18 Jun 2015 19:36:05 -0700 Subject: [PATCH 11/22] [SPARK-8135] Don't load defaults when reconstituting Hadoop Configurations Author: Sandy Ryza Closes #6679 from sryza/sandy-spark-8135 and squashes the following commits: c5554ff [Sandy Ryza] SPARK-8135. In SerializableWritable, don't load defaults when instantiating Configuration --- .../apache/spark/SerializableWritable.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../org/apache/spark/SparkHadoopWriter.scala | 3 +- .../spark/api/python/PythonHadoopUtil.scala | 6 +-- .../apache/spark/api/python/PythonRDD.scala | 12 +++--- .../org/apache/spark/rdd/CheckpointRDD.scala | 11 +++--- .../org/apache/spark/rdd/HadoopRDD.scala | 8 ++-- .../org/apache/spark/rdd/NewHadoopRDD.scala | 4 +- .../apache/spark/rdd/PairRDDFunctions.scala | 6 +-- .../apache/spark/rdd/RDDCheckpointData.scala | 3 +- .../util/SerializableConfiguration.scala | 36 ++++++++++++++++++ .../spark/util/SerializableJobConf.scala | 37 +++++++++++++++++++ .../sql/parquet/ParquetTableOperations.scala | 5 ++- .../apache/spark/sql/parquet/newParquet.scala | 7 ++-- .../sql/sources/DataSourceStrategy.scala | 8 ++-- .../spark/sql/sources/SqlNewHadoopRDD.scala | 4 +- .../apache/spark/sql/sources/commands.scala | 3 +- .../apache/spark/sql/sources/interfaces.scala | 6 +-- .../apache/spark/sql/hive/TableReader.scala | 9 ++--- .../hive/execution/InsertIntoHiveTable.scala | 7 ++-- .../spark/sql/hive/hiveWriterContainers.scala | 3 +- .../spark/sql/hive/orc/OrcRelation.scala | 5 ++- .../streaming/dstream/FileInputDStream.scala | 5 +-- .../dstream/PairDStreamFunctions.scala | 7 ++-- .../rdd/WriteAheadLogBackedBlockRDD.scala | 5 +-- .../streaming/scheduler/ReceiverTracker.scala | 9 +++-- 26 files changed, 146 insertions(+), 67 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala create mode 100644 core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala index cb2cae185256a..beb2e27254725 100644 --- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -41,7 +41,7 @@ class SerializableWritable[T <: Writable](@transient var t: T) extends Serializa private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() val ow = new ObjectWritable() - ow.setConf(new Configuration()) + ow.setConf(new Configuration(false)) ow.readFields(in) t = ow.get().asInstanceOf[T] } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a453c9bf4864a..141276ac901fb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -974,7 +974,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope { assertNotStopped() // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. - val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) + val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration)) val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) new HadoopRDD( this, diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 59ac82ccec53b..f5dd36cbcfe6d 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD +import org.apache.spark.util.SerializableJobConf /** * Internal helper class that saves an RDD using a Hadoop OutputFormat. @@ -42,7 +43,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) with Serializable { private val now = new Date() - private val conf = new SerializableWritable(jobConf) + private val conf = new SerializableJobConf(jobConf) private var jobID = 0 private var splitID = 0 diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index c9181a29d4756..b959b683d1674 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -19,8 +19,8 @@ package org.apache.spark.api.python import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SerializableWritable, SparkException} +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, SparkException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ import scala.util.{Failure, Success, Try} @@ -61,7 +61,7 @@ private[python] object Converter extends Logging { * Other objects are passed through without conversion. */ private[python] class WritableToJavaConverter( - conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] { + conf: Broadcast[SerializableConfiguration]) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 55a37f8c944b2..dc9f62f39e6d5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -36,7 +36,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import scala.util.control.NonFatal @@ -445,7 +445,7 @@ private[spark] object PythonRDD extends Logging { val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -471,7 +471,7 @@ private[spark] object PythonRDD extends Logging { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -497,7 +497,7 @@ private[spark] object PythonRDD extends Logging { val rdd = newAPIHadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -540,7 +540,7 @@ private[spark] object PythonRDD extends Logging { val rdd = hadoopRDDFromClassNames[K, V, F](sc, Some(path), inputFormatClass, keyClass, valueClass, mergedConf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) @@ -566,7 +566,7 @@ private[spark] object PythonRDD extends Logging { val rdd = hadoopRDDFromClassNames[K, V, F](sc, None, inputFormatClass, keyClass, valueClass, conf) - val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) + val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index a4715e3437d94..33e6998b2cb10 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -21,13 +21,12 @@ import java.io.IOException import scala.reflect.ClassTag -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -38,7 +37,7 @@ private[spark] class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { - val broadcastedConf = sc.broadcast(new SerializableWritable(sc.hadoopConfiguration)) + val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration)) @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) @@ -87,7 +86,7 @@ private[spark] object CheckpointRDD extends Logging { def writeToFile[T: ClassTag]( path: String, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], blockSize: Int = -1 )(ctx: TaskContext, iterator: Iterator[T]) { val env = SparkEnv.get @@ -135,7 +134,7 @@ private[spark] object CheckpointRDD extends Logging { def readFromFile[T]( path: Path, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], context: TaskContext ): Iterator[T] = { val env = SparkEnv.get @@ -164,7 +163,7 @@ private[spark] object CheckpointRDD extends Logging { val path = new Path(hdfsPath, "temp") val conf = SparkHadoopUtil.get.newConfiguration(new SparkConf()) val fs = path.getFileSystem(conf) - val broadcastedConf = sc.broadcast(new SerializableWritable(conf)) + val broadcastedConf = sc.broadcast(new SerializableConfiguration(conf)) sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 2cefe63d44b20..bee59a437f120 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -44,7 +44,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{NextIterator, Utils} +import org.apache.spark.util.{SerializableConfiguration, NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} import org.apache.spark.storage.StorageLevel @@ -100,7 +100,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp @DeveloperApi class HadoopRDD[K, V]( @transient sc: SparkContext, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], @@ -121,8 +121,8 @@ class HadoopRDD[K, V]( minPartitions: Int) = { this( sc, - sc.broadcast(new SerializableWritable(conf)) - .asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + sc.broadcast(new SerializableConfiguration(conf)) + .asInstanceOf[Broadcast[SerializableConfiguration]], None /* initLocalJobConfFuncOpt */, inputFormatClass, keyClass, diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 84456d6d868dc..f827270ee6a44 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -33,7 +33,7 @@ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel @@ -74,7 +74,7 @@ class NewHadoopRDD[K, V]( with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + private val confBroadcast = sc.broadcast(new SerializableConfiguration(conf)) // private val serializableConf = new SerializableWritable(conf) private val jobTrackerId: String = { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index cfd3e26faf2b9..91a6a2d039852 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -44,7 +44,7 @@ import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -1002,7 +1002,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val wrappedConf = new SerializableWritable(job.getConfiguration) + val wrappedConf = new SerializableConfiguration(job.getConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -1065,7 +1065,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val wrappedConf = new SerializableWritable(hadoopConf) + val wrappedConf = new SerializableConfiguration(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 1722c27e55003..acbd31aacdf59 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} +import org.apache.spark.util.SerializableConfiguration /** * Enumeration to manage state transitions of an RDD through checkpointing @@ -91,7 +92,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) // Save to file, and reload it as an RDD val broadcastedConf = rdd.context.broadcast( - new SerializableWritable(rdd.context.hadoopConfiguration)) + new SerializableConfiguration(rdd.context.hadoopConfiguration)) val newRDD = new CheckpointRDD[T](rdd.context, path.toString) if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { rdd.context.cleaner.foreach { cleaner => diff --git a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala new file mode 100644 index 0000000000000..30bcf1d2f24d5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.io.{ObjectInputStream, ObjectOutputStream} + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.util.Utils + +private[spark] +class SerializableConfiguration(@transient var value: Configuration) extends Serializable { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + out.defaultWriteObject() + value.write(out) + } + + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + value = new Configuration(false) + value.readFields(in) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala new file mode 100644 index 0000000000000..afbcc6efc850c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.{ObjectInputStream, ObjectOutputStream} + +import org.apache.hadoop.mapred.JobConf + +import org.apache.spark.util.Utils + +private[spark] +class SerializableJobConf(@transient var value: JobConf) extends Serializable { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + out.defaultWriteObject() + value.write(out) + } + + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + value = new JobConf(false) + value.readFields(in) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 65ecad9878f8e..b30fc171c0af1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -49,7 +49,8 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InternalRow, _} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, SerializableWritable, TaskContext} +import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.util.SerializableConfiguration /** * :: DeveloperApi :: @@ -329,7 +330,7 @@ private[sql] case class InsertIntoParquetTable( job.setOutputKeyClass(keyType) job.setOutputValueClass(classOf[InternalRow]) NewFileOutputFormat.setOutputPath(job, new Path(path)) - val wrappedConf = new SerializableWritable(job.getConfiguration) + val wrappedConf = new SerializableConfiguration(job.getConfiguration) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = sqlContext.sparkContext.newRddId() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 4c702c3b0d43f..c9de45e0ddfbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._ import scala.util.Try import com.google.common.base.Objects -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -42,8 +41,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SerializableWritable, SparkException, Partition => SparkPartition} +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.{Logging, SparkException, Partition => SparkPartition} private[sql] class DefaultSource extends HadoopFsRelationProvider { override def createRelation( @@ -258,7 +257,7 @@ private[sql] class ParquetRelation2( requiredColumns: Array[String], filters: Array[Filter], inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown // Create the function to set variable Parquet confs at both driver and executor side. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 4cf67439b9b8d..a8f56f4767407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.sources +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} import org.apache.spark.sql._ @@ -27,9 +28,8 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources} -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Logging, SerializableWritable, TaskContext} /** * A Strategy for planning scans over data sources defined using the sources API. @@ -91,7 +91,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // broadcast HadoopConf. val sharedHadoopConf = SparkHadoopUtil.get.conf val confBroadcast = - t.sqlContext.sparkContext.broadcast(new SerializableWritable(sharedHadoopConf)) + t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) pruneFilterProject( l, projects, @@ -126,7 +126,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Otherwise, the cost of broadcasting HadoopConf in every RDD will be high. val sharedHadoopConf = SparkHadoopUtil.get.conf val confBroadcast = - relation.sqlContext.sparkContext.broadcast(new SerializableWritable(sharedHadoopConf)) + relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) // Builds RDD[Row]s for each selected partition. val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala index ebad0c1564ec0..2bdc341021256 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala @@ -34,7 +34,7 @@ import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.{RDD, HadoopRDD} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} import scala.reflect.ClassTag @@ -65,7 +65,7 @@ private[spark] class SqlNewHadoopPartition( */ private[sql] class SqlNewHadoopRDD[K, V]( @transient sc : SparkContext, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], + broadcastedConf: Broadcast[SerializableConfiguration], @transient initDriverSideJobFuncOpt: Option[Job => Unit], initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index d39a20b388375..c16bd9ae52c81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} +import org.apache.spark.util.SerializableConfiguration private[sql] case class InsertIntoDataSource( logicalRelation: LogicalRelation, @@ -260,7 +261,7 @@ private[sql] abstract class BaseWriterContainer( with Logging with Serializable { - protected val serializableConf = new SerializableWritable(ContextUtil.getConfiguration(job)) + protected val serializableConf = new SerializableConfiguration(ContextUtil.getConfiguration(job)) // This is only used on driver side. @transient private val jobContext: JobContext = job diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 43d3507d7d2ba..7005c7079af91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -27,12 +27,12 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.SerializableWritable import org.apache.spark.sql.execution.RDDConversions import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration /** * ::DeveloperApi:: @@ -518,7 +518,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio requiredColumns: Array[String], filters: Array[Filter], inputPaths: Array[String], - broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { val inputStatuses = inputPaths.flatMap { input => val path = new Path(input) @@ -648,7 +648,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio requiredColumns: Array[String], filters: Array[Filter], inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { buildScan(requiredColumns, filters, inputFiles) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 485810320f3c1..439f39bafc926 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.hive -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ @@ -30,12 +29,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} -import org.apache.spark.{Logging, SerializableWritable} +import org.apache.spark.{Logging} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.util.Utils +import org.apache.spark.util.{SerializableConfiguration, Utils} /** * A trait for subclasses that handle table scans. @@ -72,7 +71,7 @@ class HadoopTableReader( // TODO: set aws s3 credentials. private val _broadcastedHiveConf = - sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf)) + sc.sparkContext.broadcast(new SerializableConfiguration(hiveExtraConf)) override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( @@ -276,7 +275,7 @@ class HadoopTableReader( val rdd = new HadoopRDD( sc.sparkContext, - _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + _broadcastedHiveConf.asInstanceOf[Broadcast[SerializableConfiguration]], Some(initializeJobConfFunc), inputFormatClass, classOf[Writable], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 1d306c5d10af8..404bb937aaf87 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -35,9 +35,10 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ -import org.apache.spark.{SerializableWritable, SparkException, TaskContext} +import org.apache.spark.{SparkException, TaskContext} import scala.collection.JavaConversions._ +import org.apache.spark.util.SerializableJobConf private[hive] case class InsertIntoHiveTable( @@ -64,7 +65,7 @@ case class InsertIntoHiveTable( rdd: RDD[InternalRow], valueClass: Class[_], fileSinkConf: FileSinkDesc, - conf: SerializableWritable[JobConf], + conf: SerializableJobConf, writerContainer: SparkHiveWriterContainer): Unit = { assert(valueClass != null, "Output value class not set") conf.value.setOutputValueClass(valueClass) @@ -172,7 +173,7 @@ case class InsertIntoHiveTable( } val jobConf = new JobConf(sc.hiveconf) - val jobConfSer = new SerializableWritable(jobConf) + val jobConfSer = new SerializableJobConf(jobConf) val writerContainer = if (numDynamicPartitions > 0) { val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ee440e304ec19..0bc69c00c241c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -37,6 +37,7 @@ import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableJobConf /** * Internal helper class that saves an RDD using a Hive OutputFormat. @@ -57,7 +58,7 @@ private[hive] class SparkHiveWriterContainer( PlanUtils.configureOutputJobPropertiesForStorageHandler(tableDesc) Utilities.copyTableJobPropertiesToConf(tableDesc, jobConf) } - protected val conf = new SerializableWritable(jobConf) + protected val conf = new SerializableJobConf(jobConf) private var jobID = 0 private var splitID = 0 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index f03c4cd54e7e6..77f1ca9ae0875 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -39,7 +39,8 @@ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreType import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.{Logging, SerializableWritable} +import org.apache.spark.{Logging} +import org.apache.spark.util.SerializableConfiguration /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -283,7 +284,7 @@ private[orc] case class OrcTableScan( classOf[Writable] ).asInstanceOf[HadoopRDD[NullWritable, Writable]] - val wrappedConf = new SerializableWritable(conf) + val wrappedConf = new SerializableConfiguration(conf) rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) => val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 6c1fab56740ee..86a8e2beff57c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -26,10 +26,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.spark.{SparkConf, SerializableWritable} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ -import org.apache.spark.util.{TimeStampedHashMap, Utils} +import org.apache.spark.util.{SerializableConfiguration, TimeStampedHashMap, Utils} /** * This class represents an input stream that monitors a Hadoop-compatible filesystem for new @@ -78,7 +77,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) extends InputDStream[(K, V)](ssc_) { - private val serializableConfOpt = conf.map(new SerializableWritable(_)) + private val serializableConfOpt = conf.map(new SerializableConfiguration(_)) /** * Minimum duration of remembering the information of selected files. Defaults to 60 seconds. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 358e4c66df7ba..71bec96d46c8d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,10 +24,11 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.spark.{HashPartitioner, Partitioner, SerializableWritable} +import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.StreamingContext.rddToFileName +import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. @@ -688,7 +689,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) conf: JobConf = new JobConf(ssc.sparkContext.hadoopConfiguration) ): Unit = ssc.withScope { // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints - val serializableConf = new SerializableWritable(conf) + val serializableConf = new SerializableJobConf(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value) @@ -721,7 +722,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) conf: Configuration = ssc.sparkContext.hadoopConfiguration ): Unit = ssc.withScope { // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints - val serializableConf = new SerializableWritable(conf) + val serializableConf = new SerializableConfiguration(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsNewAPIHadoopFile( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index ffce6a4c3c74c..31ce8e1ec14d7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -23,12 +23,11 @@ import java.util.UUID import scala.reflect.ClassTag import scala.util.control.NonFatal -import org.apache.commons.io.FileUtils - import org.apache.spark._ import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.util._ +import org.apache.spark.util.SerializableConfiguration /** * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]]. @@ -94,7 +93,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( // Hadoop configuration is not serializable, so broadcast it as a serializable. @transient private val hadoopConfig = sc.hadoopConfiguration - private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig) + private val broadcastedHadoopConf = new SerializableConfiguration(hadoopConfig) override def isValid(): Boolean = true diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index f1504b09c9873..e6cdbec11e94c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -21,10 +21,12 @@ import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} +import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} +import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, + StopReceiver} +import org.apache.spark.util.SerializableConfiguration /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -294,7 +296,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } val checkpointDirOption = Option(ssc.checkpointDir) - val serializableHadoopConf = new SerializableWritable(ssc.sparkContext.hadoopConfiguration) + val serializableHadoopConf = + new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) // Function to start the receiver on the worker node val startReceiver = (iterator: Iterator[Receiver[_]]) => { From 4ce3bab89f6bdf6208fdad2fbfaba0b53d1954e3 Mon Sep 17 00:00:00 2001 From: Lars Francke Date: Thu, 18 Jun 2015 19:40:32 -0700 Subject: [PATCH 12/22] [SPARK-8462] [DOCS] Documentation fixes for Spark SQL This fixes various minor documentation issues on the Spark SQL page Author: Lars Francke Closes #6890 from lfrancke/SPARK-8462 and squashes the following commits: dd7e302 [Lars Francke] Merge branch 'master' into SPARK-8462 34eff2c [Lars Francke] Minor documentation fixes --- docs/sql-programming-guide.md | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c6e6ec88a205f..9b5ea394a6efb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -819,8 +819,8 @@ saveDF(select(df, "name", "age"), "namesAndAges.parquet") You can also manually specify the data source that will be used along with any extra options that you would like to pass to the data source. Data sources are specified by their fully qualified -name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use the shorted -name (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types +name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can also use their short +names (`json`, `parquet`, `jdbc`). DataFrames of any type can be converted into other types using this syntax.
@@ -828,7 +828,7 @@ using this syntax. {% highlight scala %} val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") -df.select("name", "age").write.format("json").save("namesAndAges.parquet") +df.select("name", "age").write.format("json").save("namesAndAges.json") {% endhighlight %}
@@ -975,7 +975,7 @@ schemaPeople.write().parquet("people.parquet"); // The result of loading a parquet file is also a DataFrame. DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); -//Parquet files can also be registered as tables and then used in SQL statements. +// Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); List teenagerNames = teenagers.javaRDD().map(new Function() { @@ -1059,7 +1059,7 @@ SELECT * FROM parquetTable Table partitioning is a common optimization approach used in systems like Hive. In a partitioned table, data are usually stored in different directories, with partitioning column values encoded in the path of each partition directory. The Parquet data source is now able to discover and infer -partitioning information automatically. For exmaple, we can store all our previously used +partitioning information automatically. For example, we can store all our previously used population data into a partitioned table using the following directory structure, with two extra columns, `gender` and `country` as partitioning columns: @@ -1125,12 +1125,12 @@ source is now able to automatically detect this case and merge schemas of all th import sqlContext.implicits._ // Create a simple DataFrame, stored into a partition directory -val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") +val df1 = sc.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") df1.write.parquet("data/test_table/key=1") // Create another DataFrame in a new partition directory, // adding a new column and dropping an existing column -val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") +val df2 = sc.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") df2.write.parquet("data/test_table/key=2") // Read the partitioned table @@ -1138,7 +1138,7 @@ val df3 = sqlContext.read.parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together -// with the partiioning column appeared in the partition directory paths. +// with the partitioning column appeared in the partition directory paths. // root // |-- single: int (nullable = true) // |-- double: int (nullable = true) @@ -1169,7 +1169,7 @@ df3 = sqlContext.load("data/test_table", "parquet") df3.printSchema() # The final schema consists of all 3 columns in the Parquet files together -# with the partiioning column appeared in the partition directory paths. +# with the partitioning column appeared in the partition directory paths. # root # |-- single: int (nullable = true) # |-- double: int (nullable = true) @@ -1196,7 +1196,7 @@ df3 <- loadDF(sqlContext, "data/test_table", "parquet") printSchema(df3) # The final schema consists of all 3 columns in the Parquet files together -# with the partiioning column appeared in the partition directory paths. +# with the partitioning column appeared in the partition directory paths. # root # |-- single: int (nullable = true) # |-- double: int (nullable = true) @@ -1253,7 +1253,7 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` false Turn on Parquet filter pushdown optimization. This feature is turned off by default because of a known - bug in Paruet 1.6.0rc3 (PARQUET-136). + bug in Parquet 1.6.0rc3 (PARQUET-136). However, if your table doesn't contain any nullable string or binary columns, it's still safe to turn this feature on. @@ -1402,7 +1402,7 @@ sqlContext <- sparkRSQL.init(sc) # The path can be either a single text file or a directory storing text files. path <- "examples/src/main/resources/people.json" # Create a DataFrame from the file(s) pointed to by path -people <- jsonFile(sqlContex,t path) +people <- jsonFile(sqlContext, path) # The inferred schema can be visualized using the printSchema() method. printSchema(people) @@ -1474,7 +1474,7 @@ sqlContext.sql("FROM src SELECT key, value").collect().foreach(println) When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be +the `sql` method a `HiveContext` also provides an `hql` method, which allows queries to be expressed in HiveQL. {% highlight java %} @@ -2770,7 +2770,7 @@ from pyspark.sql.types import * MapType - enviroment + environment list(type="map", keyType=keyType, valueType=valueType, valueContainsNull=[valueContainsNull])
Note: The default value of valueContainsNull is True. From 3eaed8769c16e887edb9d54f5816b4ee6da23de5 Mon Sep 17 00:00:00 2001 From: Dibyendu Bhattacharya Date: Thu, 18 Jun 2015 19:58:47 -0700 Subject: [PATCH 13/22] [SPARK-8080] [STREAMING] Receiver.store with Iterator does not give correct count at Spark UI tdas zsxwing this is the new PR for Spark-8080 I have merged https://github.com/apache/spark/pull/6659 Also to mention , for MEMORY_ONLY settings , when Block is not able to unrollSafely to memory if enough space is not there, BlockManager won't try to put the block and ReceivedBlockHandler will throw SparkException as it could not find the block id in PutResult. Thus number of records in block won't be counted if Block failed to unroll in memory. Which is fine. For MEMORY_DISK settings , if BlockManager not able to unroll block to memory, block will still get deseralized to Disk. Same for WAL based store. So for those cases ( storage level = memory + disk ) number of records will be counted even though the block not able to unroll to memory. thus I added the isFullyConsumed in the CountingIterator but have not used it as such case will never happen that block not fully consumed and ReceivedBlockHandler still get the block ID. I have added few test cases to cover those block unrolling scenarios also. Author: Dibyendu Bhattacharya Author: U-PEROOT\UBHATD1 Closes #6707 from dibbhatt/master and squashes the following commits: f6cb6b5 [Dibyendu Bhattacharya] [SPARK-8080][STREAMING] Receiver.store with Iterator does not give correct count at Spark UI f37cfd8 [Dibyendu Bhattacharya] [SPARK-8080][STREAMING] Receiver.store with Iterator does not give correct count at Spark UI 5a8344a [Dibyendu Bhattacharya] [SPARK-8080][STREAMING] Receiver.store with Iterator does not give correct count at Spark UI Count ByteBufferBlock as 1 count fceac72 [Dibyendu Bhattacharya] [SPARK-8080][STREAMING] Receiver.store with Iterator does not give correct count at Spark UI 0153e7e [Dibyendu Bhattacharya] [SPARK-8080][STREAMING] Receiver.store with Iterator does not give correct count at Spark UI Fixed comments given by @zsxwing 4c5931d [Dibyendu Bhattacharya] [SPARK-8080][STREAMING] Receiver.store with Iterator does not give correct count at Spark UI 01e6dc8 [U-PEROOT\UBHATD1] A --- .../receiver/ReceivedBlockHandler.scala | 53 +++++- .../receiver/ReceiverSupervisorImpl.scala | 7 +- .../streaming/ReceivedBlockHandlerSuite.scala | 154 +++++++++++++++++- .../streaming/ReceivedBlockTrackerSuite.scala | 2 +- 4 files changed, 194 insertions(+), 22 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 207d64d9414ee..c8dd6e06812dc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -32,7 +32,10 @@ import org.apache.spark.{Logging, SparkConf, SparkException} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { - def blockId: StreamBlockId // Any implementation of this trait will store a block id + // Any implementation of this trait will store a block id + def blockId: StreamBlockId + // Any implementation of this trait will have to return the number of records + def numRecords: Option[Long] } /** Trait that represents a class that handles the storage of blocks received by receiver */ @@ -51,7 +54,8 @@ private[streaming] trait ReceivedBlockHandler { * that stores the metadata related to storage of blocks using * [[org.apache.spark.streaming.receiver.BlockManagerBasedBlockHandler]] */ -private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockId) +private[streaming] case class BlockManagerBasedStoreResult( + blockId: StreamBlockId, numRecords: Option[Long]) extends ReceivedBlockStoreResult @@ -64,11 +68,20 @@ private[streaming] class BlockManagerBasedBlockHandler( extends ReceivedBlockHandler with Logging { def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + + var numRecords = None: Option[Long] + val putResult: Seq[(BlockId, BlockStatus)] = block match { case ArrayBufferBlock(arrayBuffer) => - blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, tellMaster = true) + numRecords = Some(arrayBuffer.size.toLong) + blockManager.putIterator(blockId, arrayBuffer.iterator, storageLevel, + tellMaster = true) case IteratorBlock(iterator) => - blockManager.putIterator(blockId, iterator, storageLevel, tellMaster = true) + val countIterator = new CountingIterator(iterator) + val putResult = blockManager.putIterator(blockId, countIterator, storageLevel, + tellMaster = true) + numRecords = countIterator.count + putResult case ByteBufferBlock(byteBuffer) => blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true) case o => @@ -79,7 +92,7 @@ private[streaming] class BlockManagerBasedBlockHandler( throw new SparkException( s"Could not store $blockId to block manager with storage level $storageLevel") } - BlockManagerBasedStoreResult(blockId) + BlockManagerBasedStoreResult(blockId, numRecords) } def cleanupOldBlocks(threshTime: Long) { @@ -96,6 +109,7 @@ private[streaming] class BlockManagerBasedBlockHandler( */ private[streaming] case class WriteAheadLogBasedStoreResult( blockId: StreamBlockId, + numRecords: Option[Long], walRecordHandle: WriteAheadLogRecordHandle ) extends ReceivedBlockStoreResult @@ -151,12 +165,17 @@ private[streaming] class WriteAheadLogBasedBlockHandler( */ def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { + var numRecords = None: Option[Long] // Serialize the block so that it can be inserted into both val serializedBlock = block match { case ArrayBufferBlock(arrayBuffer) => + numRecords = Some(arrayBuffer.size.toLong) blockManager.dataSerialize(blockId, arrayBuffer.iterator) case IteratorBlock(iterator) => - blockManager.dataSerialize(blockId, iterator) + val countIterator = new CountingIterator(iterator) + val serializedBlock = blockManager.dataSerialize(blockId, countIterator) + numRecords = countIterator.count + serializedBlock case ByteBufferBlock(byteBuffer) => byteBuffer case _ => @@ -181,7 +200,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Combine the futures, wait for both to complete, and return the write ahead log record handle val combinedFuture = storeInBlockManagerFuture.zip(storeInWriteAheadLogFuture).map(_._2) val walRecordHandle = Await.result(combinedFuture, blockStoreTimeout) - WriteAheadLogBasedStoreResult(blockId, walRecordHandle) + WriteAheadLogBasedStoreResult(blockId, numRecords, walRecordHandle) } def cleanupOldBlocks(threshTime: Long) { @@ -199,3 +218,23 @@ private[streaming] object WriteAheadLogBasedBlockHandler { new Path(checkpointDir, new Path("receivedData", streamId.toString)).toString } } + +/** + * A utility that will wrap the Iterator to get the count + */ +private class CountingIterator[T](iterator: Iterator[T]) extends Iterator[T] { + private var _count = 0 + + private def isFullyConsumed: Boolean = !iterator.hasNext + + def hasNext(): Boolean = iterator.hasNext + + def count(): Option[Long] = { + if (isFullyConsumed) Some(_count) else None + } + + def next(): T = { + _count += 1 + iterator.next() + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 8be732b64e3a3..6078cdf8f8790 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -137,15 +137,10 @@ private[streaming] class ReceiverSupervisorImpl( blockIdOption: Option[StreamBlockId] ) { val blockId = blockIdOption.getOrElse(nextBlockId) - val numRecords = receivedBlock match { - case ArrayBufferBlock(arrayBuffer) => Some(arrayBuffer.size.toLong) - case _ => None - } - val time = System.currentTimeMillis val blockStoreResult = receivedBlockHandler.storeBlock(blockId, receivedBlock) logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") - + val numRecords = blockStoreResult.numRecords val blockInfo = ReceivedBlockInfo(streamId, numRecords, metadataOption, blockStoreResult) trackerEndpoint.askWithRetry[Boolean](AddBlock(blockInfo)) logDebug(s"Reported block $blockId") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index cca8cedb1d080..6c0c926755c20 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -49,7 +49,6 @@ class ReceivedBlockHandlerSuite val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() - val storageLevel = StorageLevel.MEMORY_ONLY_SER val streamId = 1 val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -57,10 +56,12 @@ class ReceivedBlockHandlerSuite val serializer = new KryoSerializer(conf) val manualClock = new ManualClock val blockManagerSize = 10000000 + val blockManagerBuffer = new ArrayBuffer[BlockManager]() var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null + var storageLevel: StorageLevel = null var tempDirectory: File = null before { @@ -70,20 +71,21 @@ class ReceivedBlockHandlerSuite blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, - blockManagerSize, conf, mapOutputTracker, shuffleManager, - new NioBlockTransferService(conf, securityMgr), securityMgr, 0) - blockManager.initialize("app-id") + storageLevel = StorageLevel.MEMORY_ONLY_SER + blockManager = createBlockManager(blockManagerSize, conf) tempDirectory = Utils.createTempDir() manualClock.setTime(0) } after { - if (blockManager != null) { - blockManager.stop() - blockManager = null + for ( blockManager <- blockManagerBuffer ) { + if (blockManager != null) { + blockManager.stop() + } } + blockManager = null + blockManagerBuffer.clear() if (blockManagerMaster != null) { blockManagerMaster.stop() blockManagerMaster = null @@ -174,6 +176,130 @@ class ReceivedBlockHandlerSuite } } + test("Test Block - count messages") { + // Test count with BlockManagedBasedBlockHandler + testCountWithBlockManagerBasedBlockHandler(true) + // Test count with WriteAheadLogBasedBlockHandler + testCountWithBlockManagerBasedBlockHandler(false) + } + + test("Test Block - isFullyConsumed") { + val sparkConf = new SparkConf() + sparkConf.set("spark.storage.unrollMemoryThreshold", "512") + // spark.storage.unrollFraction set to 0.4 for BlockManager + sparkConf.set("spark.storage.unrollFraction", "0.4") + // Block Manager with 12000 * 0.4 = 4800 bytes of free space for unroll + blockManager = createBlockManager(12000, sparkConf) + + // there is not enough space to store this block in MEMORY, + // But BlockManager will be able to sereliaze this block to WAL + // and hence count returns correct value. + testRecordcount(false, StorageLevel.MEMORY_ONLY, + IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) + + // there is not enough space to store this block in MEMORY, + // But BlockManager will be able to sereliaze this block to DISK + // and hence count returns correct value. + testRecordcount(true, StorageLevel.MEMORY_AND_DISK, + IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator), blockManager, Some(70)) + + // there is not enough space to store this block With MEMORY_ONLY StorageLevel. + // BlockManager will not be able to unroll this block + // and hence it will not tryToPut this block, resulting the SparkException + storageLevel = StorageLevel.MEMORY_ONLY + withBlockManagerBasedBlockHandler { handler => + val thrown = intercept[SparkException] { + storeSingleBlock(handler, IteratorBlock((List.fill(70)(new Array[Byte](100))).iterator)) + } + } + } + + private def testCountWithBlockManagerBasedBlockHandler(isBlockManagerBasedBlockHandler: Boolean) { + // ByteBufferBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None) + // ByteBufferBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + ByteBufferBlock(ByteBuffer.wrap(Array.tabulate(100)(i => i.toByte))), blockManager, None) + // ArrayBufferBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25)) + // ArrayBufferBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + ArrayBufferBlock(ArrayBuffer.fill(25)(0)), blockManager, Some(25)) + // ArrayBufferBlock-DISK_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY, + ArrayBufferBlock(ArrayBuffer.fill(50)(0)), blockManager, Some(50)) + // ArrayBufferBlock-MEMORY_AND_DISK + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK, + ArrayBufferBlock(ArrayBuffer.fill(75)(0)), blockManager, Some(75)) + // IteratorBlock-MEMORY_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY, + IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100)) + // IteratorBlock-MEMORY_ONLY_SER + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_ONLY_SER, + IteratorBlock((ArrayBuffer.fill(100)(0)).iterator), blockManager, Some(100)) + // IteratorBlock-DISK_ONLY + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.DISK_ONLY, + IteratorBlock((ArrayBuffer.fill(125)(0)).iterator), blockManager, Some(125)) + // IteratorBlock-MEMORY_AND_DISK + testRecordcount(isBlockManagerBasedBlockHandler, StorageLevel.MEMORY_AND_DISK, + IteratorBlock((ArrayBuffer.fill(150)(0)).iterator), blockManager, Some(150)) + } + + private def createBlockManager( + maxMem: Long, + conf: SparkConf, + name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { + val transfer = new NioBlockTransferService(conf, securityMgr) + val manager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, maxMem, conf, + mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + manager.initialize("app-id") + blockManagerBuffer += manager + manager + } + + /** + * Test storing of data using different types of Handler, StorageLevle and ReceivedBlocks + * and verify the correct record count + */ + private def testRecordcount(isBlockManagedBasedBlockHandler: Boolean, + sLevel: StorageLevel, + receivedBlock: ReceivedBlock, + bManager: BlockManager, + expectedNumRecords: Option[Long] + ) { + blockManager = bManager + storageLevel = sLevel + var bId: StreamBlockId = null + try { + if (isBlockManagedBasedBlockHandler) { + // test received block with BlockManager based handler + withBlockManagerBasedBlockHandler { handler => + val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock) + bId = blockId + assert(blockStoreResult.numRecords === expectedNumRecords, + "Message count not matches for a " + + receivedBlock.getClass.getName + + " being inserted using BlockManagerBasedBlockHandler with " + sLevel) + } + } else { + // test received block with WAL based handler + withWriteAheadLogBasedBlockHandler { handler => + val (blockId, blockStoreResult) = storeSingleBlock(handler, receivedBlock) + bId = blockId + assert(blockStoreResult.numRecords === expectedNumRecords, + "Message count not matches for a " + + receivedBlock.getClass.getName + + " being inserted using WriteAheadLogBasedBlockHandler with " + sLevel) + } + } + } finally { + // Removing the Block Id to use same blockManager for next test + blockManager.removeBlock(bId, true) + } + } + /** * Test storing of data using different forms of ReceivedBlocks and verify that they succeeded * using the given verification function @@ -251,9 +377,21 @@ class ReceivedBlockHandlerSuite (blockIds, storeResults) } + /** Store single block using a handler */ + private def storeSingleBlock( + handler: ReceivedBlockHandler, + block: ReceivedBlock + ): (StreamBlockId, ReceivedBlockStoreResult) = { + val blockId = generateBlockId + val blockStoreResult = handler.storeBlock(blockId, block) + logDebug("Done inserting") + (blockId, blockStoreResult) + } + private def getWriteAheadLogFiles(): Seq[String] = { getLogFilesInDirectory(checkpointDirToLogDir(tempDirectory.toString, streamId)) } private def generateBlockId(): StreamBlockId = StreamBlockId(streamId, scala.util.Random.nextLong) } + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index be305b5e0dfea..f793a12843b2f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -225,7 +225,7 @@ class ReceivedBlockTrackerSuite /** Generate blocks infos using random ids */ def generateBlockInfos(): Seq[ReceivedBlockInfo] = { List.fill(5)(ReceivedBlockInfo(streamId, Some(0L), None, - BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt))))) + BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)), Some(0L)))) } /** Get all the data written in the given write ahead log file. */ From a71cbbdea581573192a59bf8472861c463c40fcb Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 18 Jun 2015 22:01:52 -0700 Subject: [PATCH 14/22] [SPARK-8458] [SQL] Don't strip scheme part of output path when writing ORC files `Path.toUri.getPath` strips scheme part of output path (from `file:///foo` to `/foo`), which causes ORC data source only writes to the file system configured in Hadoop configuration. Should use `Path.toString` instead. Author: Cheng Lian Closes #6892 from liancheng/spark-8458 and squashes the following commits: 87f8199 [Cheng Lian] Don't strip scheme of output path when writing ORC files --- .../main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 77f1ca9ae0875..dbce39f21d271 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -111,7 +111,7 @@ private[orc] class OrcOutputWriter( new OrcOutputFormat().getRecordWriter( new Path(path, filename).getFileSystem(conf), conf.asInstanceOf[JobConf], - new Path(path, filename).toUri.getPath, + new Path(path, filename).toString, Reporter.NULL ).asInstanceOf[RecordWriter[NullWritable, Writable]] } From 754929b153aba3a8f8fbafa1581957da4ccc18be Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 18 Jun 2015 23:13:05 -0700 Subject: [PATCH 15/22] [SPARK-8348][SQL] Add in operator to DataFrame Column I have added it for only Scala. TODO: we should also support `in` operator in Python. Author: Yu ISHIKAWA Closes #6824 from yu-iskw/SPARK-8348 and squashes the following commits: e76d02f [Yu ISHIKAWA] Not use infix notation 6f744ac [Yu ISHIKAWA] Fit the test cases because these used the old test data set. 00077d3 [Yu ISHIKAWA] [SPARK-8348][SQL] Add in operator to DataFrame Column --- .../main/scala/org/apache/spark/sql/Column.scala | 2 +- .../apache/spark/sql/ColumnExpressionSuite.scala | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d3efa83380d04..b4e008a6e8480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -621,7 +621,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.3.0 */ @scala.annotation.varargs - def in(list: Column*): Column = In(expr, list.map(_.expr)) + def in(list: Any*): Column = In(expr, list.map(lit(_).expr)) /** * SQL like expression. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 5a08578e7ba4b..88bb743ab0bc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -296,6 +296,22 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer) } + test("in") { + val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b") + checkAnswer(df.filter($"a".in(1, 2)), + df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".in(3, 2)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2)) + checkAnswer(df.filter($"a".in(3, 1)), + df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1)) + checkAnswer(df.filter($"b".in("y", "x")), + df.collect().toSeq.filter(r => r.getString(1) == "y" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".in("z", "x")), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x")) + checkAnswer(df.filter($"b".in("z", "y")), + df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y")) + } + val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: From a2016b4bc4ef13339f168c3f4e135fa422046137 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 19 Jun 2015 00:07:53 -0700 Subject: [PATCH 16/22] [SPARK-8444] [STREAMING] Adding Python streaming example for queueStream A Python example similar to the existing one for Scala. Author: Bryan Cutler Closes #6884 from BryanCutler/streaming-queueStream-example-8444 and squashes the following commits: 435ba7e [Bryan Cutler] [SPARK-8444] Fixed style checks, increased sleep time to show empty queue 257abb0 [Bryan Cutler] [SPARK-8444] Stop context gracefully, Removed unused import, Added description comment 376ef6e [Bryan Cutler] [SPARK-8444] Fixed bug causing DStream.pprint to append empty parenthesis to output instead of blank line 1ff5f8b [Bryan Cutler] [SPARK-8444] Adding Python streaming example for queue_stream --- .../src/main/python/streaming/queue_stream.py | 50 +++++++++++++++++++ python/pyspark/streaming/dstream.py | 2 +- 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 examples/src/main/python/streaming/queue_stream.py diff --git a/examples/src/main/python/streaming/queue_stream.py b/examples/src/main/python/streaming/queue_stream.py new file mode 100644 index 0000000000000..dcd6a0fc6ff91 --- /dev/null +++ b/examples/src/main/python/streaming/queue_stream.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Create a queue of RDDs that will be mapped/reduced one at a time in + 1 second intervals. + + To run this example use + `$ bin/spark-submit examples/src/main/python/streaming/queue_stream.py +""" +import sys +import time + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonStreamingQueueStream") + ssc = StreamingContext(sc, 1) + + # Create the queue through which RDDs can be pushed to + # a QueueInputDStream + rddQueue = [] + for i in xrange(5): + rddQueue += [ssc.sparkContext.parallelize([j for j in xrange(1, 1001)], 10)] + + # Create the QueueInputDStream and use it do some processing + inputStream = ssc.queueStream(rddQueue) + mappedStream = inputStream.map(lambda x: (x % 10, 1)) + reducedStream = mappedStream.reduceByKey(lambda a, b: a + b) + reducedStream.pprint() + + ssc.start() + time.sleep(6) + ssc.stop(stopSparkContext=True, stopGraceFully=True) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index ff097985fae3e..8dcb9645cdc6b 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -176,7 +176,7 @@ def takeAndPrint(time, rdd): print(record) if len(taken) > num: print("...") - print() + print("") self.foreachRDD(takeAndPrint) From fdf63f12490c674cc1877ddf7b70343c4fd6f4f1 Mon Sep 17 00:00:00 2001 From: Kevin Conor Date: Fri, 19 Jun 2015 00:12:20 -0700 Subject: [PATCH 17/22] [SPARK-8339] [PYSPARK] integer division for python 3 Itertools islice requires an integer for the stop argument. Switching to integer division here prevents a ValueError when vs is evaluated above. davies This is my original work, and I license it to the project. Author: Kevin Conor Closes #6794 from kconor/kconor-patch-1 and squashes the following commits: da5e700 [Kevin Conor] Integer division for batch size --- python/pyspark/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d8cdcda3a3783..7f9d0a338d31e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -272,7 +272,7 @@ def dump_stream(self, iterator, stream): if size < best: batch *= 2 elif size > best * 10 and batch > 1: - batch /= 2 + batch //= 2 def __repr__(self): return "AutoBatchedSerializer(%s)" % self.serializer From 54557f353e588f5ff622ab8e67068bab408bce92 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Fri, 19 Jun 2015 09:57:12 +0200 Subject: [PATCH 18/22] [SPARK-8387] [FOLLOWUP ] [WEBUI] Update driver log URL to show only 4096 bytes This is to follow up #6834 , update the driver log URL as well for consistency. Author: Carson Wang Closes #6878 from carsonwang/logUrl and squashes the following commits: 13be948 [Carson Wang] update log URL in YarnClusterSuite a0004f4 [Carson Wang] Update driver log URL to show only 4096 bytes --- .../scheduler/cluster/YarnClusterSchedulerBackend.scala | 5 +++-- .../org/apache/spark/deploy/yarn/YarnClusterSuite.scala | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 1ace1a97d5156..33f580aaebdc0 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -115,8 +115,9 @@ private[spark] class YarnClusterSchedulerBackend( val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" logDebug(s"Base URL for logs: $baseUrl") - driverLogs = Some( - Map("stderr" -> s"$baseUrl/stderr?start=0", "stdout" -> s"$baseUrl/stdout?start=0")) + driverLogs = Some(Map( + "stderr" -> s"$baseUrl/stderr?start=-4096", + "stdout" -> s"$baseUrl/stdout?start=-4096")) } } } catch { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index a0f25ba450068..335e966519c7c 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -376,7 +376,7 @@ private object YarnClusterDriver extends Logging with Matchers { new URL(urlStr) val containerId = YarnSparkHadoopUtil.get.getContainerId val user = Utils.getCurrentUserName() - assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=0")) + assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096")) } } From 93360dc3cd6186e9d33c762d153a829a5882b72b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 19 Jun 2015 11:58:07 +0200 Subject: [PATCH 19/22] [SPARK-7913] [CORE] Make AppendOnlyMap use the same growth strategy of OpenHashSet and consistent exception message This is a follow up PR for #6456 to make AppendOnlyMap consistent with OpenHashSet. /cc srowen andrewor14 Author: zsxwing Closes #6879 from zsxwing/append-only-map and squashes the following commits: 912c0ad [zsxwing] Fix the doc dd4385b [zsxwing] Make AppendOnlyMap use the same growth strategy of OpenHashSet and consistent exception message --- .../apache/spark/util/collection/AppendOnlyMap.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index d215ee43cb539..4c1e16155462e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -32,7 +32,7 @@ import org.apache.spark.annotation.DeveloperApi * size, which is guaranteed to explore all spaces for each key (see * http://en.wikipedia.org/wiki/Quadratic_probing). * - * The map can support up to `536870912 (2 ^ 29)` elements. + * The map can support up to `375809638 (0.7 * 2 ^ 29)` elements. * * TODO: Cache the hash values of each key? java.util.HashMap does that. */ @@ -199,11 +199,8 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) /** Increase table size by 1, rehashing if necessary */ private def incrementSize() { - if (curSize == MAXIMUM_CAPACITY) { - throw new IllegalStateException(s"Can't put more that ${MAXIMUM_CAPACITY} elements") - } curSize += 1 - if (curSize > growThreshold && capacity < MAXIMUM_CAPACITY) { + if (curSize > growThreshold) { growTable() } } @@ -216,7 +213,8 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) /** Double the table's size and re-hash everything */ protected def growTable() { // capacity < MAXIMUM_CAPACITY (2 ^ 29) so capacity * 2 won't overflow - val newCapacity = (capacity * 2).min(MAXIMUM_CAPACITY) + val newCapacity = capacity * 2 + require(newCapacity <= MAXIMUM_CAPACITY, s"Can't contain more than ${growThreshold} elements") val newData = new Array[AnyRef](2 * newCapacity) val newMask = newCapacity - 1 // Insert all our old values into the new array. Note that because our old keys are From ebd363aecde977511469d47fb1ea7cb5df3c3541 Mon Sep 17 00:00:00 2001 From: Jihong MA Date: Fri, 19 Jun 2015 14:05:11 +0200 Subject: [PATCH 20/22] [SPARK-7265] Improving documentation for Spark SQL Hive support Please review this pull request. Author: Jihong MA Closes #5933 from JihongMA/SPARK-7265 and squashes the following commits: dfaa971 [Jihong MA] SPARK-7265 minor fix of the content ace454d [Jihong MA] SPARK-7265 take out PySpark on YARN limitation 9ea0832 [Jihong MA] Merge remote-tracking branch 'upstream/master' d5bf3f5 [Jihong MA] Merge remote-tracking branch 'upstream/master' 7b842e6 [Jihong MA] Merge remote-tracking branch 'upstream/master' 9c84695 [Jihong MA] SPARK-7265 address review comment a399aa6 [Jihong MA] SPARK-7265 Improving documentation for Spark SQL Hive support --- docs/sql-programming-guide.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 9b5ea394a6efb..26c036f6648da 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1445,7 +1445,12 @@ This command builds a new assembly jar that includes Hive. Note that this Hive a on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. Please note when running +the query on a YARN cluster (`yarn-cluster` mode), the `datanucleus` jars under the `lib_managed/jars` directory +and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the +YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the +`spark-submit` command. +
From 47af7c1ebfdbd7637f626ab07bf2bda6534f37ea Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Fri, 19 Jun 2015 14:51:19 +0200 Subject: [PATCH 21/22] =?UTF-8?q?[SPARK-8389]=20[STREAMING]=20[KAFKA]=20Ex?= =?UTF-8?q?ample=20of=20getting=20offset=20ranges=20out=20o=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …f the existing java direct stream api Author: cody koeninger Closes #6846 from koeninger/SPARK-8389 and squashes the following commits: 3f3c57a [cody koeninger] [Streaming][Kafka][SPARK-8389] Example of getting offset ranges out of the existing java direct stream api --- .../kafka/JavaDirectKafkaStreamSuite.java | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index c0669fb336657..3913b711ba28b 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -32,6 +32,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; @@ -65,8 +66,8 @@ public void tearDown() { @Test public void testKafkaStream() throws InterruptedException { - String topic1 = "topic1"; - String topic2 = "topic2"; + final String topic1 = "topic1"; + final String topic2 = "topic2"; String[] topic1data = createTopicAndSendData(topic1); String[] topic2data = createTopicAndSendData(topic2); @@ -87,6 +88,16 @@ public void testKafkaStream() throws InterruptedException { StringDecoder.class, kafkaParams, topicToSet(topic1) + ).transformToPair( + // Make sure you can get offset ranges from the rdd + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call(JavaPairRDD rdd) throws Exception { + OffsetRange[] offsets = ((HasOffsetRanges)rdd.rdd()).offsetRanges(); + Assert.assertEquals(offsets[0].topic(), topic1); + return rdd; + } + } ).map( new Function, String>() { @Override From 43c7ec6384e51105dedf3a53354b6a3732cc27b2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 19 Jun 2015 09:46:51 -0700 Subject: [PATCH 22/22] [SPARK-8151] [MLLIB] pipeline components should correctly implement copy Otherwise, extra params get ignored in `PipelineModel.transform`. jkbradley Author: Xiangrui Meng Closes #6622 from mengxr/SPARK-8087 and squashes the following commits: 0e4c8c4 [Xiangrui Meng] fix merge issues 26fc1f0 [Xiangrui Meng] address comments e607a04 [Xiangrui Meng] merge master b85b57e [Xiangrui Meng] fix examples/compile d6f7891 [Xiangrui Meng] rename defaultCopyWithParams to defaultCopy 84ec278 [Xiangrui Meng] remove setter checks due to generics 2cf2ed0 [Xiangrui Meng] snapshot 291814f [Xiangrui Meng] OneVsRest.copy 1dfe3bd [Xiangrui Meng] PipelineModel.copy should copy stages --- .../examples/ml/JavaDeveloperApiExample.java | 5 ++++ .../examples/ml/DeveloperApiExample.scala | 2 ++ .../scala/org/apache/spark/ml/Estimator.scala | 4 +-- .../scala/org/apache/spark/ml/Model.scala | 5 +--- .../scala/org/apache/spark/ml/Pipeline.scala | 6 ++-- .../scala/org/apache/spark/ml/Predictor.scala | 4 +-- .../org/apache/spark/ml/Transformer.scala | 6 ++-- .../spark/ml/classification/Classifier.scala | 1 + .../DecisionTreeClassifier.scala | 2 ++ .../ml/classification/GBTClassifier.scala | 2 ++ .../classification/LogisticRegression.scala | 2 ++ .../spark/ml/classification/OneVsRest.scala | 16 +++++++++- .../RandomForestClassifier.scala | 2 ++ .../BinaryClassificationEvaluator.scala | 2 ++ .../spark/ml/evaluation/Evaluator.scala | 4 +-- .../ml/evaluation/RegressionEvaluator.scala | 4 ++- .../apache/spark/ml/feature/Binarizer.scala | 2 ++ .../apache/spark/ml/feature/Bucketizer.scala | 2 ++ .../spark/ml/feature/ElementwiseProduct.scala | 2 +- .../apache/spark/ml/feature/HashingTF.scala | 4 ++- .../org/apache/spark/ml/feature/IDF.scala | 13 ++++++-- .../spark/ml/feature/OneHotEncoder.scala | 2 ++ .../ml/feature/PolynomialExpansion.scala | 4 ++- .../spark/ml/feature/StandardScaler.scala | 7 +++++ .../spark/ml/feature/StringIndexer.scala | 7 +++++ .../apache/spark/ml/feature/Tokenizer.scala | 4 +++ .../spark/ml/feature/VectorAssembler.scala | 3 ++ .../spark/ml/feature/VectorIndexer.scala | 9 +++++- .../apache/spark/ml/feature/Word2Vec.scala | 7 +++++ .../org/apache/spark/ml/param/params.scala | 15 +++++++--- .../apache/spark/ml/recommendation/ALS.scala | 7 +++++ .../ml/regression/DecisionTreeRegressor.scala | 2 ++ .../spark/ml/regression/GBTRegressor.scala | 2 ++ .../ml/regression/LinearRegression.scala | 2 ++ .../ml/regression/RandomForestRegressor.scala | 2 ++ .../spark/ml/tuning/CrossValidator.scala | 11 +++++++ .../org/apache/spark/mllib/feature/IDF.scala | 2 +- .../apache/spark/mllib/feature/Word2Vec.scala | 2 +- .../apache/spark/ml/param/JavaTestParams.java | 5 ++++ .../org/apache/spark/ml/PipelineSuite.scala | 10 +++++++ .../DecisionTreeClassifierSuite.scala | 12 ++++++-- .../classification/GBTClassifierSuite.scala | 11 +++++++ .../LogisticRegressionSuite.scala | 9 +++++- .../ml/classification/OneVsRestSuite.scala | 30 +++++++++++++++++++ .../RandomForestClassifierSuite.scala | 10 ++++++- .../BinaryClassificationEvaluatorSuite.scala | 28 +++++++++++++++++ .../evaluation/RegressionEvaluatorSuite.scala | 5 ++++ .../spark/ml/feature/BinarizerSuite.scala | 5 ++++ .../spark/ml/feature/BucketizerSuite.scala | 5 ++++ .../spark/ml/feature/HashingTFSuite.scala | 3 +- .../apache/spark/ml/feature/IDFSuite.scala | 8 +++++ .../spark/ml/feature/OneHotEncoderSuite.scala | 5 ++++ .../ml/feature/PolynomialExpansionSuite.scala | 5 ++++ .../spark/ml/feature/StringIndexerSuite.scala | 7 +++++ .../spark/ml/feature/TokenizerSuite.scala | 12 ++++++++ .../ml/feature/VectorAssemblerSuite.scala | 5 ++++ .../spark/ml/feature/VectorIndexerSuite.scala | 7 +++++ .../spark/ml/feature/Word2VecSuite.scala | 8 +++++ .../apache/spark/ml/param/ParamsSuite.scala | 22 +++++++++----- .../apache/spark/ml/param/TestParams.scala | 4 +-- .../ml/param/shared/SharedParamsSuite.scala | 6 ++-- .../spark/ml/tuning/CrossValidatorSuite.scala | 5 +++- 62 files changed, 350 insertions(+), 55 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index ec533d174ebdc..9df26ffca5775 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -156,6 +156,11 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { // Create a model, and return it. return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); } + + @Override + public MyJavaLogisticRegression copy(ParamMap extra) { + return defaultCopy(extra); + } } /** diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 3ee456edbe01e..7b8cc21ed8982 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -130,6 +130,8 @@ private class MyLogisticRegression(override val uid: String) // Create a model, and return it. new MyLogisticRegressionModel(uid, weights).setParent(this) } + + override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index e9a5d7c0e7988..57e416591de69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -78,7 +78,5 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { paramMaps.map(fit(dataset, _)) } - override def copy(extra: ParamMap): Estimator[M] = { - super.copy(extra).asInstanceOf[Estimator[M]] - } + override def copy(extra: ParamMap): Estimator[M] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 186bf7ae7a2f6..252acc156583f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -45,8 +45,5 @@ abstract class Model[M <: Model[M]] extends Transformer { /** Indicates whether this [[Model]] has a corresponding parent. */ def hasParent: Boolean = parent != null - override def copy(extra: ParamMap): M = { - // The default implementation of Params.copy doesn't work for models. - throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)") - } + override def copy(extra: ParamMap): M } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index a9bd28df71ee1..a1f3851d804ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -66,9 +66,7 @@ abstract class PipelineStage extends Params with Logging { outputSchema } - override def copy(extra: ParamMap): PipelineStage = { - super.copy(extra).asInstanceOf[PipelineStage] - } + override def copy(extra: ParamMap): PipelineStage } /** @@ -198,6 +196,6 @@ class PipelineModel private[ml] ( } override def copy(extra: ParamMap): PipelineModel = { - new PipelineModel(uid, stages) + new PipelineModel(uid, stages.map(_.copy(extra))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e752b81a14282..edaa2afb790e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -90,9 +90,7 @@ abstract class Predictor[ copyValues(train(dataset).setParent(this)) } - override def copy(extra: ParamMap): Learner = { - super.copy(extra).asInstanceOf[Learner] - } + override def copy(extra: ParamMap): Learner /** * Train a model using the given dataset and parameters. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index f07f733a5ddb5..3c7bcf7590e6d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -67,9 +67,7 @@ abstract class Transformer extends PipelineStage { */ def transform(dataset: DataFrame): DataFrame - override def copy(extra: ParamMap): Transformer = { - super.copy(extra).asInstanceOf[Transformer] - } + override def copy(extra: ParamMap): Transformer } /** @@ -120,4 +118,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] dataset.withColumn($(outputCol), callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) } + + override def copy(extra: ParamMap): T = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 263d580fe2dd3..14c285dbfc54a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 8030e0728a56c..2dc1824964a42 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -86,6 +86,8 @@ final class DecisionTreeClassifier(override val uid: String) super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, subsamplingRate = 1.0) } + + override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 62f4b51f770e9..554e3b8e052b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -141,6 +141,8 @@ final class GBTClassifier(override val uid: String) val oldModel = oldGBT.run(oldDataset) GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index f136bcee9cf2b..2e6eedd45ab07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -220,6 +220,8 @@ class LogisticRegression(override val uid: String) new LogisticRegressionModel(uid, weights.compressed, intercept) } + + override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 825f9ed1b54b2..b657882f8ad3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -24,7 +24,7 @@ import scala.language.existentials import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{DataFrame, Row} @@ -133,6 +133,12 @@ final class OneVsRestModel private[ml] ( aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) .drop(accColName) } + + override def copy(extra: ParamMap): OneVsRestModel = { + val copied = new OneVsRestModel( + uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) + copyValues(copied, extra) + } } /** @@ -209,4 +215,12 @@ final class OneVsRest(override val uid: String) val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this) copyValues(model) } + + override def copy(extra: ParamMap): OneVsRest = { + val copied = defaultCopy(extra).asInstanceOf[OneVsRest] + if (isDefined(classifier)) { + copied.setClassifier($(classifier).copy(extra)) + } + copied + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 852a67e066322..d3c67494a31e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -97,6 +97,8 @@ final class RandomForestClassifier(override val uid: String) oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index f695ddaeefc72..4a82b77f0edcb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -79,4 +79,6 @@ class BinaryClassificationEvaluator(override val uid: String) metrics.unpersist() metric } + + override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index 61e937e693699..e56c946a063e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -46,7 +46,5 @@ abstract class Evaluator extends Params { */ def evaluate(dataset: DataFrame): Double - override def copy(extra: ParamMap): Evaluator = { - super.copy(extra).asInstanceOf[Evaluator] - } + override def copy(extra: ParamMap): Evaluator } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index abb1b35bedea5..8670e9679d055 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.param.{Param, ParamValidators} +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.RegressionMetrics @@ -80,4 +80,6 @@ final class RegressionEvaluator(override val uid: String) } metric } + + override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index b06122d733853..46314854d5e3a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -83,4 +83,6 @@ final class Binarizer(override val uid: String) val outputFields = inputFields :+ attr.toStructField() StructType(outputFields) } + + override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index a3d1f6f65ccaf..67e4785bc3553 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -89,6 +89,8 @@ final class Bucketizer(override val uid: String) SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } + + override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra) } private[feature] object Bucketizer { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 1e758cb775de7..a359cb8f37ec3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.Param +import org.apache.spark.ml.param.{ParamMap, Param} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index f936aef80f8af..319d23e46cef4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature @@ -74,4 +74,6 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) SchemaUtils.appendColumn(schema, attrGroup.toStructField()) } + + override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 376b84530cd57..ecde80810580c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -45,9 +45,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol /** @group getParam */ def getMinDocFreq: Int = $(minDocFreq) - /** @group setParam */ - def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) - /** * Validate and transform the input schema. */ @@ -72,6 +69,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) + override def fit(dataset: DataFrame): IDFModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } @@ -82,6 +82,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): IDF = defaultCopy(extra) } /** @@ -109,4 +111,9 @@ class IDFModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): IDFModel = { + val copied = new IDFModel(uid, idfModel) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 8f34878c8d329..3825942795645 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -165,4 +165,6 @@ class OneHotEncoder(override val uid: String) extends Transformer dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata)) } + + override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 442e95820217a..d85e468562d4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType @@ -61,6 +61,8 @@ class PolynomialExpansion(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT() + + override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index b0fd06d84fdb3..ca3c1cfb56b7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -92,6 +92,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) StructType(outputFields) } + + override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) } /** @@ -125,4 +127,9 @@ class StandardScalerModel private[ml] ( val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) StructType(outputFields) } + + override def copy(extra: ParamMap): StandardScalerModel = { + val copied = new StandardScalerModel(uid, scaler) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index f4e250757560a..bf7be363b8224 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -83,6 +83,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) } /** @@ -144,4 +146,9 @@ class StringIndexerModel private[ml] ( schema } } + + override def copy(extra: ParamMap): StringIndexerModel = { + val copied = new StringIndexerModel(uid, labels) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 21c15b6c33f6c..5f9f57a2ebcfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -43,6 +43,8 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S } override protected def outputDataType: DataType = new ArrayType(StringType, false) + + override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) } /** @@ -112,4 +114,6 @@ class RegexTokenizer(override val uid: String) } override protected def outputDataType: DataType = new ArrayType(StringType, false) + + override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 229ee27ec5942..9f83c2ee16178 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} @@ -117,6 +118,8 @@ class VectorAssembler(override val uid: String) } StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false)) } + + override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) } private object VectorAssembler { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 1d0f23b4fb3db..f4854a5e4b7b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{IntParam, ParamValidators, Params} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} @@ -131,6 +131,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod SchemaUtils.checkColumnType(schema, $(inputCol), dataType) SchemaUtils.appendColumn(schema, $(outputCol), dataType) } + + override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra) } private object VectorIndexer { @@ -399,4 +401,9 @@ class VectorIndexerModel private[ml] ( val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes) newAttributeGroup.toStructField() } + + override def copy(extra: ParamMap): VectorIndexerModel = { + val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 36f19509f0cfb..6ea6590956300 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -132,6 +132,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra) } /** @@ -180,4 +182,9 @@ class Word2VecModel private[ml] ( override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): Word2VecModel = { + val copied = new Word2VecModel(uid, wordVectors) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ba94d6a3a80a9..15ebad8838a2a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -492,13 +492,20 @@ trait Params extends Identifiable with Serializable { /** * Creates a copy of this instance with the same UID and some extra params. - * The default implementation tries to create a new instance with the same UID. + * Subclasses should implement this method and set the return type properly. + * + * @see [[defaultCopy()]] + */ + def copy(extra: ParamMap): Params + + /** + * Default implementation of copy with extra params. + * It tries to create a new instance with the same UID. * Then it copies the embedded and extra parameters over and returns the new instance. - * Subclasses should override this method if the default approach is not sufficient. */ - def copy(extra: ParamMap): Params = { + protected final def defaultCopy[T <: Params](extra: ParamMap): T = { val that = this.getClass.getConstructor(classOf[String]).newInstance(uid) - copyValues(that, extra) + copyValues(that, extra).asInstanceOf[T] } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index df009d855ecbb..2e44cd4cc6a22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -216,6 +216,11 @@ class ALSModel private[ml] ( SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } + + override def copy(extra: ParamMap): ALSModel = { + val copied = new ALSModel(uid, rank, userFactors, itemFactors) + copyValues(copied, extra) + } } @@ -330,6 +335,8 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema) } + + override def copy(extra: ParamMap): ALS = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 43b68e7bb20fa..be1f8063d41d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -76,6 +76,8 @@ final class DecisionTreeRegressor(override val uid: String) super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, subsamplingRate = 1.0) } + + override def copy(extra: ParamMap): DecisionTreeRegressor = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index b7e374bb6cb49..036e3acb07412 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -131,6 +131,8 @@ final class GBTRegressor(override val uid: String) val oldModel = oldGBT.run(oldDataset) GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 70cd8e9e87fae..01306545fc7cd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -186,6 +186,8 @@ class LinearRegression(override val uid: String) // TODO: Converts to sparse format based on the storage, but may base on the scoring speed. copyValues(new LinearRegressionModel(uid, weights.compressed, intercept)) } + + override def copy(extra: ParamMap): LinearRegression = defaultCopy(extra) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 49a1f7ce8c995..21c59061a02fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -86,6 +86,8 @@ final class RandomForestRegressor(override val uid: String) oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) RandomForestRegressionModel.fromOld(oldModel, this, categoricalFeatures) } + + override def copy(extra: ParamMap): RandomForestRegressor = defaultCopy(extra) } @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index cb29392e8bc63..e2444ab65b43b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -149,6 +149,17 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM est.copy(paramMap).validateParams() } } + + override def copy(extra: ParamMap): CrossValidator = { + val copied = defaultCopy(extra).asInstanceOf[CrossValidator] + if (copied.isDefined(estimator)) { + copied.setEstimator(copied.getEstimator.copy(extra)) + } + if (copied.isDefined(evaluator)) { + copied.setEvaluator(copied.getEvaluator.copy(extra)) + } + copied + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index efbfeb4059f5a..3fab7ea79befc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -159,7 +159,7 @@ private object IDF { * Represents an IDF model that can transform term frequency vectors. */ @Experimental -class IDFModel private[mllib] (val idf: Vector) extends Serializable { +class IDFModel private[spark] (val idf: Vector) extends Serializable { /** * Transforms term frequency (TF) vectors to TF-IDF vectors. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 51546d41c36a6..f087d06d2a46a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -431,7 +431,7 @@ class Word2Vec extends Serializable with Logging { * Word2Vec model */ @Experimental -class Word2VecModel private[mllib] ( +class Word2VecModel private[spark] ( model: Map[String, Array[Float]]) extends Serializable with Saveable { // wordList: Ordered list of words obtained from model. diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index ff5929235ac2c..3ae09d39ef500 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -102,4 +102,9 @@ private void init() { setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0})); } + + @Override + public JavaTestParams copy(ParamMap extra) { + return defaultCopy(extra); + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 29394fefcbc43..63d2fa31c7499 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -24,6 +24,7 @@ import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.HashingTF import org.apache.spark.ml.param.ParamMap import org.apache.spark.sql.DataFrame @@ -84,6 +85,15 @@ class PipelineSuite extends SparkFunSuite { } } + test("PipelineModel.copy") { + val hashingTF = new HashingTF() + .setNumFeatures(100) + val model = new PipelineModel("pipeline", Array[Transformer](hashingTF)) + val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10)) + require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + "copy should handle extra stage params") + } + test("pipeline model constructors") { val transform0 = mock[Transformer] val model1 = mock[MyModel] diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index ae40b0b8ff854..73b4805c4c597 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, - DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeClassifierSuite.compareAPIs @@ -55,6 +55,12 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()) } + test("params") { + ParamsSuite.checkParams(new DecisionTreeClassifier) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)) + ParamsSuite.checkParams(model) + } + ///////////////////////////////////////////////////////////////////////////// // Tests calling train() ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 1302da3c373ff..82c345491bb3c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -51,6 +54,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) } + test("params") { + ParamsSuite.checkParams(new GBTClassifier) + val model = new GBTClassificationModel("gbtc", + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))), + Array(1.0)) + ParamsSuite.checkParams(model) + } + test("Binary classification with continuous features: Log Loss") { val categoricalFeatures = Map.empty[Int, Int] testCombinations.foreach { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a755cac3ea76e..5a6265ea992c6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -62,6 +63,12 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("params") { + ParamsSuite.checkParams(new LogisticRegression) + val model = new LogisticRegressionModel("logReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + test("logistic regression: default params") { val lr = new LogisticRegression assert(lr.getLabelCol === "label") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 1d04ccb509057..75cf5bd4ead4f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -19,15 +19,18 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.Metadata class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -52,6 +55,13 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { dataset = sqlContext.createDataFrame(rdd) } + test("params") { + ParamsSuite.checkParams(new OneVsRest) + val lrModel = new LogisticRegressionModel("lr", Vectors.dense(0.0), 0.0) + val model = new OneVsRestModel("ovr", Metadata.empty, Array(lrModel)) + ParamsSuite.checkParams(model) + } + test("one-vs-rest: default params") { val numClasses = 3 val ova = new OneVsRest() @@ -102,6 +112,26 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { val output = ovr.fit(dataset).transform(dataset) assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) } + + test("OneVsRest.copy and OneVsRestModel.copy") { + val lr = new LogisticRegression() + .setMaxIter(1) + + val ovr = new OneVsRest() + withClue("copy with classifier unset should work") { + ovr.copy(ParamMap(lr.maxIter -> 10)) + } + ovr.setClassifier(lr) + val ovr1 = ovr.copy(ParamMap(lr.maxIter -> 10)) + require(ovr.getClassifier.getOrDefault(lr.maxIter) === 1, "copy should have no side-effects") + require(ovr1.getClassifier.getOrDefault(lr.maxIter) === 10, + "copy should handle extra classifier params") + + val ovrModel = ovr1.fit(dataset).copy(ParamMap(lr.threshold -> 0.1)) + ovrModel.models.foreach { case m: LogisticRegressionModel => + require(m.getThreshold === 0.1, "copy should handle extra model params") + } + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index eee9355a67be3..1b6b69c7dc71e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.tree.LeafNode import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -27,7 +29,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame - /** * Test suite for [[RandomForestClassifier]]. */ @@ -62,6 +63,13 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses) } + test("params") { + ParamsSuite.checkParams(new RandomForestClassifier) + val model = new RandomForestClassificationModel("rfc", + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)))) + ParamsSuite.checkParams(model) + } + test("Binary classification with continuous features:" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val rf = new RandomForestClassifier() diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala new file mode 100644 index 0000000000000..def869fe66777 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite + +class BinaryClassificationEvaluatorSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new BinaryClassificationEvaluator) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 36a1ac6b7996d..aa722da323935 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -18,12 +18,17 @@ package org.apache.spark.ml.evaluation import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new RegressionEvaluator) + } + test("Regression Evaluator: default params") { /** * Here is the instruction describing how to export the test data into CSV format diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 7953bd0417191..2086043983661 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -30,6 +31,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4) } + test("params") { + ParamsSuite.checkParams(new Binarizer) + } + test("Binarize continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) val dataFrame: DataFrame = sqlContext.createDataFrame( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 507a8a7db24c7..ec85e0d151e07 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -27,6 +28,10 @@ import org.apache.spark.sql.{DataFrame, Row} class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Bucketizer) + } + test("Bucket continuous features, without -inf,inf") { // Check a set of valid feature values. val splits = Array(-0.5, 0.0, 0.5) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 7b2d70e644005..4157b84b29d01 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -28,8 +28,7 @@ import org.apache.spark.util.Utils class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { - val hashingTF = new HashingTF - ParamsSuite.checkParams(hashingTF, 3) + ParamsSuite.checkParams(new HashingTF) } test("hashingTF") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index d83772e8be755..08f80af03429b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -38,6 +40,12 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("params") { + ParamsSuite.checkParams(new IDF) + val model = new IDFModel("idf", new OldIDFModel(Vectors.dense(1.0))) + ParamsSuite.checkParams(model) + } + test("compute IDF with default parameter") { val numOfFeatures = 4 val data = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 2e5036a844562..65846a846b7b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame @@ -36,6 +37,10 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { indexer.transform(df) } + test("params") { + ParamsSuite.checkParams(new OneHotEncoder) + } + test("OneHotEncoder dropLast = false") { val transformed = stringIndexed() val encoder = new OneHotEncoder() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index feca866cd711d..29eebd8960ebc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.ml.param.ParamsSuite import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite @@ -27,6 +28,10 @@ import org.apache.spark.sql.Row class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new PolynomialExpansion) + } + test("Polynomial expansion with default parameter") { val data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 5f557e16e5150..99f82bea42688 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -19,10 +19,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new StringIndexer) + val model = new StringIndexerModel("indexer", Array("a", "b")) + ParamsSuite.checkParams(model) + } + test("StringIndexer") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) val df = sqlContext.createDataFrame(data).toDF("id", "label") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index ac279cb3215c2..e5fd21c3f6fca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -20,15 +20,27 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) +class TokenizerSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new Tokenizer) + } +} + class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.ml.feature.RegexTokenizerSuite._ + test("params") { + ParamsSuite.checkParams(new RegexTokenizer) + } + test("RegexTokenizer") { val tokenizer0 = new RegexTokenizer() .setGaps(false) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 489abb5af7130..bb4d5b983e0d4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row @@ -26,6 +27,10 @@ import org.apache.spark.sql.functions.col class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new VectorAssembler) + } + test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble assert(assemble(0.0) === Vectors.sparse(1, Array.empty, Array.empty)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 06affc7305cf5..8c85c96d5c6d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -21,6 +21,7 @@ import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -91,6 +92,12 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { private def getIndexer: VectorIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexed") + test("params") { + ParamsSuite.checkParams(new VectorIndexer) + val model = new VectorIndexerModel("indexer", 1, Map.empty) + ParamsSuite.checkParams(model) + } + test("Cannot fit an empty DataFrame") { val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) val vectorIndexer = getIndexer diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 94ebc3aebfa37..aa6ce533fd885 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -18,13 +18,21 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Word2Vec) + val model = new Word2VecModel("w2v", new OldWord2VecModel(Map("a" -> Array(0.0f)))) + ParamsSuite.checkParams(model) + } + test("Word2Vec") { val sqlContext = new SQLContext(sc) import sqlContext.implicits._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 96094d7a099aa..050d4170ea017 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -205,19 +205,27 @@ class ParamsSuite extends SparkFunSuite { object ParamsSuite extends SparkFunSuite { /** - * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered - * by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as - * the param method name. + * Checks common requirements for [[Params.params]]: + * - params are ordered by names + * - param parent has the same UID as the object's UID + * - param name is the same as the param method name + * - obj.copy should return the same type as the obj */ - def checkParams(obj: Params, expectedNumParams: Int): Unit = { + def checkParams(obj: Params): Unit = { + val clazz = obj.getClass + val params = obj.params - require(params.length === expectedNumParams, - s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.") val paramNames = params.map(_.name) - require(paramNames === paramNames.sorted) + require(paramNames === paramNames.sorted, "params must be ordered by names") params.foreach { p => assert(p.parent === obj.uid) assert(obj.getParam(p.name) === p) + // TODO: Check that setters return self, which needs special handling for generic types. } + + val copyMethod = clazz.getMethod("copy", classOf[ParamMap]) + val copyReturnType = copyMethod.getReturnType + require(copyReturnType === obj.getClass, + s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index a9e78366ad98f..2759248344531 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -38,7 +38,5 @@ class TestParams(override val uid: String) extends Params with HasMaxIter with H require(isDefined(inputCol)) } - override def copy(extra: ParamMap): TestParams = { - super.copy(extra).asInstanceOf[TestParams] - } + override def copy(extra: ParamMap): TestParams = defaultCopy(extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala index eb5408d3fee7c..b3af81a3c60b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala @@ -18,13 +18,15 @@ package org.apache.spark.ml.param.shared import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.param.Params +import org.apache.spark.ml.param.{ParamMap, Params} class SharedParamsSuite extends SparkFunSuite { test("outputCol") { - class Obj(override val uid: String) extends Params with HasOutputCol + class Obj(override val uid: String) extends Params with HasOutputCol { + override def copy(extra: ParamMap): Obj = defaultCopy(extra) + } val obj = new Obj("obj") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 9b3619f0046ea..36af4b34a9e40 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite - import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator} @@ -98,6 +97,8 @@ object CrossValidatorSuite { override def transformSchema(schema: StructType): StructType = { throw new UnsupportedOperationException } + + override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) } class MyEvaluator extends Evaluator { @@ -107,5 +108,7 @@ object CrossValidatorSuite { } override val uid: String = "eval" + + override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) } }