From 29aabdd6c20197adb16706823a8c7f48a0074352 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 15 Apr 2015 10:23:53 +0100 Subject: [PATCH 001/191] [HOTFIX] [SPARK-6896] [SQL] fix compile error in hive-thriftserver SPARK-6440 #5424 import guava but did not promote guava dependency to compile level. [INFO] compiler plugin: BasicArtifact(org.scalamacros,paradise_2.10.4,2.0.1,null) [info] Compiling 8 Scala sources to /root/projects/spark/sql/hive-thriftserver/target/scala-2.10/classes... [error] bad symbolic reference. A signature in Utils.class refers to term util [error] in package com.google.common which is not available. [error] It may be completely missing from the current classpath, or the version on [error] the classpath might be incompatible with the version used when compiling Utils.class. [error] [error] while compiling: /root/projects/spark/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala [error] during phase: erasure [error] library version: version 2.10.4 [error] compiler version: version 2.10.4 [error] reconstructed args: -deprecation -classpath Author: Daoyuan Wang Closes #5507 from adrian-wang/guava and squashes the following commits: c337dad [Daoyuan Wang] fix compile error --- sql/hive-thriftserver/pom.xml | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index a96b1ffc26966..f38c796241df1 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -44,7 +44,6 @@ com.google.guava guava - runtime ${hive.group} From 6c5ed8a6d552abd967d27cdb94b68d46ccb57221 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 15 Apr 2015 15:17:58 +0100 Subject: [PATCH 002/191] SPARK-6861 [BUILD] Scalastyle config prevents building Maven child modules alone Move scalastyle-config.xml to dev/ (SBT config still doesn't work) to fix running mvn targets from subdirs; make scalastyle a verify stage target again in Maven; output results in target not project root; update to scalastyle 0.7.0 Author: Sean Owen Closes #5471 from srowen/SPARK-6861 and squashes the following commits: acac637 [Sean Owen] Oops, add back execution but leave it at the default verify phase 35a4fd2 [Sean Owen] Revert change to scalastyle-config.xml location, but return scalastyle Maven check to verify phase instead of package to get it farther out of the way, since the Maven invocation is optional c4fb42c [Sean Owen] Move scalastyle-config.xml to dev/ (SBT config still doesn't work) to fix running mvn targets from subdirs; make scalastyle a verify stage target again in Maven; output results in target not project root; update to scalastyle 0.7.0 --- pom.xml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index 261292d5b6cde..bcc2f57f1af5d 100644 --- a/pom.xml +++ b/pom.xml @@ -1447,7 +1447,7 @@ org.scalastyle scalastyle-maven-plugin - 0.4.0 + 0.7.0 false true @@ -1456,13 +1456,12 @@ ${basedir}/src/main/scala ${basedir}/src/test/scala scalastyle-config.xml - scalastyle-output.xml + ${basedir}/target/scalastyle-output.xml ${project.build.sourceEncoding} ${project.reporting.outputEncoding} - package check From f11288d5272bc18585b8cad4ee3bd59eade7c296 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 15 Apr 2015 12:58:02 -0700 Subject: [PATCH 003/191] [SPARK-6886] [PySpark] fix big closure with shuffle Currently, the created broadcast object will have same life cycle as RDD in Python. For multistage jobs, an PythonRDD will be created in JVM and the RDD in Python may be GCed, then the broadcast will be destroyed in JVM before the PythonRDD. This PR change to use PythonRDD to track the lifecycle of the broadcast object. It also have a refactor about getNumPartitions() to avoid unnecessary creation of PythonRDD, which could be heavy. cc JoshRosen Author: Davies Liu Closes #5496 from davies/big_closure and squashes the following commits: 9a0ea4c [Davies Liu] fix big closure with shuffle --- python/pyspark/rdd.py | 15 +++++---------- python/pyspark/tests.py | 6 ++---- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c9ac95d117574..93e658eded9e2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1197,7 +1197,7 @@ def take(self, num): [91, 92, 93] """ items = [] - totalParts = self._jrdd.partitions().size() + totalParts = self.getNumPartitions() partsScanned = 0 while len(items) < num and partsScanned < totalParts: @@ -1260,7 +1260,7 @@ def isEmpty(self): >>> sc.parallelize([1]).isEmpty() False """ - return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0 + return self.getNumPartitions() == 0 or len(self.take(1)) == 0 def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ @@ -2235,11 +2235,9 @@ def _prepare_for_python_RDD(sc, command, obj=None): ser = CloudPickleSerializer() pickled_command = ser.dumps((command, sys.version_info[:2])) if len(pickled_command) > (1 << 20): # 1M + # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) - # tracking the life cycle by obj - if obj is not None: - obj._broadcast = broadcast broadcast_vars = ListConverter().convert( [x._jbroadcast for x in sc._pickled_broadcast_vars], sc._gateway._gateway_client) @@ -2294,12 +2292,9 @@ def pipeline_func(split, iterator): self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self.partitioner = prev.partitioner if self.preservesPartitioning else None - self._broadcast = None - def __del__(self): - if self._broadcast: - self._broadcast.unpersist() - self._broadcast = None + def getNumPartitions(self): + return self._prev_jrdd.partitions().size() @property def _jrdd(self): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b938b9ce12395..ee67e80d539f8 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -550,10 +550,8 @@ def test_large_closure(self): data = [float(i) for i in xrange(N)] rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) self.assertEquals(N, rdd.first()) - self.assertTrue(rdd._broadcast is not None) - rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1) - self.assertEqual(1, rdd.first()) - self.assertTrue(rdd._broadcast is None) + # regression test for SPARK-6886 + self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) From b75b3070740803480d235b0c9a86673721344f30 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Apr 2015 13:00:19 -0700 Subject: [PATCH 004/191] [SPARK-6730][SQL] Allow using keyword as identifier in OPTIONS JIRA: https://issues.apache.org/jira/browse/SPARK-6730 It is very possible that keyword will be used as identifier in `OPTIONS`, this pr makes it works. However, another approach is that we can request that `OPTIONS` can't include keywords and has to use alternative identifier (e.g. table -> cassandraTable) if needed. If so, please let me know to close this pr. Thanks. Author: Liang-Chi Hsieh Closes #5520 from viirya/relax_options and squashes the following commits: 339fd68 [Liang-Chi Hsieh] Use regex parser. 92be11c [Liang-Chi Hsieh] Allow using keyword as identifier in OPTIONS. --- .../scala/org/apache/spark/sql/sources/ddl.scala | 15 ++++++++++++++- .../apache/spark/sql/sources/DDLTestSuite.scala | 11 ++++++----- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 319de710fbc3e..2e861b84b7133 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources import scala.language.existentials +import scala.util.matching.Regex import scala.language.implicitConversions import org.apache.spark.Logging @@ -155,7 +156,19 @@ private[sql] class DDLParser( protected lazy val className: Parser[String] = repsep(ident, ".") ^^ { case s => s.mkString(".")} - protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } + override implicit def regexToParser(regex: Regex): Parser[String] = acceptMatch( + s"identifier matching regex ${regex}", { + case lexical.Identifier(str) if regex.unapplySeq(str).isDefined => str + case lexical.Keyword(str) if regex.unapplySeq(str).isDefined => str + } + ) + + protected lazy val optionName: Parser[String] = "[_a-zA-Z][a-zA-Z0-9]*".r ^^ { + case name => name + } + + protected lazy val pair: Parser[(String, String)] = + optionName ~ stringLit ^^ { case k ~ v => (k,v) } protected lazy val column: Parser[StructField] = ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 3f24a497390c1..ca25751b9583d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -25,17 +25,17 @@ class DDLScanSource extends RelationProvider { override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { - SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt)(sqlContext) + SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext) } } -case class SimpleDDLScan(from: Int, to: Int)(@transient val sqlContext: SQLContext) +case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan { override def schema: StructType = StructType(Seq( StructField("intType", IntegerType, nullable = false, - new MetadataBuilder().putString("comment", "test comment").build()), + new MetadataBuilder().putString("comment", s"test comment $table").build()), StructField("stringType", StringType, nullable = false), StructField("dateType", DateType, nullable = false), StructField("timestampType", TimestampType, nullable = false), @@ -73,7 +73,8 @@ class DDLTestSuite extends DataSourceTest { |USING org.apache.spark.sql.sources.DDLScanSource |OPTIONS ( | From '1', - | To '10' + | To '10', + | Table 'test1' |) """.stripMargin) } @@ -81,7 +82,7 @@ class DDLTestSuite extends DataSourceTest { sqlTest( "describe ddlPeople", Seq( - Row("intType", "int", "test comment"), + Row("intType", "int", "test comment test1"), Row("stringType", "string", ""), Row("dateType", "date", ""), Row("timestampType", "timestamp", ""), From e3e4e9a38b25174ed8bb460ba2b375813ebf3b4b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Apr 2015 13:01:29 -0700 Subject: [PATCH 005/191] [SPARK-6800][SQL] Update doc for JDBCRelation's columnPartition JIRA https://issues.apache.org/jira/browse/SPARK-6800 Author: Liang-Chi Hsieh Closes #5488 from viirya/fix_jdbc_where and squashes the following commits: 51386c8 [Liang-Chi Hsieh] Update code comment. 1dcc929 [Liang-Chi Hsieh] Update document. 3eb74d6 [Liang-Chi Hsieh] Revert and modify doc. df11783 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into fix_jdbc_where 3e7db15 [Liang-Chi Hsieh] Fix wrong logic to generate WHERE clause for JDBC. --- docs/sql-programming-guide.md | 5 ++++- .../src/main/scala/org/apache/spark/sql/SQLContext.scala | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 332618edf0c55..03500867df70f 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1371,7 +1371,10 @@ the Data Sources API. The following options are supported: These options must all be specified if any of them is specified. They describe how to partition the table when reading in parallel from multiple workers. - partitionColumn must be a numeric column from the table in question. + partitionColumn must be a numeric column from the table in question. Notice + that lowerBound and upperBound are just used to decide the + partition stride, not for filtering the rows in table. So all rows in the table will be + partitioned and returned. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index c25ef58e6f62a..b237fe684cdc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -873,8 +873,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * passed to this function. * * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` to retrieve - * @param upperBound the maximum value of `columnName` to retrieve + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions * From 785f95586b951d7b05481ee925fb95c20c4d6b6f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 15 Apr 2015 13:04:03 -0700 Subject: [PATCH 006/191] [SPARK-6887][SQL] ColumnBuilder misses FloatType https://issues.apache.org/jira/browse/SPARK-6887 Author: Yin Huai Closes #5499 from yhuai/inMemFloat and squashes the following commits: 84cba38 [Yin Huai] Add test. 4b75ba6 [Yin Huai] Add FloatType back. --- .../spark/sql/columnar/ColumnBuilder.scala | 1 + .../org/apache/spark/sql/QueryTest.scala | 3 + .../columnar/InMemoryColumnarQuerySuite.scala | 59 ++++++++++++++++++- 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index c881747751520..00ed70430b84d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -153,6 +153,7 @@ private[sql] object ColumnBuilder { val builder: ColumnBuilder = dataType match { case IntegerType => new IntColumnBuilder case LongType => new LongColumnBuilder + case FloatType => new FloatColumnBuilder case DoubleType => new DoubleColumnBuilder case BooleanType => new BooleanColumnBuilder case ByteType => new ByteColumnBuilder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 9a81fc5d72819..59f9508444f25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -104,9 +104,12 @@ object QueryTest { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. val converted: Seq[Row] = answer.map { s => Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq case o => o }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 479210d1c9c43..56591d9dba29e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.columnar +import java.sql.{Date, Timestamp} + import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.types.{DecimalType, Decimal} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY @@ -132,4 +134,59 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM test_fixed_decimal"), (1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal))) } + + test("test different data types") { + // Create the schema. + val struct = + StructType( + StructField("f1", FloatType, true) :: + StructField("f2", ArrayType(BooleanType), true) :: Nil) + val dataTypes = + Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + DateType, TimestampType, + ArrayType(IntegerType), MapType(StringType, LongType), struct) + val fields = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + val allColumns = fields.map(_.name).mkString(",") + val schema = StructType(fields) + + // Create a RDD for the schema + val rdd = + sparkContext.parallelize((1 to 100), 10).map { i => + Row( + s"str${i}: test cache.", + s"binary${i}: test cache.".getBytes("UTF-8"), + null, + i % 2 == 0, + i.toByte, + i.toShort, + i, + Long.MaxValue - i.toLong, + (i + 0.25).toFloat, + (i + 0.75), + BigDecimal(Long.MaxValue.toString + ".12345"), + new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), + new Date(i), + new Timestamp(i), + (1 to i).toSeq, + (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, + Row((i - 0.25).toFloat, (1 to i).toSeq)) + } + createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + // Cache the table. + sql("cache table InMemoryCache_different_data_types") + // Make sure the table is indeed cached. + val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan + assert( + isCached("InMemoryCache_different_data_types"), + "InMemoryCache_different_data_types should be cached.") + // Issue a query and check the results. + checkAnswer( + sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), + table("InMemoryCache_different_data_types").collect()) + dropTempTable("InMemoryCache_different_data_types") + } } From 85842760dc4616577162f44cc0fa9db9bd23bd9c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 15 Apr 2015 13:06:38 -0700 Subject: [PATCH 007/191] [SPARK-6638] [SQL] Improve performance of StringType in SQL This PR change the internal representation for StringType from java.lang.String to UTF8String, which is implemented use ArrayByte. This PR should not break any public API, Row.getString() will still return java.lang.String. This is the first step of improve the performance of String in SQL. cc rxin Author: Davies Liu Closes #5350 from davies/string and squashes the following commits: 3b7bfa8 [Davies Liu] fix schema of AddJar 2772f0d [Davies Liu] fix new test failure 6d776a9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 59025c8 [Davies Liu] address comments from @marmbrus 341ec2c [Davies Liu] turn off scala style check in UTF8StringSuite 744788f [Davies Liu] Merge branch 'master' of github.com:apache/spark into string b04a19c [Davies Liu] add comment for getString/setString 08d897b [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 5116b43 [Davies Liu] rollback unrelated changes 1314a37 [Davies Liu] address comments from Yin 867bf50 [Davies Liu] fix String filter push down 13d9d42 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 2089d24 [Davies Liu] add hashcode check back ac18ae6 [Davies Liu] address comment fd11364 [Davies Liu] optimize UTF8String 8d17f21 [Davies Liu] fix hive compatibility tests e5fa5b8 [Davies Liu] remove clone in UTF8String 28f3d81 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 28d6f32 [Davies Liu] refactor 537631c [Davies Liu] some comment about Date 9f4c194 [Davies Liu] convert data type for data source 956b0a4 [Davies Liu] fix hive tests 73e4363 [Davies Liu] Merge branch 'master' of github.com:apache/spark into string 9dc32d1 [Davies Liu] fix some hive tests 23a766c [Davies Liu] refactor 8b45864 [Davies Liu] fix codegen with UTF8String bb52e44 [Davies Liu] fix scala style c7dd4d2 [Davies Liu] fix some catalyst tests 38c303e [Davies Liu] fix python sql tests 5f9e120 [Davies Liu] fix sql tests 6b499ac [Davies Liu] fix style a85fb27 [Davies Liu] refactor d32abd1 [Davies Liu] fix utf8 for python api 4699c3a [Davies Liu] use Array[Byte] in UTF8String 21f67c6 [Davies Liu] cleanup 685fd07 [Davies Liu] use UTF8String instead of String for StringType --- python/pyspark/sql/dataframe.py | 10 +- .../main/scala/org/apache/spark/sql/Row.scala | 3 +- .../sql/catalyst/CatalystTypeConverters.scala | 37 +++ .../spark/sql/catalyst/ScalaReflection.scala | 1 + .../catalyst/analysis/HiveTypeCoercion.scala | 6 +- .../spark/sql/catalyst/expressions/Cast.scala | 36 +-- .../expressions/SpecificMutableRow.scala | 12 +- .../expressions/codegen/CodeGenerator.scala | 32 ++- .../codegen/GenerateProjection.scala | 46 ++-- .../sql/catalyst/expressions/generators.scala | 7 +- .../sql/catalyst/expressions/literals.scala | 7 +- .../sql/catalyst/expressions/predicates.scala | 3 +- .../spark/sql/catalyst/expressions/rows.scala | 14 +- .../expressions/stringOperations.scala | 90 ++++---- .../sql/catalyst/optimizer/Optimizer.scala | 21 +- .../apache/spark/sql/types/DateUtils.scala | 1 + .../apache/spark/sql/types/UTF8String.scala | 214 ++++++++++++++++++ .../apache/spark/sql/types/dataTypes.scala | 6 +- .../ExpressionEvaluationSuite.scala | 90 ++++---- .../GeneratedMutableEvaluationSuite.scala | 4 +- .../spark/sql/types/UTF8StringSuite.scala | 70 ++++++ .../org/apache/spark/sql/SQLContext.scala | 1 + .../spark/sql/columnar/ColumnStats.scala | 6 +- .../spark/sql/columnar/ColumnType.scala | 20 +- .../spark/sql/execution/ExistingRDD.scala | 31 ++- .../apache/spark/sql/execution/commands.scala | 13 +- .../spark/sql/execution/debug/package.scala | 2 +- .../spark/sql/execution/pythonUdfs.scala | 4 +- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 4 +- .../apache/spark/sql/jdbc/JDBCRelation.scala | 2 + .../org/apache/spark/sql/jdbc/jdbc.scala | 5 +- .../apache/spark/sql/json/JSONRelation.scala | 8 +- .../org/apache/spark/sql/json/JsonRDD.scala | 2 +- .../spark/sql/parquet/ParquetConverter.scala | 19 +- .../spark/sql/parquet/ParquetFilters.scala | 12 +- .../sql/parquet/ParquetTableSupport.scala | 7 +- .../apache/spark/sql/parquet/newParquet.scala | 11 +- .../sql/sources/DataSourceStrategy.scala | 37 +-- .../apache/spark/sql/sources/interfaces.scala | 10 + .../scala/org/apache/spark/sql/RowSuite.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 10 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 8 +- .../sql/columnar/ColumnarTestUtils.scala | 4 +- .../spark/sql/sources/TableScanSuite.scala | 10 +- .../spark/sql/hive/HiveInspectors.scala | 22 +- .../spark/sql/hive/HiveStrategies.scala | 13 +- .../hive/execution/ScriptTransformation.scala | 17 +- .../spark/sql/hive/execution/commands.scala | 10 +- .../org/apache/spark/sql/hive/Shim12.scala | 4 +- .../org/apache/spark/sql/hive/Shim13.scala | 36 ++- 50 files changed, 742 insertions(+), 298 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ef91a9c4f522d..f2c3b74a185cf 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -456,7 +456,7 @@ def join(self, other, joinExprs=None, joinType=None): One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() - [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)] + [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] """ if joinExprs is None: @@ -637,9 +637,9 @@ def groupBy(self, *cols): >>> df.groupBy().avg().collect() [Row(AVG(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() - [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)] + [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] >>> df.groupBy(df.name).avg().collect() - [Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)] + [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] """ jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) @@ -867,11 +867,11 @@ def agg(self, *exprs): >>> gdf = df.groupBy(df.name) >>> gdf.agg({"*": "count"}).collect() - [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)] + [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] >>> from pyspark.sql import functions as F >>> gdf.agg(F.min(df.age)).collect() - [Row(MIN(age)=5), Row(MIN(age)=2)] + [Row(MIN(age)=2), Row(MIN(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index d794f034f5578..ac8a782976465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.{StructType, DateUtils} +import org.apache.spark.sql.types.StructType object Row { /** @@ -257,6 +257,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ + // TODO(davies): This is not the right default implementation, we use Int as Date internally def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 91976fef6dc0d..d4f9fdacda4fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -77,6 +77,9 @@ object CatalystTypeConverters { } new GenericRowWithSchema(ar, structType) + case (d: String, _) => + UTF8String(d) + case (d: BigDecimal, _) => Decimal(d) @@ -175,6 +178,11 @@ object CatalystTypeConverters { case other => other } + case dataType: StringType => (item: Any) => extractOption(item) match { + case s: String => UTF8String(s) + case other => other + } + case _ => (item: Any) => extractOption(item) match { case d: BigDecimal => Decimal(d) @@ -184,6 +192,26 @@ object CatalystTypeConverters { } } + /** + * Converts Scala objects to catalyst rows / types. + * + * Note: This should be called before do evaluation on Row + * (It does not support UDT) + * This is used to create an RDD or test results with correct types for Catalyst. + */ + def convertToCatalyst(a: Any): Any = a match { + case s: String => UTF8String(s) + case d: java.sql.Date => DateUtils.fromJavaDate(d) + case d: BigDecimal => Decimal(d) + case d: java.math.BigDecimal => Decimal(d) + case seq: Seq[Any] => seq.map(convertToCatalyst) + case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) + case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray + case m: Map[Any, Any] => + m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap + case other => other + } + /** * Converts Catalyst types used internally in rows to standard Scala types * This method is slow, and for batch conversion you should be using converter @@ -211,6 +239,9 @@ object CatalystTypeConverters { case (i: Int, DateType) => DateUtils.toJavaDate(i) + case (s: UTF8String, StringType) => + s.toString() + case (other, _) => other } @@ -262,6 +293,12 @@ object CatalystTypeConverters { case other => other } + case StringType => + (item: Any) => item match { + case s: UTF8String => s.toString() + case other => other + } + case other => (item: Any) => item } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 01d5c1512201a..d9521953cad73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -138,6 +138,7 @@ trait ScalaReflection { // The data type can be determined without ambiguity. case obj: BooleanType.JvmType => BooleanType case obj: BinaryType.JvmType => BinaryType + case obj: String => StringType case obj: StringType.JvmType => StringType case obj: ByteType.JvmType => ByteType case obj: ShortType.JvmType => ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 3aeb964994d37..35c7f00d4e42a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -115,7 +115,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal.create("NaN", StringType) + val stringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -563,6 +563,10 @@ trait HiveTypeCoercion { case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) + // Compatible with Hive + case Substring(e, start, len) if e.dataType != StringType => + Substring(Cast(e, StringType), start, len) + // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 31f1a5fdc7e53..adf941ab2a45f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types._ /** Cast the child expression to the target data type. */ @@ -112,21 +111,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { - case BinaryType => buildCast[Array[Byte]](_, new String(_, "UTF-8")) - case DateType => buildCast[Int](_, d => DateUtils.toString(d)) - case TimestampType => buildCast[Timestamp](_, timestampToString) - case _ => buildCast[Any](_, _.toString) + case BinaryType => buildCast[Array[Byte]](_, UTF8String(_)) + case DateType => buildCast[Int](_, d => UTF8String(DateUtils.toString(d))) + case TimestampType => buildCast[Timestamp](_, t => UTF8String(timestampToString(t))) + case _ => buildCast[Any](_, o => UTF8String(o.toString)) } // BinaryConverter private[this] def castToBinary(from: DataType): Any => Any = from match { - case StringType => buildCast[String](_, _.getBytes("UTF-8")) + case StringType => buildCast[UTF8String](_, _.getBytes) } // UDFToBoolean private[this] def castToBoolean(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, _.length() != 0) + buildCast[UTF8String](_, _.length() != 0) case TimestampType => buildCast[Timestamp](_, t => t.getTime() != 0 || t.getNanos() != 0) case DateType => @@ -151,8 +150,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // TimestampConverter private[this] def castToTimestamp(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => { + buildCast[UTF8String](_, utfs => { // Throw away extra if more than 9 decimal places + val s = utfs.toString val periodIdx = s.indexOf(".") var n = s if (periodIdx != -1 && n.length() - periodIdx > 9) { @@ -227,8 +227,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => - try DateUtils.fromJavaDate(Date.valueOf(s)) + buildCast[UTF8String](_, s => + try DateUtils.fromJavaDate(Date.valueOf(s.toString)) catch { case _: java.lang.IllegalArgumentException => null } ) case TimestampType => @@ -245,7 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toLong catch { + buildCast[UTF8String](_, s => try s.toString.toLong catch { case _: NumberFormatException => null }) case BooleanType => @@ -261,7 +261,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toInt catch { + buildCast[UTF8String](_, s => try s.toString.toInt catch { case _: NumberFormatException => null }) case BooleanType => @@ -277,7 +277,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toShort catch { + buildCast[UTF8String](_, s => try s.toString.toShort catch { case _: NumberFormatException => null }) case BooleanType => @@ -293,7 +293,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toByte catch { + buildCast[UTF8String](_, s => try s.toString.toByte catch { case _: NumberFormatException => null }) case BooleanType => @@ -323,7 +323,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => - buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch { + buildCast[UTF8String](_, s => try { + changePrecision(Decimal(s.toString.toDouble), target) + } catch { case _: NumberFormatException => null }) case BooleanType => @@ -348,7 +350,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // DoubleConverter private[this] def castToDouble(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toDouble catch { + buildCast[UTF8String](_, s => try s.toString.toDouble catch { case _: NumberFormatException => null }) case BooleanType => @@ -364,7 +366,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // FloatConverter private[this] def castToFloat(from: DataType): Any => Any = from match { case StringType => - buildCast[String](_, s => try s.toFloat catch { + buildCast[UTF8String](_, s => try s.toString.toFloat catch { case _: NumberFormatException => null }) case BooleanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 47b6f358ed1b1..3475ed05f4454 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -230,13 +230,17 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR new GenericRow(newValues) } - override def update(ordinal: Int, value: Any): Unit = { - if (value == null) setNullAt(ordinal) else values(ordinal).update(value) + override def update(ordinal: Int, value: Any) { + if (value == null) { + setNullAt(ordinal) + } else { + values(ordinal).update(value) + } } - override def setString(ordinal: Int, value: String): Unit = update(ordinal, value) + override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value)) - override def getString(ordinal: Int): String = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int): String = apply(ordinal).toString override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d141354a0f427..be2c101d63a63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -216,10 +216,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val $primitiveTerm: ${termForType(dataType)} = $value """.children - case expressions.Literal(value: String, dataType) => + case expressions.Literal(value: UTF8String, dataType) => q""" val $nullTerm = ${value == null} - val $primitiveTerm: ${termForType(dataType)} = $value + val $primitiveTerm: ${termForType(dataType)} = + org.apache.spark.sql.types.UTF8String(${value.getBytes}) """.children case expressions.Literal(value: Int, dataType) => @@ -243,11 +244,14 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - new String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.asInstanceOf[Array[Byte]]) """.children case Cast(child @ DateType(), StringType) => - child.castOrNull(c => q"org.apache.spark.sql.types.DateUtils.toString($c)", StringType) + child.castOrNull(c => + q"""org.apache.spark.sql.types.UTF8String( + org.apache.spark.sql.types.DateUtils.toString($c))""", + StringType) case Cast(child @ NumericType(), IntegerType) => child.castOrNull(c => q"$c.toInt", IntegerType) @@ -272,9 +276,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin if($nullTerm) ${defaultPrimitive(StringType)} else - ${eval.primitiveTerm}.toString + org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) """.children + case EqualTo(e1: BinaryType, e2: BinaryType) => + (e1, e2).evaluateAs (BooleanType) { + case (eval1, eval2) => + q""" + java.util.Arrays.equals($eval1.asInstanceOf[Array[Byte]], + $eval2.asInstanceOf[Array[Byte]]) + """ + } + case EqualTo(e1, e2) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q"$eval1 == $eval2" } @@ -597,7 +610,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin val localLogger = log val localLoggerTree = reify { localLogger } q""" - $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm)) + $localLoggerTree.debug( + ${e.toString} + ": " + (if ($nullTerm) "null" else $primitiveTerm.toString)) """ :: Nil } else { Nil @@ -608,6 +622,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { dataType match { + case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" } @@ -619,6 +634,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin ordinal: Int, value: TermName) = { dataType match { + case StringType => q"$destinationRow.update($ordinal, $value)" case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" case _ => q"$destinationRow.update($ordinal, $value)" } @@ -642,13 +658,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case DoubleType => "Double" case FloatType => "Float" case BooleanType => "Boolean" - case StringType => "String" + case StringType => "org.apache.spark.sql.types.UTF8String" } protected def defaultPrimitive(dt: DataType) = dt match { case BooleanType => ru.Literal(Constant(false)) case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => ru.Literal(Constant("")) + case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" case ShortType => ru.Literal(Constant(-1.toShort)) case LongType => ru.Literal(Constant(-1L)) case ByteType => ru.Literal(Constant(-1.toByte)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 69397a73a8880..6f572ff959fb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -111,36 +111,54 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificAccessorFunctions = NativeType.all.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - case (e, i) if e.dataType == dataType => + // getString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? q"if(i == $i) return $elementName" :: Nil case _ => Nil } - - q""" - override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { - ..$ifStatements; - $accessorFailure - }""" + dataType match { + // Row() need this interface to compile + case StringType => + q""" + override def getString(i: Int): String = { + $accessorFailure + }""" + case other => + q""" + override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { + ..$ifStatements; + $accessorFailure + }""" + } } val specificMutatorFunctions = NativeType.all.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - case (e, i) if e.dataType == dataType => + // setString() is not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil case _ => Nil } - - q""" - override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { - ..$ifStatements; - $accessorFailure - }""" + dataType match { + case StringType => + // MutableRow() need this interface to compile + q""" + override def setString(i: Int, value: String) { + $accessorFailure + }""" + case other => + q""" + override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { + ..$ifStatements; + $accessorFailure + }""" + } } val hashValues = expressions.zipWithIndex.map { case (e,i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 860b72fad38b3..67caadb839ff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ /** @@ -85,8 +85,11 @@ case class UserDefinedGenerator( override protected def makeOutput(): Seq[Attribute] = schema override def eval(input: Row): TraversableOnce[Row] = { + // TODO(davies): improve this + // Convert the objects into Scala Type before calling function, we need schema to support UDT + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) val inputRow = new InterpretedProjection(children) - function(inputRow(input)) + function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row]) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 0e2d593e94124..18cba4cc46707 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types._ object Literal { @@ -29,7 +30,7 @@ object Literal { case f: Float => Literal(f, FloatType) case b: Byte => Literal(b, ByteType) case s: Short => Literal(s, ShortType) - case s: String => Literal(s, StringType) + case s: String => Literal(UTF8String(s), StringType) case b: Boolean => Literal(b, BooleanType) case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) @@ -42,7 +43,9 @@ object Literal { throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) } - def create(v: Any, dataType: DataType): Literal = Literal(v, dataType) + def create(v: Any, dataType: DataType): Literal = { + Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7e47cb3fffe12..fcd6352079b4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -179,8 +179,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison val r = right.eval(input) if (r == null) null else if (left.dataType != BinaryType) l == r - else BinaryType.ordering.compare( - l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) == 0 + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 0a275b84086cf..1b62e17ff47fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{StructType, NativeType} +import org.apache.spark.sql.types.{UTF8String, StructType, NativeType} /** @@ -37,6 +37,7 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) + // TODO(davies): add setDate() and setDecimal() } /** @@ -114,9 +115,15 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } override def getString(i: Int): String = { - values(i).asInstanceOf[String] + values(i) match { + case null => null + case s: String => s + case utf8: UTF8String => utf8.toString + } } + // TODO(davies): add getDate and getDecimal + // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 @@ -189,8 +196,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value } - + override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)} override def setNullAt(i: Int): Unit = { values(i) = null } override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index acfbbace608ef..d597bf7ce756a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,11 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern -import scala.collection.IndexedSeqOptimized - - import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, StringType} +import org.apache.spark.sql.types._ trait StringRegexExpression { self: BinaryExpression => @@ -60,38 +57,17 @@ trait StringRegexExpression { if(r == null) { null } else { - val regex = pattern(r.asInstanceOf[String]) + val regex = pattern(r.asInstanceOf[UTF8String].toString) if(regex == null) { null } else { - matches(regex, l.asInstanceOf[String]) + matches(regex, l.asInstanceOf[UTF8String].toString) } } } } } -trait CaseConversionExpression { - self: UnaryExpression => - - type EvaluatedType = Any - - def convert(v: String): String - - override def foldable: Boolean = child.foldable - def nullable: Boolean = child.nullable - def dataType: DataType = StringType - - override def eval(input: Row): Any = { - val evaluated = child.eval(input) - if (evaluated == null) { - null - } else { - convert(evaluated.toString) - } - } -} - /** * Simple RegEx pattern matching function */ @@ -134,12 +110,33 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } +trait CaseConversionExpression { + self: UnaryExpression => + + type EvaluatedType = Any + + def convert(v: UTF8String): UTF8String + + override def foldable: Boolean = child.foldable + def nullable: Boolean = child.nullable + def dataType: DataType = StringType + + override def eval(input: Row): Any = { + val evaluated = child.eval(input) + if (evaluated == null) { + null + } else { + convert(evaluated.asInstanceOf[UTF8String]) + } + } +} + /** * A function that converts the characters of a string to uppercase. */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: String): String = v.toUpperCase() + override def convert(v: UTF8String): UTF8String = v.toUpperCase override def toString: String = s"Upper($child)" } @@ -149,7 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: String): String = v.toLowerCase() + override def convert(v: UTF8String): UTF8String = v.toLowerCase override def toString: String = s"Lower($child)" } @@ -162,15 +159,16 @@ trait StringComparison { override def nullable: Boolean = left.nullable || right.nullable - def compare(l: String, r: String): Boolean + def compare(l: UTF8String, r: UTF8String): Boolean override def eval(input: Row): Any = { - val leftEval = left.eval(input).asInstanceOf[String] + val leftEval = left.eval(input) if(leftEval == null) { null } else { - val rightEval = right.eval(input).asInstanceOf[String] - if (rightEval == null) null else compare(leftEval, rightEval) + val rightEval = right.eval(input) + if (rightEval == null) null + else compare(leftEval.asInstanceOf[UTF8String], rightEval.asInstanceOf[UTF8String]) } } @@ -184,7 +182,7 @@ trait StringComparison { */ case class Contains(left: Expression, right: Expression) extends BinaryPredicate with StringComparison { - override def compare(l: String, r: String): Boolean = l.contains(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) } /** @@ -192,7 +190,7 @@ case class Contains(left: Expression, right: Expression) */ case class StartsWith(left: Expression, right: Expression) extends BinaryPredicate with StringComparison { - override def compare(l: String, r: String): Boolean = l.startsWith(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) } /** @@ -200,7 +198,7 @@ case class StartsWith(left: Expression, right: Expression) */ case class EndsWith(left: Expression, right: Expression) extends BinaryPredicate with StringComparison { - override def compare(l: String, r: String): Boolean = l.endsWith(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) } /** @@ -224,9 +222,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends override def children: Seq[Expression] = str :: pos :: len :: Nil @inline - def slice[T, C <: Any](str: C, startPos: Int, sliceLen: Int) - (implicit ev: (C=>IndexedSeqOptimized[T,_])): Any = { - val len = str.length + def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers @@ -235,7 +231,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends val start = startPos match { case pos if pos > 0 => pos - 1 - case neg if neg < 0 => len + neg + case neg if neg < 0 => length() + neg case _ => 0 } @@ -244,12 +240,11 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends case x => start + x } - str.slice(start, end) + (start, end) } override def eval(input: Row): Any = { val string = str.eval(input) - val po = pos.eval(input) val ln = len.eval(input) @@ -257,11 +252,14 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends null } else { val start = po.asInstanceOf[Int] - val length = ln.asInstanceOf[Int] - + val length = ln.asInstanceOf[Int] string match { - case ba: Array[Byte] => slice(ba, start, length) - case other => slice(other.toString, start, length) + case ba: Array[Byte] => + val (st, end) = slicePos(start, length, () => ba.length) + ba.slice(st, end) + case s: UTF8String => + val (st, end) = slicePos(start, length, () => s.length) + s.slice(st, end) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 93e69d409cb91..7c80634d2c852 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -198,14 +198,19 @@ object LikeSimplification extends Rule[LogicalPlan] { val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Like(l, Literal(startsWith(pattern), StringType)) if !pattern.endsWith("\\") => - StartsWith(l, Literal(pattern)) - case Like(l, Literal(endsWith(pattern), StringType)) => - EndsWith(l, Literal(pattern)) - case Like(l, Literal(contains(pattern), StringType)) if !pattern.endsWith("\\") => - Contains(l, Literal(pattern)) - case Like(l, Literal(equalTo(pattern), StringType)) => - EqualTo(l, Literal(pattern)) + case Like(l, Literal(utf, StringType)) => + utf.toString match { + case startsWith(pattern) if !pattern.endsWith("\\") => + StartsWith(l, Literal(pattern)) + case endsWith(pattern) => + EndsWith(l, Literal(pattern)) + case contains(pattern) if !pattern.endsWith("\\") => + Contains(l, Literal(pattern)) + case equalTo(pattern) => + EqualTo(l, Literal(pattern)) + case _ => + Like(l, Literal.create(utf, StringType)) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala index 504fb05842505..d36a49159b87f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala @@ -40,6 +40,7 @@ object DateUtils { millisToDays(d.getTime) } + // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisLocal: Long): Int = { ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala new file mode 100644 index 0000000000000..fc02ba6c9c43e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -0,0 +1,214 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import java.util.Arrays + +/** + * A UTF-8 String, as internal representation of StringType in SparkSQL + * + * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, + * search, see http://en.wikipedia.org/wiki/UTF-8 for details. + * + * Note: This is not designed for general use cases, should not be used outside SQL. + */ + +final class UTF8String extends Ordered[UTF8String] with Serializable { + + private[this] var bytes: Array[Byte] = _ + + /** + * Update the UTF8String with String. + */ + def set(str: String): UTF8String = { + bytes = str.getBytes("utf-8") + this + } + + /** + * Update the UTF8String with Array[Byte], which should be encoded in UTF-8 + */ + def set(bytes: Array[Byte]): UTF8String = { + this.bytes = bytes + this + } + + /** + * Return the number of bytes for a code point with the first byte as `b` + * @param b The first byte of a code point + */ + @inline + private[this] def numOfBytes(b: Byte): Int = { + val offset = (b & 0xFF) - 192 + if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1 + } + + /** + * Return the number of code points in it. + * + * This is only used by Substring() when `start` is negative. + */ + def length(): Int = { + var len = 0 + var i: Int = 0 + while (i < bytes.length) { + i += numOfBytes(bytes(i)) + len += 1 + } + len + } + + def getBytes: Array[Byte] = { + bytes + } + + /** + * Return a substring of this, + * @param start the position of first code point + * @param until the position after last code point + */ + def slice(start: Int, until: Int): UTF8String = { + if (until <= start || start >= bytes.length || bytes == null) { + new UTF8String + } + + var c = 0 + var i: Int = 0 + while (c < start && i < bytes.length) { + i += numOfBytes(bytes(i)) + c += 1 + } + var j = i + while (c < until && j < bytes.length) { + j += numOfBytes(bytes(j)) + c += 1 + } + UTF8String(Arrays.copyOfRange(bytes, i, j)) + } + + def contains(sub: UTF8String): Boolean = { + val b = sub.getBytes + if (b.length == 0) { + return true + } + var i: Int = 0 + while (i <= bytes.length - b.length) { + // In worst case, it's O(N*K), but should works fine with SQL + if (bytes(i) == b(0) && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) { + return true + } + i += 1 + } + false + } + + def startsWith(prefix: UTF8String): Boolean = { + val b = prefix.getBytes + if (b.length > bytes.length) { + return false + } + Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b) + } + + def endsWith(suffix: UTF8String): Boolean = { + val b = suffix.getBytes + if (b.length > bytes.length) { + return false + } + Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b) + } + + def toUpperCase(): UTF8String = { + // upper case depends on locale, fallback to String. + UTF8String(toString().toUpperCase) + } + + def toLowerCase(): UTF8String = { + // lower case depends on locale, fallback to String. + UTF8String(toString().toLowerCase) + } + + override def toString(): String = { + new String(bytes, "utf-8") + } + + override def clone(): UTF8String = new UTF8String().set(this.bytes) + + override def compare(other: UTF8String): Int = { + var i: Int = 0 + val b = other.getBytes + while (i < bytes.length && i < b.length) { + val res = bytes(i).compareTo(b(i)) + if (res != 0) return res + i += 1 + } + bytes.length - b.length + } + + override def compareTo(other: UTF8String): Int = { + compare(other) + } + + override def equals(other: Any): Boolean = other match { + case s: UTF8String => + Arrays.equals(bytes, s.getBytes) + case s: String => + // This is only used for Catalyst unit tests + // fail fast + bytes.length >= s.length && length() == s.length && toString() == s + case _ => + false + } + + override def hashCode(): Int = { + Arrays.hashCode(bytes) + } +} + +object UTF8String { + // number of tailing bytes in a UTF8 sequence for a code point + // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 + private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, + 6, 6, 6, 6) + + /** + * Create a UTF-8 String from String + */ + def apply(s: String): UTF8String = { + if (s != null) { + new UTF8String().set(s) + } else{ + null + } + } + + /** + * Create a UTF-8 String from Array[Byte], which should be encoded in UTF-8 + */ + def apply(bytes: Array[Byte]): UTF8String = { + if (bytes != null) { + new UTF8String().set(bytes) + } else { + null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index cdf2bc68d9c5e..c6fb22c26bd3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -350,7 +350,7 @@ class StringType private() extends NativeType with PrimitiveType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = String + private[sql] type JvmType = UTF8String @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } private[sql] val ordering = implicitly[Ordering[JvmType]] @@ -1196,8 +1196,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** * Convert the user type to a SQL datum * - * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, - * where we need to convert Any to UserType. + * TODO: Can we make this take obj: UserType? The issue is in + * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType. */ def serialize(obj: Any): Any diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index d4362a91d992c..76298f03c94ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -25,8 +25,9 @@ import org.scalactic.TripleEqualsSupport.Spread import org.scalatest.FunSuite import org.scalatest.Matchers._ -import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ @@ -59,6 +60,10 @@ class ExpressionEvaluationBaseSuite extends FunSuite { class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { + def create_row(values: Any*): Row = { + new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) + } + test("literals") { checkEvaluation(Literal(1), 1) checkEvaluation(Literal(true), true) @@ -265,24 +270,23 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("LIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null))) - checkEvaluation("abdef" like regEx, true, new GenericRow(Array[Any]("abdef"))) - checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a_%b"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a\\__b"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("a%\\%b"))) - checkEvaluation("a_%b" like regEx, true, new GenericRow(Array[Any]("a%\\%b"))) - checkEvaluation("addb" like regEx, true, new GenericRow(Array[Any]("a%"))) - checkEvaluation("addb" like regEx, false, new GenericRow(Array[Any]("**"))) - checkEvaluation("abc" like regEx, true, new GenericRow(Array[Any]("a%"))) - checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("b%"))) - checkEvaluation("abc" like regEx, false, new GenericRow(Array[Any]("bc%"))) - checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a_b"))) - checkEvaluation("ab" like regEx, true, new GenericRow(Array[Any]("a%b"))) - checkEvaluation("a\nb" like regEx, true, new GenericRow(Array[Any]("a%b"))) - - checkEvaluation(Literal.create(null, StringType) like regEx, null, - new GenericRow(Array[Any]("bc%"))) + checkEvaluation("abcd" like regEx, null, create_row(null)) + checkEvaluation("abdef" like regEx, true, create_row("abdef")) + checkEvaluation("a_%b" like regEx, true, create_row("a\\__b")) + checkEvaluation("addb" like regEx, true, create_row("a_%b")) + checkEvaluation("addb" like regEx, false, create_row("a\\__b")) + checkEvaluation("addb" like regEx, false, create_row("a%\\%b")) + checkEvaluation("a_%b" like regEx, true, create_row("a%\\%b")) + checkEvaluation("addb" like regEx, true, create_row("a%")) + checkEvaluation("addb" like regEx, false, create_row("**")) + checkEvaluation("abc" like regEx, true, create_row("a%")) + checkEvaluation("abc" like regEx, false, create_row("b%")) + checkEvaluation("abc" like regEx, false, create_row("bc%")) + checkEvaluation("a\nb" like regEx, true, create_row("a_b")) + checkEvaluation("ab" like regEx, true, create_row("a%b")) + checkEvaluation("a\nb" like regEx, true, create_row("a%b")) + + checkEvaluation(Literal.create(null, StringType) like regEx, null, create_row("bc%")) } test("RLIKE literal Regular Expression") { @@ -313,14 +317,14 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("RLIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) - checkEvaluation("abdef" rlike regEx, true, new GenericRow(Array[Any]("abdef"))) - checkEvaluation("abbbbc" rlike regEx, true, new GenericRow(Array[Any]("a.*c"))) - checkEvaluation("fofo" rlike regEx, true, new GenericRow(Array[Any]("^fo"))) - checkEvaluation("fo\no" rlike regEx, true, new GenericRow(Array[Any]("^fo\no$"))) - checkEvaluation("Bn" rlike regEx, true, new GenericRow(Array[Any]("^Ba*n"))) + checkEvaluation("abdef" rlike regEx, true, create_row("abdef")) + checkEvaluation("abbbbc" rlike regEx, true, create_row("a.*c")) + checkEvaluation("fofo" rlike regEx, true, create_row("^fo")) + checkEvaluation("fo\no" rlike regEx, true, create_row("^fo\no$")) + checkEvaluation("Bn" rlike regEx, true, create_row("^Ba*n")) intercept[java.util.regex.PatternSyntaxException] { - evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**"))) + evaluate("abbbbc" rlike regEx, create_row("**")) } } @@ -763,7 +767,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("null checking") { - val row = new GenericRow(Array[Any]("^Ba*n", null, true, null)) + val row = create_row("^Ba*n", null, true, null) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) val c3 = 'a.boolean.at(2) @@ -803,7 +807,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("case when") { - val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c")) + val row = create_row(null, false, true, "a", "b", "c") val c1 = 'a.boolean.at(0) val c2 = 'a.boolean.at(1) val c3 = 'a.boolean.at(2) @@ -846,13 +850,13 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("complex type") { - val row = new GenericRow(Array[Any]( - "^Ba*n", // 0 - null.asInstanceOf[String], // 1 - new GenericRow(Array[Any]("aa", "bb")), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 - )) + val row = create_row( + "^Ba*n", // 0 + null.asInstanceOf[UTF8String], // 1 + create_row("aa", "bb"), // 2 + Map("aa"->"bb"), // 3 + Seq("aa", "bb") // 4 + ) val typeS = StructType( StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil @@ -909,7 +913,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("arithmetic") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -934,7 +938,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("fractional arithmetic") { - val row = new GenericRow(Array[Any](1.1, 2.0, 3.1, null)) + val row = create_row(1.1, 2.0, 3.1, null) val c1 = 'a.double.at(0) val c2 = 'a.double.at(1) val c3 = 'a.double.at(2) @@ -958,7 +962,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("BinaryComparison") { - val row = new GenericRow(Array[Any](1, 2, 3, null, 3, null)) + val row = create_row(1, 2, 3, null, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) @@ -988,7 +992,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("StringComparison") { - val row = new GenericRow(Array[Any]("abc", null)) + val row = create_row("abc", null) val c1 = 'a.string.at(0) val c2 = 'a.string.at(1) @@ -1009,7 +1013,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } test("Substring") { - val row = new GenericRow(Array[Any]("example", "example".toArray.map(_.toByte))) + val row = create_row("example", "example".toArray.map(_.toByte)) val s = 'a.string.at(0) @@ -1053,7 +1057,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { // substring(null, _, _) -> null checkEvaluation(Substring(s, Literal.create(100, IntegerType), Literal.create(4, IntegerType)), - null, new GenericRow(Array[Any](null))) + null, create_row(null)) // substring(_, null, _) -> null checkEvaluation(Substring(s, Literal.create(null, IntegerType), Literal.create(4, IntegerType)), @@ -1102,20 +1106,20 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { test("SQRT") { val inputSequence = (1 to (1<<24) by 511).map(_ * (1L<<24)) val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) - val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) + val rowSequence = inputSequence.map(l => create_row(l.toDouble)) val d = 'a.double.at(0) for ((row, expected) <- rowSequence zip expectedResults) { checkEvaluation(Sqrt(d), expected, row) } - checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, new GenericRow(Array[Any](null))) + checkEvaluation(Sqrt(Literal.create(null, DoubleType)), null, create_row(null)) checkEvaluation(Sqrt(-1), null, EmptyRow) checkEvaluation(Sqrt(-1.5), null, EmptyRow) } test("Bitwise operations") { - val row = new GenericRow(Array[Any](1, 2, 3, null)) + val row = create_row(1, 2, 3, null) val c1 = 'a.int.at(0) val c2 = 'a.int.at(1) val c3 = 'a.int.at(2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 275ea2627ebcd..bcc0c404d2cfb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.codegen._ /** @@ -43,7 +43,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { } val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](expected)) + val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) if (actual.hashCode() != expectedRow.hashCode()) { fail( s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala new file mode 100644 index 0000000000000..a22aa6f244c48 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -0,0 +1,70 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.types + +import org.scalatest.FunSuite + +// scalastyle:off +class UTF8StringSuite extends FunSuite { + test("basic") { + def check(str: String, len: Int) { + + assert(UTF8String(str).length == len) + assert(UTF8String(str.getBytes("utf8")).length() == len) + + assert(UTF8String(str) == str) + assert(UTF8String(str.getBytes("utf8")) == str) + assert(UTF8String(str).toString == str) + assert(UTF8String(str.getBytes("utf8")).toString == str) + assert(UTF8String(str.getBytes("utf8")) == UTF8String(str)) + + assert(UTF8String(str).hashCode() == UTF8String(str.getBytes("utf8")).hashCode()) + } + + check("hello", 5) + check("世 界", 3) + } + + test("contains") { + assert(UTF8String("hello").contains(UTF8String("ello"))) + assert(!UTF8String("hello").contains(UTF8String("vello"))) + assert(UTF8String("大千世界").contains(UTF8String("千世"))) + assert(!UTF8String("大千世界").contains(UTF8String("世千"))) + } + + test("prefix") { + assert(UTF8String("hello").startsWith(UTF8String("hell"))) + assert(!UTF8String("hello").startsWith(UTF8String("ell"))) + assert(UTF8String("大千世界").startsWith(UTF8String("大千"))) + assert(!UTF8String("大千世界").startsWith(UTF8String("千"))) + } + + test("suffix") { + assert(UTF8String("hello").endsWith(UTF8String("ello"))) + assert(!UTF8String("hello").endsWith(UTF8String("ellov"))) + assert(UTF8String("大千世界").endsWith(UTF8String("世界"))) + assert(!UTF8String("大千世界").endsWith(UTF8String("世"))) + } + + test("slice") { + assert(UTF8String("hello").slice(1, 3) == UTF8String("el")) + assert(UTF8String("大千世界").slice(0, 1) == UTF8String("大")) + assert(UTF8String("大千世界").slice(1, 3) == UTF8String("千世")) + assert(UTF8String("大千世界").slice(3, 5) == UTF8String("界")) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index b237fe684cdc1..89a4faf35e0d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1195,6 +1195,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case FloatType => true case DateType => true case TimestampType => true + case StringType => true case ArrayType(_, _) => true case MapType(_, _, _) => true case StructType(_) => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 87a6631da8300..b0f983c180673 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -216,13 +216,13 @@ private[sql] class IntColumnStats extends ColumnStats { } private[sql] class StringColumnStats extends ColumnStats { - protected var upper: String = null - protected var lower: String = null + protected var upper: UTF8String = null + protected var lower: UTF8String = null override def gatherStats(row: Row, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getString(ordinal) + val value = row(ordinal).asInstanceOf[UTF8String] if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += STRING.actualSize(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index c47497e0662d9..1b9e0df2dcb5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import scala.reflect.runtime.universe.TypeTag @@ -312,26 +312,28 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { row.getString(ordinal).getBytes("utf-8").length + 4 } - override def append(v: String, buffer: ByteBuffer): Unit = { - val stringBytes = v.getBytes("utf-8") + override def append(v: UTF8String, buffer: ByteBuffer): Unit = { + val stringBytes = v.getBytes buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) } - override def extract(buffer: ByteBuffer): String = { + override def extract(buffer: ByteBuffer): UTF8String = { val length = buffer.getInt() val stringBytes = new Array[Byte](length) buffer.get(stringBytes, 0, length) - new String(stringBytes, "utf-8") + UTF8String(stringBytes) } - override def setField(row: MutableRow, ordinal: Int, value: String): Unit = { - row.setString(ordinal, value) + override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { + row.update(ordinal, value) } - override def getField(row: Row, ordinal: Int): String = row.getString(ordinal) + override def getField(row: Row, ordinal: Int): UTF8String = { + row(ordinal).asInstanceOf[UTF8String] + } override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to.setString(toOrdinal, from.getString(fromOrdinal)) + to.update(toOrdinal, from(fromOrdinal)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 656bdd7212f56..1fd387eec7e57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SQLContext} /** * :: DeveloperApi :: @@ -54,6 +54,33 @@ object RDDConversions { } } } + + /** + * Convert the objects inside Row into the types Catalyst expected. + */ + def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = { + data.mapPartitions { iterator => + if (iterator.isEmpty) { + Iterator.empty + } else { + val bufferedIterator = iterator.buffered + val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray) + val schemaFields = schema.fields.toArray + val converters = schemaFields.map { + f => CatalystTypeConverters.createToCatalystConverter(f.dataType) + } + bufferedIterator.map { r => + var i = 0 + while (i < mutableRow.length) { + mutableRow(i) = converters(i)(r(i)) + i += 1 + } + + mutableRow + } + } + } + } } /** Logical plan node for scanning data from an RDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index fad7a281dc1e2..99f24910fd61f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.{BooleanType, StructField, StructType, StringType} -import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} /** * A logical command that is executed for its side-effects. `RunnableCommand`s are @@ -61,7 +62,11 @@ case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan { override def executeTake(limit: Int): Array[Row] = sideEffectResult.take(limit).toArray - override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) + override def execute(): RDD[Row] = { + val converted = sideEffectResult.map(r => + CatalystTypeConverters.convertToCatalyst(r, schema).asInstanceOf[Row]) + sqlContext.sparkContext.parallelize(converted, 1) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index e916e68e58b5d..710787096e6cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -164,7 +164,7 @@ package object debug { case (_: Long, LongType) => case (_: Int, IntegerType) => - case (_: String, StringType) => + case (_: UTF8String, StringType) => case (_: Float, FloatType) => case (_: Byte, ByteType) => case (_: Short, ShortType) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 5b308d88d4cdf..7a43bfd8bc8d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -140,6 +140,7 @@ object EvaluatePython { case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) case (date: Int, DateType) => DateUtils.toJavaDate(date) + case (s: UTF8String, StringType) => s.toString // Pyrolite can handle Timestamp and Decimal case (other, _) => other @@ -192,7 +193,8 @@ object EvaluatePython { case (c: Long, IntegerType) => c.toInt case (c: Int, LongType) => c.toLong case (c: Double, FloatType) => c.toFloat - case (c, StringType) if !c.isInstanceOf[String] => c.toString + case (c: String, StringType) => UTF8String(c) + case (c, StringType) if !c.isInstanceOf[String] => UTF8String(c.toString) case (c, _) => c } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 463e1dcc268bc..b9022fcd9e3ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -233,7 +233,7 @@ private[sql] class JDBCRDD( * Converts value to SQL expression. */ private def compileValue(value: Any): Any = value match { - case stringValue: String => s"'${escapeSql(stringValue)}'" + case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'" case _ => value } @@ -349,12 +349,14 @@ private[sql] class JDBCRDD( val pos = i + 1 conversions(i) match { case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) + // TODO(davies): convert Date into Int case DateConversion => mutableRow.update(i, rs.getDate(pos)) case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos)) case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) + // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 case StringConversion => mutableRow.setString(i, rs.getString(pos)) case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 4fa84dc076f7e..99b755c9f25d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -130,6 +130,8 @@ private[sql] case class JDBCRelation( extends BaseRelation with PrunedFilteredScan { + override val needConversion: Boolean = false + override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index 34f864f5fda7a..d4e0abc040bc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -18,11 +18,8 @@ package org.apache.spark.sql import java.sql.{Connection, DriverManager, PreparedStatement} -import org.apache.spark.{Logging, Partition} -import org.apache.spark.sql._ -import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartition} +import org.apache.spark.Logging import org.apache.spark.sql.types._ package object jdbc { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index f4c99b4b56606..e3352d02787fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.json import java.io.IOException import org.apache.hadoop.fs.Path + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row - -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} private[sql] class DefaultSource @@ -113,6 +113,8 @@ private[sql] case class JSONRelation( // TODO: Support partitioned JSON relation. private def baseRDD = sqlContext.sparkContext.textFile(path) + override val needConversion: Boolean = false + override val schema = userSpecifiedSchema.getOrElse( JsonRDD.nullTypeToStringType( JsonRDD.inferSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index b1e8521383756..29de7401dda71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -409,7 +409,7 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case StringType => toString(value) + case StringType => UTF8String(toString(value)) case _ if value == null || value == "" => null // guard the non string type case IntegerType => value.asInstanceOf[IntegerType.JvmType] case LongType => toLong(value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 43ca359b51735..bc108e37dfb0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -219,8 +219,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.getBytes) - protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = - updateField(fieldIndex, value) + protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = + updateField(fieldIndex, UTF8String(value)) protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, readTimestamp(value)) @@ -418,8 +418,8 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, value.getBytes) - override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = - current.setString(fieldIndex, value) + override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = + current.update(fieldIndex, UTF8String(value)) override protected[parquet] def updateTimestamp(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, readTimestamp(value)) @@ -475,19 +475,18 @@ private[parquet] class CatalystPrimitiveConverter( private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) { - private[this] var dict: Array[String] = null + private[this] var dict: Array[Array[Byte]] = null override def hasDictionarySupport: Boolean = true override def setDictionary(dictionary: Dictionary):Unit = - dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).toStringUsingUTF8} - + dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } override def addValueFromDictionary(dictionaryId: Int): Unit = parent.updateString(fieldIndex, dict(dictionaryId)) override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value.toStringUsingUTF8) + parent.updateString(fieldIndex, value.getBytes) } private[parquet] object CatalystArrayConverter { @@ -714,9 +713,9 @@ private[parquet] class CatalystNativeArrayConverter( elements += 1 } - override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = { + override protected[parquet] def updateString(fieldIndex: Int, value: Array[Byte]): Unit = { checkGrowBuffer() - buffer(elements) = value.asInstanceOf[NativeType] + buffer(elements) = UTF8String(value).asInstanceOf[NativeType] elements += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 0357dcc4688be..5eb1c6abc2432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -55,7 +55,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) case BinaryType => (n: String, v: Any) => FilterApi.eq( binaryColumn(n), @@ -76,7 +76,7 @@ private[sql] object ParquetFilters { case StringType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), - Option(v).map(s => Binary.fromString(s.asInstanceOf[String])).orNull) + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) case BinaryType => (n: String, v: Any) => FilterApi.notEq( binaryColumn(n), @@ -94,7 +94,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -111,7 +111,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -128,7 +128,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) @@ -145,7 +145,7 @@ private[sql] object ParquetFilters { (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) + FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) case BinaryType => (n: String, v: Any) => FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 5a1b15490d273..e05a4c20b0d41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -198,10 +198,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { if (value != null) { schema match { case StringType => writer.addBinary( - Binary.fromByteArray( - value.asInstanceOf[String].getBytes("utf-8") - ) - ) + Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(value.asInstanceOf[Int]) @@ -349,7 +346,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { index: Int): Unit = { ctype match { case StringType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[String].getBytes("utf-8"))) + Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case IntegerType => writer.addInteger(record.getInt(index)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 20fdf5e58ef82..af7b3c81ae7b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -33,7 +33,6 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} - import parquet.filter2.predicate.FilterApi import parquet.format.converter.ParquetMetadataConverter import parquet.hadoop.metadata.CompressionCodecName @@ -45,13 +44,13 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} import org.apache.spark.sql.parquet.ParquetTypesConverter._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _} import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} -import org.apache.spark.{Logging, Partition => SparkPartition, SerializableWritable, SparkException, TaskContext} +import org.apache.spark.{Logging, SerializableWritable, SparkException, TaskContext, Partition => SparkPartition} /** * Allows creation of Parquet based tables using the syntax: @@ -409,6 +408,9 @@ private[sql] case class ParquetRelation2( file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + // Skip type conversion + override val needConversion: Boolean = false + // TODO Should calculate per scan size // It's common that a query only scans a fraction of a large Parquet file. Returning size of the // whole Parquet file disables some optimizations in this case (e.g. broadcast join). @@ -550,7 +552,8 @@ private[sql] case class ParquetRelation2( baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => val partValues = selectedPartitions.collectFirst { - case p if split.getPath.getParent.toString == p.path => p.values + case p if split.getPath.getParent.toString == p.path => + CatalystTypeConverters.convertToCatalyst(p.values).asInstanceOf[Row] }.get val requiredPartOrdinal = partitionKeyLocations.keys.toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 34d048e426d10..b3d71f687a60a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{UTF8String, StringType} import org.apache.spark.sql.{Row, Strategy, execution, sources} /** @@ -53,7 +54,7 @@ private[sql] object DataSourceStrategy extends Strategy { (a, _) => t.buildScan(a)) :: Nil case l @ LogicalRelation(t: TableScan) => - execution.PhysicalRDD(l.output, t.buildScan()) :: Nil + createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation), part, query, overwrite, false) if part.isEmpty => @@ -102,20 +103,30 @@ private[sql] object DataSourceStrategy extends Strategy { projectList.asInstanceOf[Seq[Attribute]] // Safe due to if above. .map(relation.attributeMap) // Match original case of attributes. - val scan = - execution.PhysicalRDD( - projectList.map(_.toAttribute), + val scan = createPhysicalRDD(relation.relation, projectList.map(_.toAttribute), scanBuilder(requestedColumns, pushedFilters)) filterCondition.map(execution.Filter(_, scan)).getOrElse(scan) } else { val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq - val scan = - execution.PhysicalRDD(requestedColumns, scanBuilder(requestedColumns, pushedFilters)) + val scan = createPhysicalRDD(relation.relation, requestedColumns, + scanBuilder(requestedColumns, pushedFilters)) execution.Project(projectList, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)) } } + private[this] def createPhysicalRDD( + relation: BaseRelation, + output: Seq[Attribute], + rdd: RDD[Row]): SparkPlan = { + val converted = if (relation.needConversion) { + execution.RDDConversions.rowToRowRdd(rdd, relation.schema) + } else { + rdd + } + execution.PhysicalRDD(output, converted) + } + /** * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s, * and convert them. @@ -167,14 +178,14 @@ private[sql] object DataSourceStrategy extends Strategy { case expressions.Not(child) => translate(child).map(sources.Not) - case expressions.StartsWith(a: Attribute, Literal(v: String, StringType)) => - Some(sources.StringStartsWith(a.name, v)) + case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringStartsWith(a.name, v.toString)) - case expressions.EndsWith(a: Attribute, Literal(v: String, StringType)) => - Some(sources.StringEndsWith(a.name, v)) + case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringEndsWith(a.name, v.toString)) - case expressions.Contains(a: Attribute, Literal(v: String, StringType)) => - Some(sources.StringContains(a.name, v)) + case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => + Some(sources.StringContains(a.name, v.toString)) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 8f9946a5a801e..ca53dcdb92c52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -126,6 +126,16 @@ abstract class BaseRelation { * could lead to execution plans that are suboptimal (i.e. broadcasting a very large table). */ def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes + + /** + * Whether does it need to convert the objects in Row to internal representation, for example: + * java.lang.String -> UTF8String + * java.lang.Decimal -> Decimal + * + * Note: The internal representation is not stable across releases and thus data sources outside + * of Spark SQL should leave this as true. + */ + def needConversion: Boolean = true } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 36465cc2fa11a..bf6cf1321a056 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -30,7 +30,7 @@ class RowSuite extends FunSuite { test("create row") { val expected = new GenericMutableRow(4) expected.update(0, 2147483647) - expected.update(1, "this is a string") + expected.setString(1, "this is a string") expected.update(2, false) expected.update(3, null) val actual1 = Row(2147483647, "this is a string", false, null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0174aaee94246..4c48dca44498b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,18 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.execution.GeneratedAggregate -import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.TestData._ import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types._ - -import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} +import org.apache.spark.sql.types._ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 5f08834f73c6b..c86ef338fc644 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -65,7 +65,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(FLOAT, Float.MaxValue, 4) checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(BOOLEAN, true, 1) - checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length) + checkActualSize(STRING, UTF8String("hello"), 4 + "hello".getBytes("utf-8").length) checkActualSize(DATE, 0, 4) checkActualSize(TIMESTAMP, new Timestamp(0L), 12) @@ -108,8 +108,8 @@ class ColumnTypeSuite extends FunSuite with Logging { testNativeColumnType[StringType.type]( STRING, - (buffer: ByteBuffer, string: String) => { - val bytes = string.getBytes("utf-8") + (buffer: ByteBuffer, string: UTF8String) => { + val bytes = string.getBytes buffer.putInt(bytes.length) buffer.put(bytes) }, @@ -117,7 +117,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes) - new String(bytes, "utf-8") + UTF8String(bytes) }) testColumnType[BinaryType.type, Array[Byte]]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index b301818a008e7..f76314b9dab5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{Decimal, DataType, NativeType} +import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, NativeType} object ColumnarTestUtils { def makeNullRow(length: Int): GenericMutableRow = { @@ -48,7 +48,7 @@ object ColumnarTestUtils { case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) - case STRING => Random.nextString(Random.nextInt(32)) + case STRING => UTF8String(Random.nextString(Random.nextInt(32))) case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) case DATE => Random.nextInt() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 60c8c00bda4d5..3b47b8adf313b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -74,7 +74,7 @@ case class AllDataTypesScan( i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + new Date(1970, 1, 1), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -82,7 +82,7 @@ case class AllDataTypesScan( Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) } } } @@ -103,7 +103,7 @@ class TableScanSuite extends DataSourceTest { i.toDouble, new java.math.BigDecimal(i), new java.math.BigDecimal(i), - new Date((i + 1) * 8640000), + new Date(1970, 1, 1), new Timestamp(20000 + i), s"varchar_$i", Seq(i, i + 1), @@ -111,7 +111,7 @@ class TableScanSuite extends DataSourceTest { Map(i -> i.toString), Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), - Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date((i + 2) * 8640000))))) + Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(new Date(1970, 1, i + 1))))) }.toSeq before { @@ -266,7 +266,7 @@ class TableScanSuite extends DataSourceTest { sqlTest( "SELECT structFieldComplex.Value.`value_(2)` FROM tableWithSchema", - (1 to 10).map(i => Row(Seq(new Date((i + 2) * 8640000)))).toSeq) + (1 to 10).map(i => Row(Seq(new Date(1970, 1, i + 1)))).toSeq) test("Caching") { // Cached Query Execution diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 921c6194c7b76..74ae984f34866 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -34,7 +34,7 @@ import scala.collection.JavaConversions._ * 1. The Underlying data type in catalyst and in Hive * In catalyst: * Primitive => - * java.lang.String + * UTF8String * int / scala.Int * boolean / scala.Boolean * float / scala.Float @@ -239,9 +239,10 @@ private[hive] trait HiveInspectors { */ def unwrap(data: Any, oi: ObjectInspector): Any = oi match { case coi: ConstantObjectInspector if coi.getWritableConstantValue == null => null - case poi: WritableConstantStringObjectInspector => poi.getWritableConstantValue.toString + case poi: WritableConstantStringObjectInspector => + UTF8String(poi.getWritableConstantValue.toString) case poi: WritableConstantHiveVarcharObjectInspector => - poi.getWritableConstantValue.getHiveVarchar.getValue + UTF8String(poi.getWritableConstantValue.getHiveVarchar.getValue) case poi: WritableConstantHiveDecimalObjectInspector => HiveShim.toCatalystDecimal( PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector, @@ -284,10 +285,13 @@ private[hive] trait HiveInspectors { case pi: PrimitiveObjectInspector => pi match { // We think HiveVarchar is also a String case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() => - hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue - case hvoi: HiveVarcharObjectInspector => hvoi.getPrimitiveJavaObject(data).getValue + UTF8String(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue) + case hvoi: HiveVarcharObjectInspector => + UTF8String(hvoi.getPrimitiveJavaObject(data).getValue) case x: StringObjectInspector if x.preferWritable() => - x.getPrimitiveWritableObject(data).toString + UTF8String(x.getPrimitiveWritableObject(data).toString) + case x: StringObjectInspector => + UTF8String(x.getPrimitiveJavaObject(data)) case x: IntObjectInspector if x.preferWritable() => x.get(data) case x: BooleanObjectInspector if x.preferWritable() => x.get(data) case x: FloatObjectInspector if x.preferWritable() => x.get(data) @@ -340,7 +344,9 @@ private[hive] trait HiveInspectors { */ protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match { case _: JavaHiveVarcharObjectInspector => - (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) + (o: Any) => + val s = o.asInstanceOf[UTF8String].toString + new HiveVarchar(s, s.size) case _: JavaHiveDecimalObjectInspector => (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toJavaBigDecimal) @@ -409,7 +415,7 @@ private[hive] trait HiveInspectors { case x: PrimitiveObjectInspector => x match { // TODO we don't support the HiveVarcharObjectInspector yet. case _: StringObjectInspector if x.preferWritable() => HiveShim.getStringWritable(a) - case _: StringObjectInspector => a.asInstanceOf[java.lang.String] + case _: StringObjectInspector => a.asInstanceOf[UTF8String].toString() case _: IntObjectInspector if x.preferWritable() => HiveShim.getIntWritable(a) case _: IntObjectInspector => a.asInstanceOf[java.lang.Integer] case _: BooleanObjectInspector if x.preferWritable() => HiveShim.getBooleanWritable(a) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 1ccb0c279c60e..a6f4fbe8aba06 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,24 +17,21 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.catalyst.expressions.Row - import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate +import org.apache.spark.sql.catalyst.expressions.{Row, _} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.DescribeCommand -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, CreateTableUsing} +import org.apache.spark.sql.sources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.types.StringType @@ -131,7 +128,7 @@ private[hive] trait HiveStrategies { val partitionValues = part.getValues var i = 0 while (i < partitionValues.size()) { - inputData(i) = partitionValues(i) + inputData(i) = CatalystTypeConverters.convertToCatalyst(partitionValues(i)) i += 1 } pruningCondition(inputData) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 8efed7f0299bf..cab0fdd35723a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.hive.execution -import java.io.{BufferedReader, InputStreamReader} -import java.io.{DataInputStream, DataOutputStream, EOFException} +import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader} import java.util.Properties import scala.collection.JavaConversions._ @@ -28,12 +27,13 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} +import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils /** @@ -121,14 +121,13 @@ case class ScriptTransformation( if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() - if (!ioschema.schemaLess) { - new GenericRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + new GenericRow(CatalystTypeConverters.convertToCatalyst( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) .asInstanceOf[Array[Any]]) } else { - new GenericRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + new GenericRow(CatalystTypeConverters.convertToCatalyst( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) .asInstanceOf[Array[Any]]) } } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 902a12785e3e9..a40a1e53117cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -22,11 +22,11 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ /** * Analyzes the given table in the current database to generate statistics, which will be @@ -76,6 +76,12 @@ case class DropTable( private[hive] case class AddJar(path: String) extends RunnableCommand { + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("result", IntegerType, false) :: Nil) + schema.toAttributes + } + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD JAR $path") diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 0ed93c2c5b1fa..33e96eaabfbf6 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -41,7 +41,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.InputFormat -import org.apache.spark.sql.types.{Decimal, DecimalType} +import org.apache.spark.sql.types.{UTF8String, Decimal, DecimalType} private[hive] case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable { @@ -135,7 +135,7 @@ private[hive] object HiveShim { PrimitiveCategory.VOID, null) def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 7577309900209..d331c210e8939 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -17,37 +17,35 @@ package org.apache.spark.sql.hive -import java.util -import java.util.{ArrayList => JArrayList} -import java.util.Properties import java.rmi.server.UID +import java.util.{Properties, ArrayList => JArrayList} import scala.collection.JavaConversions._ import scala.language.implicitConversions +import com.esotericsoftware.kryo.Kryo import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.hive.common.StatsSetupConst -import org.apache.hadoop.hive.common.`type`.{HiveDecimal} +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition} +import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} +import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, PrimitiveObjectInspector, ObjectInspector} -import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} -import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorConverters, PrimitiveObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfo, TypeInfoFactory} +import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging -import org.apache.spark.sql.types.{Decimal, DecimalType} - +import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} /** * This class provides the UDF creation and also the UDF instance serialization and @@ -63,18 +61,14 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) // for Serialization def this() = this(null) - import java.io.{OutputStream, InputStream} - import com.esotericsoftware.kryo.Kryo import org.apache.spark.util.Utils._ - import org.apache.hadoop.hive.ql.exec.Utilities - import org.apache.hadoop.hive.ql.exec.UDF @transient private val methodDeSerialize = { val method = classOf[Utilities].getDeclaredMethod( "deserializeObjectByKryo", classOf[Kryo], - classOf[InputStream], + classOf[java.io.InputStream], classOf[Class[_]]) method.setAccessible(true) @@ -87,7 +81,7 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) "serializeObjectByKryo", classOf[Kryo], classOf[Object], - classOf[OutputStream]) + classOf[java.io.OutputStream]) method.setAccessible(true) method @@ -224,7 +218,7 @@ private[hive] object HiveShim { TypeInfoFactory.voidTypeInfo, null) def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[String]) + if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) def getIntWritable(value: Any): hadoopIo.IntWritable = if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) From cf38fe04f8782ff4573ae106ec0de8e8d183cb2b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 15 Apr 2015 13:15:58 -0700 Subject: [PATCH 008/191] [SPARK-6844][SQL] Clean up accumulators used in InMemoryRelation when it is uncached JIRA: https://issues.apache.org/jira/browse/SPARK-6844 Author: Liang-Chi Hsieh Closes #5475 from viirya/cache_memory_leak and squashes the following commits: 0b41235 [Liang-Chi Hsieh] fix style. dc1d5d5 [Liang-Chi Hsieh] For comments. 78af229 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cache_memory_leak 26c9bb6 [Liang-Chi Hsieh] Add configuration to enable in-memory table scan accumulators. 1c3b06e [Liang-Chi Hsieh] Clean up accumulators used in InMemoryRelation when it is uncached. --- .../org/apache/spark/sql/CacheManager.scala | 2 +- .../columnar/InMemoryColumnarTableScan.scala | 47 ++++++++++++++----- .../apache/spark/sql/CachedTableSuite.scala | 18 +++++++ .../columnar/PartitionBatchPruningSuite.scala | 2 + 4 files changed, 55 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index ca4a127120b37..18584c2dcf797 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -112,7 +112,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") - cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + cachedData(dataIndex).cachedRepresentation.uncache(blocking) cachedData.remove(dataIndex) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 6eee0c86d6a1c..d9b6fb43ab83d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import org.apache.spark.Accumulator +import org.apache.spark.{Accumulable, Accumulator, Accumulators} import org.apache.spark.sql.catalyst.expressions import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row +import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -53,11 +55,16 @@ private[sql] case class InMemoryRelation( child: SparkPlan, tableName: Option[String])( private var _cachedColumnBuffers: RDD[CachedBatch] = null, - private var _statistics: Statistics = null) + private var _statistics: Statistics = null, + private var _batchStats: Accumulable[ArrayBuffer[Row], Row] = null) extends LogicalPlan with MultiInstanceRelation { - private val batchStats = - child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) + private val batchStats: Accumulable[ArrayBuffer[Row], Row] = + if (_batchStats == null) { + child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[Row]) + } else { + _batchStats + } val partitionStatistics = new PartitionStatistics(output) @@ -161,7 +168,7 @@ private[sql] case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, statisticsToBePropagated) + _cachedColumnBuffers, statisticsToBePropagated, batchStats) } override def children: Seq[LogicalPlan] = Seq.empty @@ -175,13 +182,20 @@ private[sql] case class InMemoryRelation( child, tableName)( _cachedColumnBuffers, - statisticsToBePropagated).asInstanceOf[this.type] + statisticsToBePropagated, + batchStats).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers override protected def otherCopyArgs: Seq[AnyRef] = - Seq(_cachedColumnBuffers, statisticsToBePropagated) + Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats) + + private[sql] def uncache(blocking: Boolean): Unit = { + Accumulators.remove(batchStats.id) + cachedColumnBuffers.unpersist(blocking) + _cachedColumnBuffers = null + } } private[sql] case class InMemoryColumnarTableScan( @@ -244,15 +258,20 @@ private[sql] case class InMemoryColumnarTableScan( } } + lazy val enableAccumulators: Boolean = + sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean + // Accumulators used for testing purposes - val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) - val readBatches: Accumulator[Int] = sparkContext.accumulator(0) + lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0) + lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0) private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning override def execute(): RDD[Row] = { - readPartitions.setValue(0) - readBatches.setValue(0) + if (enableAccumulators) { + readPartitions.setValue(0) + readBatches.setValue(0) + } relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator => val partitionFilter = newPredicate( @@ -302,7 +321,7 @@ private[sql] case class InMemoryColumnarTableScan( } } - if (rows.hasNext) { + if (rows.hasNext && enableAccumulators) { readPartitions += 1 } @@ -321,7 +340,9 @@ private[sql] case class InMemoryColumnarTableScan( logInfo(s"Skipping partition based on stats $statsString") false } else { - readBatches += 1 + if (enableAccumulators) { + readBatches += 1 + } true } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index f7b5f08beb92f..01e3b8671071e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,6 +22,7 @@ import scala.language.{implicitConversions, postfixOps} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.test.TestSQLContext._ @@ -297,4 +298,21 @@ class CachedTableSuite extends QueryTest { sql("Clear CACHE") assert(cacheManager.isEmpty) } + + test("Clear accumulators when uncacheTable to prevent memory leaking") { + val accsSize = Accumulators.originals.size + + sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") + sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") + cacheTable("t1") + cacheTable("t2") + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + sql("SELECT * FROM t1").count() + sql("SELECT * FROM t2").count() + uncacheTable("t1") + uncacheTable("t2") + + assert(accsSize >= Accumulators.originals.size) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index e57bb06e7263b..2a0b701cad7fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -39,6 +39,8 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be // Enable in-memory partition pruning setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + // Enable in-memory table scan accumulators + setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { From 557a797a273f1668065806cba53e19e6134a66d3 Mon Sep 17 00:00:00 2001 From: sboeschhuawei Date: Wed, 15 Apr 2015 13:28:10 -0700 Subject: [PATCH 009/191] [SPARK-6937][MLLIB] Fixed bug in PICExample in which the radius were not being accepted on c... Tiny bug in PowerIterationClusteringExample in which radius not accepted from command line Author: sboeschhuawei Closes #5531 from javadba/picsub and squashes the following commits: 2aab8cf [sboeschhuawei] Fixed bug in PICExample in which the radius were not being accepted on command line --- .../examples/mllib/PowerIterationClusteringExample.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 9f22d40c15f3f..6d8b806569dfd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -65,7 +65,7 @@ object PowerIterationClusteringExample { def main(args: Array[String]) { val defaultParams = Params() - val parser = new OptionParser[Params]("PIC Circles") { + val parser = new OptionParser[Params]("PowerIterationClusteringExample") { head("PowerIterationClusteringExample: an example PIC app using concentric circles.") opt[Int]('k', "k") .text(s"number of circles (/clusters), default: ${defaultParams.k}") @@ -76,9 +76,9 @@ object PowerIterationClusteringExample { opt[Int]("maxIterations") .text(s"number of iterations, default: ${defaultParams.maxIterations}") .action((x, c) => c.copy(maxIterations = x)) - opt[Int]('r', "r") + opt[Double]('r', "r") .text(s"radius of outermost circle, default: ${defaultParams.outerRadius}") - .action((x, c) => c.copy(numPoints = x)) + .action((x, c) => c.copy(outerRadius = x)) } parser.parse(args, defaultParams).map { params => @@ -154,3 +154,4 @@ object PowerIterationClusteringExample { coeff * math.exp(expCoeff * ssquares) } } + From 4754e16f4746ebd882b2ce7f1efc6e4d4408922c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 15 Apr 2015 13:39:12 -0700 Subject: [PATCH 010/191] [SPARK-6898][SQL] completely support special chars in column names Even if we wrap column names in backticks like `` `a#$b.c` ``, we still handle the "." inside column name specially. I think it's fragile to use a special char to split name parts, why not put name parts in `UnresolvedAttribute` directly? Author: Wenchen Fan This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #5511 from cloud-fan/6898 and squashes the following commits: 48e3e57 [Wenchen Fan] more style fix 820dc45 [Wenchen Fan] do not ignore newName in UnresolvedAttribute d81ad43 [Wenchen Fan] fix style 11699d6 [Wenchen Fan] completely support special chars in column names --- .../apache/spark/sql/catalyst/SqlParser.scala | 4 +-- .../sql/catalyst/analysis/Analyzer.scala | 13 ++++----- .../sql/catalyst/analysis/CheckAnalysis.scala | 6 ++++- .../sql/catalyst/analysis/unresolved.scala | 14 ++++++++-- .../catalyst/plans/logical/LogicalPlan.scala | 27 +++++++++---------- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 -- .../org/apache/spark/sql/DataFrame.scala | 4 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 13 ++++++--- .../org/apache/spark/sql/hive/HiveQl.scala | 2 +- 9 files changed, 52 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 9a3531ceb3343..0af969cc5cc67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -381,13 +381,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | "(" ~> expression <~ ")" | function | dotExpressionHeader - | ident ^^ UnresolvedAttribute + | ident ^^ {case i => UnresolvedAttribute.quoted(i)} | signedPrimary | "~" ~> expression ^^ BitwiseNot ) protected lazy val dotExpressionHeader: Parser[Expression] = (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { - case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString(".")) + case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8b68b0df35f48..cb49e5ad5586f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -297,14 +297,15 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { - case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) && + case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 && + resolver(nameParts(0), VirtualColumn.groupingIdName) && q.isInstanceOf[GroupingAnalytics] => // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics q.asInstanceOf[GroupingAnalytics].gid - case u @ UnresolvedAttribute(name) => + case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = - withPosition(u) { q.resolveChildren(name, resolver).getOrElse(u) } + withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedGetField(child, fieldName) if child.resolved => @@ -383,12 +384,12 @@ class Analyzer( child: LogicalPlan, grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { // Find any attributes that remain unresolved in the sort. - val unresolved: Seq[String] = - ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) + val unresolved: Seq[Seq[String]] = + ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts }) // Create a map from name, to resolved attributes, when the desired name can be found // prior to the projection. - val resolved: Map[String, NamedExpression] = + val resolved: Map[Seq[String], NamedExpression] = unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap // Construct a set that contains all of the attributes that we need to evaluate the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fa02111385c06..1155dac28fc78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -46,8 +46,12 @@ trait CheckAnalysis { operator transformExpressionsUp { case a: Attribute if !a.resolved => if (operator.childrenResolved) { + val nameParts = a match { + case UnresolvedAttribute(nameParts) => nameParts + case _ => Seq(a.name) + } // Throw errors for specific problems with get field. - operator.resolveChildren(a.name, resolver, throwErrors = true) + operator.resolveChildren(nameParts, resolver, throwErrors = true) } val from = operator.inputSet.map(_.name).mkString(", ") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 300e9ba187bc5..3f567e3e8b2a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -49,7 +49,12 @@ case class UnresolvedRelation( /** * Holds the name of an attribute that has yet to be resolved. */ -case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { +case class UnresolvedAttribute(nameParts: Seq[String]) + extends Attribute with trees.LeafNode[Expression] { + + def name: String = + nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") @@ -59,7 +64,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def newInstance(): UnresolvedAttribute = this override def withNullability(newNullability: Boolean): UnresolvedAttribute = this override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this - override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute(name) + override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) // Unresolved attributes are transient at compile time and don't get evaluated during execution. override def eval(input: Row = null): EvaluatedType = @@ -68,6 +73,11 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def toString: String = s"'$name" } +object UnresolvedAttribute { + def apply(name: String): UnresolvedAttribute = new UnresolvedAttribute(name.split("\\.")) + def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) +} + case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 579a0fb8d3f93..ae4620a4e5abf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedGetField, Resolver} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.trees -import org.apache.spark.sql.types.{ArrayType, StructType, StructField} abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { @@ -111,10 +110,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ def resolveChildren( - name: String, + nameParts: Seq[String], resolver: Resolver, throwErrors: Boolean = false): Option[NamedExpression] = - resolve(name, children.flatMap(_.output), resolver, throwErrors) + resolve(nameParts, children.flatMap(_.output), resolver, throwErrors) /** * Optionally resolves the given string to a [[NamedExpression]] based on the output of this @@ -122,10 +121,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * `[scope].AttributeName.[nested].[fields]...`. */ def resolve( - name: String, + nameParts: Seq[String], resolver: Resolver, throwErrors: Boolean = false): Option[NamedExpression] = - resolve(name, output, resolver, throwErrors) + resolve(nameParts, output, resolver, throwErrors) /** * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. @@ -135,7 +134,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * See the comment above `candidates` variable in resolve() for semantics the returned data. */ private def resolveAsTableColumn( - nameParts: Array[String], + nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { assert(nameParts.length > 1) @@ -155,7 +154,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * See the comment above `candidates` variable in resolve() for semantics the returned data. */ private def resolveAsColumn( - nameParts: Array[String], + nameParts: Seq[String], resolver: Resolver, attribute: Attribute): Option[(Attribute, List[String])] = { if (resolver(attribute.name, nameParts.head)) { @@ -167,13 +166,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** Performs attribute resolution given a name and a sequence of possible attributes. */ protected def resolve( - name: String, + nameParts: Seq[String], input: Seq[Attribute], resolver: Resolver, throwErrors: Boolean): Option[NamedExpression] = { - val parts = name.split("\\.") - // A sequence of possible candidate matches. // Each candidate is a tuple. The first element is a resolved attribute, followed by a list // of parts that are to be resolved. @@ -182,9 +179,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // and the second element will be List("c"). var candidates: Seq[(Attribute, List[String])] = { // If the name has 2 or more parts, try to resolve it as `table.column` first. - if (parts.length > 1) { + if (nameParts.length > 1) { input.flatMap { option => - resolveAsTableColumn(parts, resolver, option) + resolveAsTableColumn(nameParts, resolver, option) } } else { Seq.empty @@ -194,10 +191,12 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // If none of attributes match `table.column` pattern, we try to resolve it as a column. if (candidates.isEmpty) { candidates = input.flatMap { candidate => - resolveAsColumn(parts, resolver, candidate) + resolveAsColumn(nameParts, resolver, candidate) } } + def name = UnresolvedAttribute(nameParts).name + candidates.distinct match { // One match, no nested fields, use it. case Seq((a, Nil)) => Some(a) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6e3d6b9263e86..e10ddfdf5127c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import scala.collection.immutable - class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) val caseInsensitiveCatalog = new SimpleCatalog(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 94ae2d65fd0e4..3235f85d5bbd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -158,7 +158,7 @@ class DataFrame private[sql]( } protected[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse { + queryExecution.analyzed.resolve(colName.split("\\."), sqlContext.analyzer.resolver).getOrElse { throw new AnalysisException( s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") } @@ -166,7 +166,7 @@ class DataFrame private[sql]( protected[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get + queryExecution.analyzed.resolve(n.name.split("\\."), sqlContext.analyzer.resolver).get } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4c48dca44498b..d739e550f3e56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.TestData._ import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ +import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} import org.apache.spark.sql.types._ - class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. TestData @@ -1125,7 +1124,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) jsonRDD(data).registerTempTable("records") - sql("SELECT `key?number1` FROM records") + sql("SELECT `key?number1`, `key.number2` FROM records") } test("SPARK-3814 Support Bitwise & operator") { @@ -1225,4 +1224,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } + + test("SPARK-6898: complete support for special chars in column names") { + jsonRDD(sparkContext.makeRDD( + """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) + .registerTempTable("t") + + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 53a204b8c2932..fd305eb480e63 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1101,7 +1101,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token(".", qualifier :: Token(attr, Nil) :: Nil) => nodeToExpr(qualifier) match { case UnresolvedAttribute(qualifierName) => - UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr)) + UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr)) case other => UnresolvedGetField(other, attr) } From 585638e81ce09a72b9e7f95d38e0d432cfa02456 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 15 Apr 2015 14:06:10 -0700 Subject: [PATCH 011/191] [SPARK-2213] [SQL] sort merge join for spark sql Thanks for the initial work from Ishiihara in #3173 This PR introduce a new join method of sort merge join, which firstly ensure that keys of same value are in the same partition, and inside each partition the Rows are sorted by key. Then we can run down both sides together, find matched rows using [sort merge join](http://en.wikipedia.org/wiki/Sort-merge_join). In this way, we don't have to store the whole hash table of one side as hash join, thus we have less memory usage. Also, this PR would benefit from #3438 , making the sorting phrase much more efficient. We introduced a new configuration of "spark.sql.planner.sortMergeJoin" to switch between this(`true`) and ShuffledHashJoin(`false`), probably we want the default value of it be `false` at first. Author: Daoyuan Wang Author: Michael Armbrust This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #5208 from adrian-wang/smj and squashes the following commits: 2493b9f [Daoyuan Wang] fix style 5049d88 [Daoyuan Wang] propagate rowOrdering for RangePartitioning f91a2ae [Daoyuan Wang] yin's comment: use external sort if option is enabled, add comments f515cd2 [Daoyuan Wang] yin's comment: outputOrdering, join suite refine ec8061b [Daoyuan Wang] minor change 413fd24 [Daoyuan Wang] Merge pull request #3 from marmbrus/pr/5208 952168a [Michael Armbrust] add type 5492884 [Michael Armbrust] copy when ordering 7ddd656 [Michael Armbrust] Cleanup addition of ordering requirements b198278 [Daoyuan Wang] inherit ordering in project c8e82a3 [Daoyuan Wang] fix style 6e897dd [Daoyuan Wang] hide boundReference from manually construct RowOrdering for key compare in smj 8681d73 [Daoyuan Wang] refactor Exchange and fix copy for sorting 2875ef2 [Daoyuan Wang] fix changed configuration 61d7f49 [Daoyuan Wang] add omitted comment 00a4430 [Daoyuan Wang] fix bug 078d69b [Daoyuan Wang] address comments: add comments, do sort in shuffle, and others 3af6ba5 [Daoyuan Wang] use buffer for only one side 171001f [Daoyuan Wang] change default outputordering 47455c9 [Daoyuan Wang] add apache license ... a28277f [Daoyuan Wang] fix style 645c70b [Daoyuan Wang] address comments using sort 068c35d [Daoyuan Wang] fix new style and add some tests 925203b [Daoyuan Wang] address comments 07ce92f [Daoyuan Wang] fix ArrayIndexOutOfBound 42fca0e [Daoyuan Wang] code clean e3ec096 [Daoyuan Wang] fix comment style.. 2edd235 [Daoyuan Wang] fix outputpartitioning 57baa40 [Daoyuan Wang] fix sort eval bug 303b6da [Daoyuan Wang] fix several errors 95db7ad [Daoyuan Wang] fix brackets for if-statement 4464f16 [Daoyuan Wang] fix error 880d8e9 [Daoyuan Wang] sort merge join for spark sql --- .../spark/sql/catalyst/expressions/rows.scala | 10 +- .../plans/physical/partitioning.scala | 13 ++ .../scala/org/apache/spark/sql/SQLConf.scala | 8 + .../org/apache/spark/sql/SQLContext.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 148 ++++++++++++--- .../spark/sql/execution/SparkPlan.scala | 6 + .../spark/sql/execution/SparkStrategies.scala | 11 +- .../spark/sql/execution/basicOperators.scala | 10 ++ .../sql/execution/joins/SortMergeJoin.scala | 169 ++++++++++++++++++ .../org/apache/spark/sql/JoinSuite.scala | 28 ++- .../SortMergeCompatibilitySuite.scala | 162 +++++++++++++++++ 11 files changed, 534 insertions(+), 33 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala create mode 100644 sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 1b62e17ff47fd..b6ec7d3417ef8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{UTF8String, StructType, NativeType} - +import org.apache.spark.sql.types.{UTF8String, DataType, StructType, NativeType} /** * An extended interface to [[Row]] that allows the values for each column to be updated. Setting @@ -239,3 +238,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { return 0 } } + +object RowOrdering { + def forSchema(dataTypes: Seq[DataType]): RowOrdering = + new RowOrdering(dataTypes.zipWithIndex.map { + case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) + }) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 288c11f69fe22..fb4217a44807b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -94,6 +94,9 @@ sealed trait Partitioning { * only compatible if the `numPartitions` of them is the same. */ def compatibleWith(other: Partitioning): Boolean + + /** Returns the expressions that are used to key the partitioning. */ + def keyExpressions: Seq[Expression] } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case UnknownPartitioning(_) => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object SinglePartition extends Partitioning { @@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } case object BroadcastPartitioning extends Partitioning { @@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning { case SinglePartition => true case _ => false } + + override def keyExpressions: Seq[Expression] = Nil } /** @@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } + override def keyExpressions: Seq[Expression] = expressions + override def eval(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } + override def keyExpressions: Seq[Expression] = ordering.map(_.child) + override def eval(input: Row): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ee641bdfeb2d7..5c65f04ee8497 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -47,6 +47,7 @@ private[spark] object SQLConf { // Options that control which operators can be chosen by the query planner. These should be // considered hints and may be ignored by future versions of Spark SQL. val EXTERNAL_SORT = "spark.sql.planner.externalSort" + val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin" // This is only used for the thriftserver val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" @@ -128,6 +129,13 @@ private[sql] class SQLConf extends Serializable { /** When true the planner will use the external sort, which may spill to disk. */ private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean + /** + * Sort merge join would sort the two side of join first, and then iterate both sides together + * only once to get all matches. Using sort merge join can save a lot of memory usage compared + * to HashJoin. + */ + private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean + /** * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 89a4faf35e0d2..f9f3eb2e03817 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1081,7 +1081,7 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange(self)) :: Nil + Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil } protected[sql] def openSession(): SQLSession = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 437408d30bfd2..518fc9e57c708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,24 +19,42 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair +object Exchange { + /** + * Returns true when the ordering expressions are a subset of the key. + * if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]]. + */ + def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = { + desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet) + } +} + /** * :: DeveloperApi :: + * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each + * resulting partition based on expressions from the partition key. It is invalid to construct an + * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key. */ @DeveloperApi -case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { +case class Exchange( + newPartitioning: Partitioning, + newOrdering: Seq[SortOrder], + child: SparkPlan) + extends UnaryNode { override def outputPartitioning: Partitioning = newPartitioning + override def outputOrdering: Seq[SortOrder] = newOrdering + override def output: Seq[Attribute] = child.output /** We must copy rows when sort based shuffle is on */ @@ -45,6 +63,20 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una private val bypassMergeThreshold = child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + private val keyOrdering = { + if (newOrdering.nonEmpty) { + val key = newPartitioning.keyExpressions + val boundOrdering = newOrdering.map { o => + val ordinal = key.indexOf(o.child) + if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning") + o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable)) + } + new RowOrdering(boundOrdering) + } else { + null // Ordering will not be used + } + } + override def execute(): RDD[Row] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => @@ -56,7 +88,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // we can avoid the defensive copies to improve performance. In the long run, we probably // want to include information in shuffle dependencies to indicate whether elements in the // source RDD should be copied. - val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) { + val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold + + val rdd = if (willMergeSort || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => val hashExpressions = newMutableProjection(expressions, child.output)() iter.map(r => (hashExpressions(r).copy(), r.copy())) @@ -69,12 +103,17 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una } } val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part) + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Row, Row](rdd, part) + } shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => - val rdd = if (sortBasedShuffleOn) { + val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) { child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} } else { child.execute().mapPartitions { iter => @@ -87,7 +126,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una implicit val ordering = new RowOrdering(sortingExpressions, child.output) val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) + val shuffled = + if (newOrdering.nonEmpty) { + new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering) + } else { + new ShuffledRDD[Row, Null, Null](rdd, part) + } shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) shuffled.map(_._1) @@ -120,27 +164,34 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[Exchange]] Operators where required. + * each operator by inserting [[Exchange]] Operators where required. Also ensure that the + * required input partition ordering requirements are met. */ -private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { +private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. def numPartitions: Int = sqlContext.conf.numShufflePartitions def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - // Check if every child's outputPartitioning satisfies the corresponding + // True iff every child's outputPartitioning satisfies the corresponding // required data distribution. def meetsRequirements: Boolean = - !operator.requiredChildDistribution.zip(operator.children).map { + operator.requiredChildDistribution.zip(operator.children).forall { case (required, child) => val valid = child.outputPartitioning.satisfies(required) logDebug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid - }.exists(!_) + } - // Check if outputPartitionings of children are compatible with each other. + // True iff any of the children are incorrectly sorted. + def needsAnySort: Boolean = + operator.requiredChildOrdering.zip(operator.children).exists { + case (required, child) => required.nonEmpty && required != child.outputOrdering + } + + // True iff outputPartitionings of children are compatible with each other. // It is possible that every child satisfies its required data distribution // but two children have incompatible outputPartitionings. For example, // A dataset is range partitioned by "a.asc" (RangePartitioning) and another @@ -157,28 +208,69 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl case Seq(a,b) => a compatibleWith b }.exists(!_) - // Check if the partitioning we want to ensure is the same as the child's output - // partitioning. If so, we do not need to add the Exchange operator. - def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan = - if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child + // Adds Exchange or Sort operators as required + def addOperatorsIfNecessary( + partitioning: Partitioning, + rowOrdering: Seq[SortOrder], + child: SparkPlan): SparkPlan = { + val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering + val needsShuffle = child.outputPartitioning != partitioning + val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering) + + if (needSort && needsShuffle && canSortWithShuffle) { + Exchange(partitioning, rowOrdering, child) + } else { + val withShuffle = if (needsShuffle) { + Exchange(partitioning, Nil, child) + } else { + child + } - if (meetsRequirements && compatible) { + val withSort = if (needSort) { + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, withShuffle) + } else { + Sort(rowOrdering, global = false, withShuffle) + } + } else { + withShuffle + } + + withSort + } + } + + if (meetsRequirements && compatible && !needsAnySort) { operator } else { // At least one child does not satisfies its required data distribution or // at least one child's outputPartitioning is not compatible with another child's // outputPartitioning. In this case, we need to add Exchange operators. - val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map { - case (AllTuples, child) => - addExchangeIfNecessary(SinglePartition, child) - case (ClusteredDistribution(clustering), child) => - addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child) - case (OrderedDistribution(ordering), child) => - addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child) - case (UnspecifiedDistribution, child) => child - case (dist, _) => sys.error(s"Don't know how to ensure $dist") + val requirements = + (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) + + val fixedChildren = requirements.zipped.map { + case (AllTuples, rowOrdering, child) => + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (ClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (OrderedDistribution(ordering), rowOrdering, child) => + addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + + case (UnspecifiedDistribution, Seq(), child) => + child + case (UnspecifiedDistribution, rowOrdering, child) => + if (sqlContext.conf.externalSortEnabled) { + ExternalSort(rowOrdering, global = false, child) + } else { + Sort(rowOrdering, global = false, child) + } + + case (dist, ordering, _) => + sys.error(s"Don't know how to ensure $dist with ordering $ordering") } - operator.withNewChildren(repartitionedChildren) + + operator.withNewChildren(fixedChildren) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index fabcf6b4a0570..e159ffe66cb24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def requiredChildDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution) + /** Specifies how data is ordered in each partition. */ + def outputOrdering: Seq[SortOrder] = Nil + + /** Specifies sort order for each partition requirements on the input data for this operator. */ + def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + /** * Runs this query returning the result as an RDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5b99e40c2f491..e687d01f57520 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -90,6 +90,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + // If the sort merge join option is set, we want to use sort merge join prior to hashjoin + // for now let's support inner join first, then add outer join + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) + if sqlContext.conf.sortMergeJoinEnabled => + val mergeJoin = + joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) => val buildSide = if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) { @@ -309,7 +317,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => - execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + execution.Exchange( + HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f8221f41bc6c3..308dae236a5ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -41,6 +41,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends val resuableProjection = buildProjection() iter.map(resuableProjection) } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -55,6 +57,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { override def execute(): RDD[Row] = child.execute().mapPartitions { iter => iter.filter(conditionEvaluator) } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering } /** @@ -147,6 +151,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. override def execute(): RDD[Row] = sparkContext.makeRDD(collectData(), 1) + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -172,6 +178,8 @@ case class Sort( } override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** @@ -202,6 +210,8 @@ case class ExternalSort( } override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala new file mode 100644 index 0000000000000..b5123668ba11e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.NoSuchElementException + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs an sort merge join of two child relations. + */ +@DeveloperApi +case class SortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = left.output ++ right.output + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + // this is to manually construct an ordering that can be used to compare keys from both sides + private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) + + override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + + @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + keys.map(SortOrder(_, Ascending)) + + override def execute(): RDD[Row] = { + val leftResults = left.execute().map(_.copy()) + val rightResults = right.execute().map(_.copy()) + + leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + new Iterator[Row] { + // Mutable per row objects. + private[this] val joinRow = new JoinedRow5 + private[this] var leftElement: Row = _ + private[this] var rightElement: Row = _ + private[this] var leftKey: Row = _ + private[this] var rightKey: Row = _ + private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var rightPosition: Int = -1 + private[this] var stop: Boolean = false + private[this] var matchKey: Row = _ + + // initialize iterator + initialize() + + override final def hasNext: Boolean = nextMatchingPair() + + override final def next(): Row = { + if (hasNext) { + // we are using the buffered right rows and run down left iterator + val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) + rightPosition += 1 + if (rightPosition >= rightMatches.size) { + rightPosition = 0 + fetchLeft() + if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { + stop = false + rightMatches = null + } + } + joinedRow + } else { + // no more result + throw new NoSuchElementException + } + } + + private def fetchLeft() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + leftKey = leftKeyGenerator(leftElement) + } else { + leftElement = null + } + } + + private def fetchRight() = { + if (rightIter.hasNext) { + rightElement = rightIter.next() + rightKey = rightKeyGenerator(rightElement) + } else { + rightElement = null + } + } + + private def initialize() = { + fetchLeft() + fetchRight() + } + + /** + * Searches the right iterator for the next rows that have matches in left side, and store + * them in a buffer. + * + * @return true if the search is successful, and false if the right iterator runs out of + * tuples. + */ + private def nextMatchingPair(): Boolean = { + if (!stop && rightElement != null) { + // run both side to get the first match pair + while (!stop && leftElement != null && rightElement != null) { + val comparing = keyOrdering.compare(leftKey, rightKey) + // for inner join, we need to filter those null keys + stop = comparing == 0 && !leftKey.anyNull + if (comparing > 0 || rightKey.anyNull) { + fetchRight() + } else if (comparing < 0 || leftKey.anyNull) { + fetchLeft() + } + } + rightMatches = new CompactBuffer[Row]() + if (stop) { + stop = false + // iterate the right side to buffer all rows that matches + // as the records should be ordered, exit when we meet the first that not match + while (!stop && rightElement != null) { + rightMatches += rightElement + fetchRight() + stop = keyOrdering.compare(leftKey, rightKey) != 0 + } + if (rightMatches.size > 0) { + rightPosition = 0 + matchKey = leftKey + } + } + } + rightMatches != null && rightMatches.size > 0 + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e4dee87849fd4..037d392c1f929 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -51,6 +51,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case j: CartesianProduct => j case j: BroadcastNestedLoopJoin => j case j: BroadcastLeftSemiJoinHash => j + case j: SortMergeJoin => j } assert(operators.size === 1) @@ -62,6 +63,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("join operator selection") { cacheManager.clearCache() + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -91,17 +93,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + conf.setConf("spark.sql.planner.sortMergeJoin", "true") + Seq( + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + } } test("broadcasted hash join operator selection") { cacheManager.clearCache() sql("CACHE TABLE testData") + val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), - ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + try { + conf.setConf("spark.sql.planner.sortMergeJoin", "true") + Seq( + ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a and key = 2", + classOf[BroadcastHashJoin]), + ("SELECT * FROM testData join testData2 ON key = a where key = 2", + classOf[BroadcastHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + } sql("UNCACHE TABLE testData") } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala new file mode 100644 index 0000000000000..65d070bd3cbde --- /dev/null +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.test.TestHive + +/** + * Runs the test cases that are included in the hive distribution with sort merge join is true. + */ +class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { + override def beforeAll() { + super.beforeAll() + TestHive.setConf(SQLConf.SORTMERGE_JOIN, "true") + } + + override def afterAll() { + TestHive.setConf(SQLConf.SORTMERGE_JOIN, "false") + super.afterAll() + } + + override def whiteList = Seq( + "auto_join0", + "auto_join1", + "auto_join10", + "auto_join11", + "auto_join12", + "auto_join13", + "auto_join14", + "auto_join14_hadoop20", + "auto_join15", + "auto_join17", + "auto_join18", + "auto_join19", + "auto_join2", + "auto_join20", + "auto_join21", + "auto_join22", + "auto_join23", + "auto_join24", + "auto_join25", + "auto_join26", + "auto_join27", + "auto_join28", + "auto_join3", + "auto_join30", + "auto_join31", + "auto_join32", + "auto_join4", + "auto_join5", + "auto_join6", + "auto_join7", + "auto_join8", + "auto_join9", + "auto_join_filters", + "auto_join_nulls", + "auto_join_reordering_values", + "auto_smb_mapjoin_14", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_13", + "auto_sortmerge_join_14", + "auto_sortmerge_join_15", + "auto_sortmerge_join_16", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", + "correlationoptimizer1", + "correlationoptimizer10", + "correlationoptimizer11", + "correlationoptimizer13", + "correlationoptimizer14", + "correlationoptimizer15", + "correlationoptimizer2", + "correlationoptimizer3", + "correlationoptimizer4", + "correlationoptimizer6", + "correlationoptimizer7", + "correlationoptimizer8", + "correlationoptimizer9", + "join0", + "join1", + "join10", + "join11", + "join12", + "join13", + "join14", + "join14_hadoop20", + "join15", + "join16", + "join17", + "join18", + "join19", + "join2", + "join20", + "join21", + "join22", + "join23", + "join24", + "join25", + "join26", + "join27", + "join28", + "join29", + "join3", + "join30", + "join31", + "join32", + "join32_lessSize", + "join33", + "join34", + "join35", + "join36", + "join37", + "join38", + "join39", + "join4", + "join40", + "join41", + "join5", + "join6", + "join7", + "join8", + "join9", + "join_1to1", + "join_array", + "join_casesensitive", + "join_empty", + "join_filters", + "join_hive_626", + "join_map_ppr", + "join_nulls", + "join_nullsafe", + "join_rc", + "join_reorder2", + "join_reorder3", + "join_reorder4", + "join_star" + ) +} From d5f1b9650b6e46cf6a9d61f01cda0df0cda5b1c9 Mon Sep 17 00:00:00 2001 From: Isaias Barroso Date: Wed, 15 Apr 2015 22:40:52 +0100 Subject: [PATCH 012/191] [SPARK-2312] Logging Unhandled messages The previous solution has changed based on https://github.com/apache/spark/pull/2048 discussions. Author: Isaias Barroso Closes #2055 from isaias/SPARK-2312 and squashes the following commits: f61d9e6 [Isaias Barroso] Change Log level for unhandled message to debug f341777 [Isaias Barroso] [SPARK-2312] Logging Unhandled messages --- .../scala/org/apache/spark/util/ActorLogReceive.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala index 332d0cbb2dc0c..81a7cbde01ce5 100644 --- a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala +++ b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala @@ -43,7 +43,13 @@ private[spark] trait ActorLogReceive { private val _receiveWithLogging = receiveWithLogging - override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) + override def isDefinedAt(o: Any): Boolean = { + val handled = _receiveWithLogging.isDefinedAt(o) + if (!handled) { + log.debug(s"Received unexpected actor system event: $o") + } + handled + } override def apply(o: Any): Unit = { if (log.isDebugEnabled) { From 8a53de16fc8208358b76d0f3d45538e0304bcc8e Mon Sep 17 00:00:00 2001 From: Max Seiden Date: Wed, 15 Apr 2015 16:15:11 -0700 Subject: [PATCH 013/191] [SPARK-5277][SQL] - SparkSqlSerializer doesn't always register user specified KryoRegistrators [SPARK-5277][SQL] - SparkSqlSerializer doesn't always register user specified KryoRegistrators There were a few places where new SparkSqlSerializer instances were created with new, empty SparkConfs resulting in user specified registrators sometimes not getting initialized. The fix is to try and pull a conf from the SparkEnv, and construct a new conf (that loads defaults) if one cannot be found. The changes touched: 1) SparkSqlSerializer's resource pool (this appears to fix the issue in the comment) 2) execution.Exchange (for all of the partitioners) 3) execution.Limit (for the HashPartitioner) A few tests were added to ColumnTypeSuite, ensuring that a custom registrator and serde is initialized and used when in-memory columns are written. Author: Max Seiden This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #5237 from mhseiden/sql_udt_kryo and squashes the following commits: 3175c2f [Max Seiden] [SPARK-5277][SQL] - address code review comments e5011fb [Max Seiden] [SPARK-5277][SQL] - SparkSqlSerializer does not register user specified KryoRegistrators --- .../apache/spark/sql/execution/Exchange.scala | 9 +-- .../sql/execution/SparkSqlSerializer.scala | 7 +-- .../spark/sql/execution/basicOperators.scala | 2 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 62 ++++++++++++++++++- 4 files changed, 68 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 518fc9e57c708..69a620e1ec929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -78,6 +78,8 @@ case class Exchange( } override def execute(): RDD[Row] = attachTree(this , "execute") { + lazy val sparkConf = child.sqlContext.sparkContext.getConf + newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. @@ -109,7 +111,7 @@ case class Exchange( } else { new ShuffledRDD[Row, Row, Row](rdd, part) } - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => @@ -132,8 +134,7 @@ case class Exchange( } else { new ShuffledRDD[Row, Null, Null](rdd, part) } - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) - + shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) shuffled.map(_._1) case SinglePartition => @@ -151,7 +152,7 @@ case class Exchange( } val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) shuffled.map(_._2) case _ => sys.error(s"Exchange not implemented for $newPartitioning") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 914f387dec78f..eea15aff5dbcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -65,12 +65,9 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co private[execution] class KryoResourcePool(size: Int) extends ResourcePool[SerializerInstance](size) { - val ser: KryoSerializer = { + val ser: SparkSqlSerializer = { val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) - // TODO (lian) Using KryoSerializer here is workaround, needs further investigation - // Using SparkSqlSerializer here makes BasicQuerySuite to fail because of Kryo serialization - // related error. - new KryoSerializer(sparkConf) + new SparkSqlSerializer(sparkConf) } def newInstance(): SerializerInstance = ser.newInstance() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 308dae236a5ed..d286fe81bee5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -121,7 +121,7 @@ case class Limit(limit: Int, child: SparkPlan) } val part = new HashPartitioner(1) val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part) - shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf)) shuffled.mapPartitions(_.take(limit).map(_._2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index c86ef338fc644..b48bed1871c50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -20,9 +20,12 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer import java.sql.Timestamp +import com.esotericsoftware.kryo.{Serializer, Kryo} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.spark.serializer.KryoRegistrator import org.scalatest.FunSuite -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -73,7 +76,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(BINARY, binary, 4 + 4) val generic = Map(1 -> "a") - checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11) + checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } testNativeColumnType[BooleanType.type]( @@ -158,6 +161,41 @@ class ColumnTypeSuite extends FunSuite with Logging { } } + test("CUSTOM") { + val conf = new SparkConf() + conf.set("spark.kryo.registrator", "org.apache.spark.sql.columnar.Registrator") + val serializer = new SparkSqlSerializer(conf).newInstance() + + val buffer = ByteBuffer.allocate(512) + val obj = CustomClass(Int.MaxValue,Long.MaxValue) + val serializedObj = serializer.serialize(obj).array() + + GENERIC.append(serializer.serialize(obj).array(), buffer) + buffer.rewind() + + val length = buffer.getInt + assert(length === serializedObj.length) + assert(13 == length) // id (1) + int (4) + long (8) + + val genericSerializedObj = SparkSqlSerializer.serialize(obj) + assert(length != genericSerializedObj.length) + assert(length < genericSerializedObj.length) + + assertResult(obj, "Custom deserialized object didn't equal the original object") { + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + serializer.deserialize(ByteBuffer.wrap(bytes)) + } + + buffer.rewind() + buffer.putInt(serializedObj.length).put(serializedObj) + + assertResult(obj, "Custom deserialized object didn't equal the original object") { + buffer.rewind() + serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer))) + } + } + def testNativeColumnType[T <: NativeType]( columnType: NativeColumnType[T], putter: (ByteBuffer, T#JvmType) => Unit, @@ -229,3 +267,23 @@ class ColumnTypeSuite extends FunSuite with Logging { } } } + +private[columnar] final case class CustomClass(a: Int, b: Long) + +private[columnar] object CustomerSerializer extends Serializer[CustomClass] { + override def write(kryo: Kryo, output: Output, t: CustomClass) { + output.writeInt(t.a) + output.writeLong(t.b) + } + override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { + val a = input.readInt() + val b = input.readLong() + CustomClass(a,b) + } +} + +private[columnar] final class Registrator extends KryoRegistrator { + override def registerClasses(kryo: Kryo) { + kryo.register(classOf[CustomClass], CustomerSerializer) + } +} From 52c3439a8a107ce1fc10e4f0b59fd7881e851622 Mon Sep 17 00:00:00 2001 From: Juliet Hougland Date: Wed, 15 Apr 2015 21:52:25 -0700 Subject: [PATCH 014/191] SPARK-6938: All require statements now have an informative error message. This pr adds informative error messages to all require statements in the Vectors class that did not previously have them. This references [SPARK-6938](https://issues.apache.org/jira/browse/SPARK-6938). Author: Juliet Hougland Closes #5532 from jhlch/SPARK-6938 and squashes the following commits: ab321bb [Juliet Hougland] Remove braces from string interpolation when not required. 1221f94 [Juliet Hougland] All require statements now have an informative error message. --- .../org/apache/spark/mllib/linalg/Vectors.scala | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 328dbe2ce11fa..4ef171f4f0419 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -227,7 +227,7 @@ object Vectors { * @param elements vector elements in (index, value) pairs. */ def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = { - require(size > 0) + require(size > 0, "The size of the requested sparse vector must be greater than 0.") val (indices, values) = elements.sortBy(_._1).unzip var prev = -1 @@ -235,7 +235,8 @@ object Vectors { require(prev < i, s"Found duplicate indices: $i.") prev = i } - require(prev < size) + require(prev < size, s"You may not write an element to index $prev because the declared " + + s"size of your vector is $size") new SparseVector(size, indices.toArray, values.toArray) } @@ -309,7 +310,8 @@ object Vectors { * @return norm in L^p^ space. */ def norm(vector: Vector, p: Double): Double = { - require(p >= 1.0) + require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " + + s"You specified p=$p.") val values = vector match { case DenseVector(vs) => vs case SparseVector(n, ids, vs) => vs @@ -360,7 +362,8 @@ object Vectors { * @return squared distance between two Vectors. */ def sqdist(v1: Vector, v2: Vector): Double = { - require(v1.size == v2.size, "vector dimension mismatch") + require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" + + s"=${v2.size}.") var squaredDistance = 0.0 (v1, v2) match { case (v1: SparseVector, v2: SparseVector) => @@ -518,7 +521,9 @@ class SparseVector( val indices: Array[Int], val values: Array[Double]) extends Vector { - require(indices.length == values.length) + require(indices.length == values.length, "Sparse vectors require that the dimension of the" + + s" indices match the dimension of the values. You provided ${indices.size} indices and " + + s" ${values.size} values.") override def toString: String = "(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]")) From 57cd1e86d1d450f85fc9e296aff498a940452113 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 15 Apr 2015 23:49:42 -0700 Subject: [PATCH 015/191] [SPARK-6893][ML] default pipeline parameter handling in python Same as #5431 but for Python. jkbradley Author: Xiangrui Meng Closes #5534 from mengxr/SPARK-6893 and squashes the following commits: d3b519b [Xiangrui Meng] address comments ebaccc6 [Xiangrui Meng] style update fce244e [Xiangrui Meng] update explainParams with test 4d6b07a [Xiangrui Meng] add tests 5294500 [Xiangrui Meng] update default param handling in python --- .../org/apache/spark/ml/Identifiable.scala | 2 +- .../apache/spark/ml/param/TestParams.scala | 9 +- python/pyspark/ml/classification.py | 3 +- python/pyspark/ml/feature.py | 19 +-- python/pyspark/ml/param/__init__.py | 146 +++++++++++++++--- ...d_params.py => _shared_params_code_gen.py} | 42 ++--- python/pyspark/ml/param/shared.py | 106 ++++++------- python/pyspark/ml/pipeline.py | 6 +- python/pyspark/ml/tests.py | 52 ++++++- python/pyspark/ml/util.py | 4 +- python/pyspark/ml/wrapper.py | 2 +- 11 files changed, 270 insertions(+), 121 deletions(-) rename python/pyspark/ml/param/{_gen_shared_params.py => _shared_params_code_gen.py} (70%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala index a50090671ae48..a1d49095c24ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -25,7 +25,7 @@ import java.util.UUID private[ml] trait Identifiable extends Serializable { /** - * A unique id for the object. The default implementation concatenates the class name, "-", and 8 + * A unique id for the object. The default implementation concatenates the class name, "_", and 8 * random hex chars. */ private[ml] val uid: String = diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 8f9ab687c05cb..641b64b42a5e7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -17,16 +17,13 @@ package org.apache.spark.ml.param +import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} + /** A subclass of Params for testing. */ -class TestParams extends Params { +class TestParams extends Params with HasMaxIter with HasInputCol { - val maxIter = new IntParam(this, "maxIter", "max number of iterations") def setMaxIter(value: Int): this.type = { set(maxIter, value); this } - def getMaxIter: Int = getOrDefault(maxIter) - - val inputCol = new Param[String](this, "inputCol", "input column name") def setInputCol(value: String): this.type = { set(inputCol, value); this } - def getInputCol: String = getOrDefault(inputCol) setDefault(maxIter -> 10) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 7f42de531f3b4..d7bc09fd77adb 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -59,6 +59,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxIter=100, regParam=0.1) """ super(LogisticRegression, self).__init__() + self._setDefault(maxIter=100, regParam=0.1) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -71,7 +72,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre Sets params for logistic regression. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) def _create_model(self, java_model): return LogisticRegressionModel(java_model) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 1cfcd019dfb18..263fe2a5bcc41 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -52,22 +52,22 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): _java_class = "org.apache.spark.ml.feature.Tokenizer" @keyword_only - def __init__(self, inputCol="input", outputCol="output"): + def __init__(self, inputCol=None, outputCol=None): """ - __init__(self, inputCol="input", outputCol="output") + __init__(self, inputCol=None, outputCol=None) """ super(Tokenizer, self).__init__() kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, inputCol="input", outputCol="output"): + def setParams(self, inputCol=None, outputCol=None): """ setParams(self, inputCol="input", outputCol="output") Sets params for this Tokenizer. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) @inherit_doc @@ -91,22 +91,23 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): _java_class = "org.apache.spark.ml.feature.HashingTF" @keyword_only - def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ - __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) """ super(HashingTF, self).__init__() + self._setDefault(numFeatures=1 << 18) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ - setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None) Sets params for this HashingTF. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) if __name__ == "__main__": diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index e3a53dd780c4c..5c62620562a84 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -25,23 +25,21 @@ class Param(object): """ - A param with self-contained documentation and optionally default value. + A param with self-contained documentation. """ - def __init__(self, parent, name, doc, defaultValue=None): - if not isinstance(parent, Identifiable): - raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__) + def __init__(self, parent, name, doc): + if not isinstance(parent, Params): + raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__) self.parent = parent self.name = str(name) self.doc = str(doc) - self.defaultValue = defaultValue def __str__(self): - return str(self.parent) + "-" + self.name + return str(self.parent) + "__" + self.name def __repr__(self): - return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \ - (self.parent, self.name, self.doc, self.defaultValue) + return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc) class Params(Identifiable): @@ -52,26 +50,128 @@ class Params(Identifiable): __metaclass__ = ABCMeta - def __init__(self): - super(Params, self).__init__() - #: embedded param map - self.paramMap = {} + #: internal param map for user-supplied values param map + paramMap = {} + + #: internal param map for default values + defaultParamMap = {} @property def params(self): """ - Returns all params. The default implementation uses - :py:func:`dir` to get all attributes of type + Returns all params ordered by name. The default implementation + uses :py:func:`dir` to get all attributes of type :py:class:`Param`. """ return filter(lambda attr: isinstance(attr, Param), [getattr(self, x) for x in dir(self) if x != "params"]) - def _merge_params(self, params): - paramMap = self.paramMap.copy() - paramMap.update(params) + def _explain(self, param): + """ + Explains a single param and returns its name, doc, and optional + default value and user-supplied value in a string. + """ + param = self._resolveParam(param) + values = [] + if self.isDefined(param): + if param in self.defaultParamMap: + values.append("default: %s" % self.defaultParamMap[param]) + if param in self.paramMap: + values.append("current: %s" % self.paramMap[param]) + else: + values.append("undefined") + valueStr = "(" + ", ".join(values) + ")" + return "%s: %s %s" % (param.name, param.doc, valueStr) + + def explainParams(self): + """ + Returns the documentation of all params with their optionally + default values and user-supplied values. + """ + return "\n".join([self._explain(param) for param in self.params]) + + def getParam(self, paramName): + """ + Gets a param by its name. + """ + param = getattr(self, paramName) + if isinstance(param, Param): + return param + else: + raise ValueError("Cannot find param with name %s." % paramName) + + def isSet(self, param): + """ + Checks whether a param is explicitly set by user. + """ + param = self._resolveParam(param) + return param in self.paramMap + + def hasDefault(self, param): + """ + Checks whether a param has a default value. + """ + param = self._resolveParam(param) + return param in self.defaultParamMap + + def isDefined(self, param): + """ + Checks whether a param is explicitly set by user or has a default value. + """ + return self.isSet(param) or self.hasDefault(param) + + def getOrDefault(self, param): + """ + Gets the value of a param in the user-supplied param map or its + default value. Raises an error if either is set. + """ + if isinstance(param, Param): + if param in self.paramMap: + return self.paramMap[param] + else: + return self.defaultParamMap[param] + elif isinstance(param, str): + return self.getOrDefault(self.getParam(param)) + else: + raise KeyError("Cannot recognize %r as a param." % param) + + def extractParamMap(self, extraParamMap={}): + """ + Extracts the embedded default param values and user-supplied + values, and then merges them with extra values from input into + a flat param map, where the latter value is used if there exist + conflicts, i.e., with ordering: default param values < + user-supplied values < extraParamMap. + :param extraParamMap: extra param values + :return: merged param map + """ + paramMap = self.defaultParamMap.copy() + paramMap.update(self.paramMap) + paramMap.update(extraParamMap) return paramMap + def _shouldOwn(self, param): + """ + Validates that the input param belongs to this Params instance. + """ + if param.parent is not self: + raise ValueError("Param %r does not belong to %r." % (param, self)) + + def _resolveParam(self, param): + """ + Resolves a param and validates the ownership. + :param param: param name or the param instance, which must + belong to this Params instance + :return: resolved param instance + """ + if isinstance(param, Param): + self._shouldOwn(param) + return param + elif isinstance(param, str): + return self.getParam(param) + else: + raise ValueError("Cannot resolve %r as a param." % param) + @staticmethod def _dummy(): """ @@ -81,10 +181,18 @@ def _dummy(): dummy.uid = "undefined" return dummy - def _set_params(self, **kwargs): + def _set(self, **kwargs): """ - Sets params. + Sets user-supplied params. """ for param, value in kwargs.iteritems(): self.paramMap[getattr(self, param)] = value return self + + def _setDefault(self, **kwargs): + """ + Sets default params. + """ + for param, value in kwargs.iteritems(): + self.defaultParamMap[getattr(self, param)] = value + return self diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_shared_params_code_gen.py similarity index 70% rename from python/pyspark/ml/param/_gen_shared_params.py rename to python/pyspark/ml/param/_shared_params_code_gen.py index 5eb81106f116c..55f422497672f 100644 --- a/python/pyspark/ml/param/_gen_shared_params.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -32,29 +32,34 @@ # limitations under the License. #""" +# Code generator for shared params (shared.py). Run under this folder with: +# python _shared_params_code_gen.py > shared.py -def _gen_param_code(name, doc, defaultValue): + +def _gen_param_code(name, doc, defaultValueStr): """ Generates Python code for a shared param class. :param name: param name :param doc: param doc - :param defaultValue: string representation of the param + :param defaultValueStr: string representation of the default value :return: code string """ # TODO: How to correctly inherit instance attributes? template = '''class Has$Name(Params): """ - Params with $name. + Mixin for param $name: $doc. """ # a placeholder to make it appear in the generated doc - $name = Param(Params._dummy(), "$name", "$doc", $defaultValue) + $name = Param(Params._dummy(), "$name", "$doc") def __init__(self): super(Has$Name, self).__init__() #: param for $doc - self.$name = Param(self, "$name", "$doc", $defaultValue) + self.$name = Param(self, "$name", "$doc") + if $defaultValueStr is not None: + self._setDefault($name=$defaultValueStr) def set$Name(self, value): """ @@ -67,32 +72,29 @@ def get$Name(self): """ Gets the value of $name or its default value. """ - if self.$name in self.paramMap: - return self.paramMap[self.$name] - else: - return self.$name.defaultValue''' + return self.getOrDefault(self.$name)''' - upperCamelName = name[0].upper() + name[1:] + Name = name[0].upper() + name[1:] return template \ .replace("$name", name) \ - .replace("$Name", upperCamelName) \ + .replace("$Name", Name) \ .replace("$doc", doc) \ - .replace("$defaultValue", defaultValue) + .replace("$defaultValueStr", str(defaultValueStr)) if __name__ == "__main__": print header - print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n" + print "\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n" print "from pyspark.ml.param import Param, Params\n\n" shared = [ - ("maxIter", "max number of iterations", "100"), - ("regParam", "regularization constant", "0.1"), + ("maxIter", "max number of iterations", None), + ("regParam", "regularization constant", None), ("featuresCol", "features column name", "'features'"), ("labelCol", "label column name", "'label'"), ("predictionCol", "prediction column name", "'prediction'"), - ("inputCol", "input column name", "'input'"), - ("outputCol", "output column name", "'output'"), - ("numFeatures", "number of features", "1 << 18")] + ("inputCol", "input column name", None), + ("outputCol", "output column name", None), + ("numFeatures", "number of features", None)] code = [] - for name, doc, defaultValue in shared: - code.append(_gen_param_code(name, doc, defaultValue)) + for name, doc, defaultValueStr in shared: + code.append(_gen_param_code(name, doc, defaultValueStr)) print "\n\n\n".join(code) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 586822f2de423..13b6749998ad0 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -15,23 +15,25 @@ # limitations under the License. # -# DO NOT MODIFY. The code is generated by _gen_shared_params.py. +# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py. from pyspark.ml.param import Param, Params class HasMaxIter(Params): """ - Params with maxIter. + Mixin for param maxIter: max number of iterations. """ # a placeholder to make it appear in the generated doc - maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100) + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations") def __init__(self): super(HasMaxIter, self).__init__() #: param for max number of iterations - self.maxIter = Param(self, "maxIter", "max number of iterations", 100) + self.maxIter = Param(self, "maxIter", "max number of iterations") + if None is not None: + self._setDefault(maxIter=None) def setMaxIter(self, value): """ @@ -44,24 +46,23 @@ def getMaxIter(self): """ Gets the value of maxIter or its default value. """ - if self.maxIter in self.paramMap: - return self.paramMap[self.maxIter] - else: - return self.maxIter.defaultValue + return self.getOrDefault(self.maxIter) class HasRegParam(Params): """ - Params with regParam. + Mixin for param regParam: regularization constant. """ # a placeholder to make it appear in the generated doc - regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1) + regParam = Param(Params._dummy(), "regParam", "regularization constant") def __init__(self): super(HasRegParam, self).__init__() #: param for regularization constant - self.regParam = Param(self, "regParam", "regularization constant", 0.1) + self.regParam = Param(self, "regParam", "regularization constant") + if None is not None: + self._setDefault(regParam=None) def setRegParam(self, value): """ @@ -74,24 +75,23 @@ def getRegParam(self): """ Gets the value of regParam or its default value. """ - if self.regParam in self.paramMap: - return self.paramMap[self.regParam] - else: - return self.regParam.defaultValue + return self.getOrDefault(self.regParam) class HasFeaturesCol(Params): """ - Params with featuresCol. + Mixin for param featuresCol: features column name. """ # a placeholder to make it appear in the generated doc - featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features') + featuresCol = Param(Params._dummy(), "featuresCol", "features column name") def __init__(self): super(HasFeaturesCol, self).__init__() #: param for features column name - self.featuresCol = Param(self, "featuresCol", "features column name", 'features') + self.featuresCol = Param(self, "featuresCol", "features column name") + if 'features' is not None: + self._setDefault(featuresCol='features') def setFeaturesCol(self, value): """ @@ -104,24 +104,23 @@ def getFeaturesCol(self): """ Gets the value of featuresCol or its default value. """ - if self.featuresCol in self.paramMap: - return self.paramMap[self.featuresCol] - else: - return self.featuresCol.defaultValue + return self.getOrDefault(self.featuresCol) class HasLabelCol(Params): """ - Params with labelCol. + Mixin for param labelCol: label column name. """ # a placeholder to make it appear in the generated doc - labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label') + labelCol = Param(Params._dummy(), "labelCol", "label column name") def __init__(self): super(HasLabelCol, self).__init__() #: param for label column name - self.labelCol = Param(self, "labelCol", "label column name", 'label') + self.labelCol = Param(self, "labelCol", "label column name") + if 'label' is not None: + self._setDefault(labelCol='label') def setLabelCol(self, value): """ @@ -134,24 +133,23 @@ def getLabelCol(self): """ Gets the value of labelCol or its default value. """ - if self.labelCol in self.paramMap: - return self.paramMap[self.labelCol] - else: - return self.labelCol.defaultValue + return self.getOrDefault(self.labelCol) class HasPredictionCol(Params): """ - Params with predictionCol. + Mixin for param predictionCol: prediction column name. """ # a placeholder to make it appear in the generated doc - predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction') + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name") def __init__(self): super(HasPredictionCol, self).__init__() #: param for prediction column name - self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction') + self.predictionCol = Param(self, "predictionCol", "prediction column name") + if 'prediction' is not None: + self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): """ @@ -164,24 +162,23 @@ def getPredictionCol(self): """ Gets the value of predictionCol or its default value. """ - if self.predictionCol in self.paramMap: - return self.paramMap[self.predictionCol] - else: - return self.predictionCol.defaultValue + return self.getOrDefault(self.predictionCol) class HasInputCol(Params): """ - Params with inputCol. + Mixin for param inputCol: input column name. """ # a placeholder to make it appear in the generated doc - inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input') + inputCol = Param(Params._dummy(), "inputCol", "input column name") def __init__(self): super(HasInputCol, self).__init__() #: param for input column name - self.inputCol = Param(self, "inputCol", "input column name", 'input') + self.inputCol = Param(self, "inputCol", "input column name") + if None is not None: + self._setDefault(inputCol=None) def setInputCol(self, value): """ @@ -194,24 +191,23 @@ def getInputCol(self): """ Gets the value of inputCol or its default value. """ - if self.inputCol in self.paramMap: - return self.paramMap[self.inputCol] - else: - return self.inputCol.defaultValue + return self.getOrDefault(self.inputCol) class HasOutputCol(Params): """ - Params with outputCol. + Mixin for param outputCol: output column name. """ # a placeholder to make it appear in the generated doc - outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output') + outputCol = Param(Params._dummy(), "outputCol", "output column name") def __init__(self): super(HasOutputCol, self).__init__() #: param for output column name - self.outputCol = Param(self, "outputCol", "output column name", 'output') + self.outputCol = Param(self, "outputCol", "output column name") + if None is not None: + self._setDefault(outputCol=None) def setOutputCol(self, value): """ @@ -224,24 +220,23 @@ def getOutputCol(self): """ Gets the value of outputCol or its default value. """ - if self.outputCol in self.paramMap: - return self.paramMap[self.outputCol] - else: - return self.outputCol.defaultValue + return self.getOrDefault(self.outputCol) class HasNumFeatures(Params): """ - Params with numFeatures. + Mixin for param numFeatures: number of features. """ # a placeholder to make it appear in the generated doc - numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18) + numFeatures = Param(Params._dummy(), "numFeatures", "number of features") def __init__(self): super(HasNumFeatures, self).__init__() #: param for number of features - self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18) + self.numFeatures = Param(self, "numFeatures", "number of features") + if None is not None: + self._setDefault(numFeatures=None) def setNumFeatures(self, value): """ @@ -254,7 +249,4 @@ def getNumFeatures(self): """ Gets the value of numFeatures or its default value. """ - if self.numFeatures in self.paramMap: - return self.paramMap[self.numFeatures] - else: - return self.numFeatures.defaultValue + return self.getOrDefault(self.numFeatures) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 83880a5afcd1d..d94ecfff09f66 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -124,10 +124,10 @@ def setParams(self, stages=[]): Sets params for Pipeline. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) def fit(self, dataset, params={}): - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) stages = paramMap[self.stages] for stage in stages: if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): @@ -164,7 +164,7 @@ def __init__(self, transformers): self.transformers = transformers def transform(self, dataset, params={}): - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) for t in self.transformers: dataset = t.transform(dataset, paramMap) return dataset diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b627c2b4e930b..3a42bcf723894 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -33,6 +33,7 @@ from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.sql import DataFrame from pyspark.ml.param import Param +from pyspark.ml.param.shared import HasMaxIter, HasInputCol from pyspark.ml.pipeline import Transformer, Estimator, Pipeline @@ -46,7 +47,7 @@ class MockTransformer(Transformer): def __init__(self): super(MockTransformer, self).__init__() - self.fake = Param(self, "fake", "fake", None) + self.fake = Param(self, "fake", "fake") self.dataset_index = None self.fake_param_value = None @@ -62,7 +63,7 @@ class MockEstimator(Estimator): def __init__(self): super(MockEstimator, self).__init__() - self.fake = Param(self, "fake", "fake", None) + self.fake = Param(self, "fake", "fake") self.dataset_index = None self.fake_param_value = None self.model = None @@ -111,5 +112,52 @@ def test_pipeline(self): self.assertEqual(6, dataset.index) +class TestParams(HasMaxIter, HasInputCol): + """ + A subclass of Params mixed with HasMaxIter and HasInputCol. + """ + + def __init__(self): + super(TestParams, self).__init__() + self._setDefault(maxIter=10) + + +class ParamTests(PySparkTestCase): + + def test_param(self): + testParams = TestParams() + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations") + self.assertTrue(maxIter.parent is testParams) + + def test_params(self): + testParams = TestParams() + maxIter = testParams.maxIter + inputCol = testParams.inputCol + + params = testParams.params + self.assertEqual(params, [inputCol, maxIter]) + + self.assertTrue(testParams.hasDefault(maxIter)) + self.assertFalse(testParams.isSet(maxIter)) + self.assertTrue(testParams.isDefined(maxIter)) + self.assertEqual(testParams.getMaxIter(), 10) + testParams.setMaxIter(100) + self.assertTrue(testParams.isSet(maxIter)) + self.assertEquals(testParams.getMaxIter(), 100) + + self.assertFalse(testParams.hasDefault(inputCol)) + self.assertFalse(testParams.isSet(inputCol)) + self.assertFalse(testParams.isDefined(inputCol)) + with self.assertRaises(KeyError): + testParams.getInputCol() + + self.assertEquals( + testParams.explainParams(), + "\n".join(["inputCol: input column name (undefined)", + "maxIter: max number of iterations (default: 10, current: 100)"])) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 6f7f39c40eb5a..d3cb100a9efa5 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -40,8 +40,8 @@ class Identifiable(object): def __init__(self): #: A unique id for the object. The default implementation - #: concatenates the class name, "-", and 8 random hex chars. - self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8] + #: concatenates the class name, "_", and 8 random hex chars. + self.uid = type(self).__name__ + "_" + uuid.uuid4().hex[:8] def __repr__(self): return self.uid diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 31a66b3d2f730..394f23c5e9b12 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -64,7 +64,7 @@ def _transfer_params_to_java(self, params, java_obj): :param params: additional params (overwriting embedded values) :param java_obj: Java object to receive the params """ - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) for param in self.params: if param in paramMap: java_obj.set(param.name, paramMap[param]) From 8370550593f3549e90ace446961281dad0cd7498 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 16 Apr 2015 10:39:02 +0100 Subject: [PATCH 016/191] [Streaming][minor] Remove additional quote and unneeded imports Author: jerryshao Closes #5540 from jerryshao/minor-fix and squashes the following commits: ebaa646 [jerryshao] Minor fix --- .../apache/spark/examples/streaming/DirectKafkaWordCount.scala | 2 +- .../main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index 1c8a20bf8f1ae..11a8cf09533ce 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -41,7 +41,7 @@ object DirectKafkaWordCount { | is a list of one or more Kafka brokers | is a list of one or more kafka topics to consume from | - """".stripMargin) + """.stripMargin) System.exit(1) } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index a0b8a0c565210..a1b4a12e5d6a0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -23,10 +23,9 @@ import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskC import org.apache.spark.rdd.RDD import org.apache.spark.util.NextIterator -import java.util.Properties import kafka.api.{FetchRequestBuilder, FetchResponse} import kafka.common.{ErrorMapping, TopicAndPartition} -import kafka.consumer.{ConsumerConfig, SimpleConsumer} +import kafka.consumer.SimpleConsumer import kafka.message.{MessageAndMetadata, MessageAndOffset} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties From 6179a948371897cecb7322ebda366c2de8ecaedd Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 16 Apr 2015 10:45:32 +0100 Subject: [PATCH 017/191] SPARK-4783 [CORE] System.exit() calls in SparkContext disrupt applications embedding Spark Avoid `System.exit(1)` in `TaskSchedulerImpl` and convert to `SparkException`; ensure scheduler calls `sc.stop()` even when this exception is thrown. CC mateiz aarondav as those who may have last touched this code. Author: Sean Owen Closes #5492 from srowen/SPARK-4783 and squashes the following commits: 60dc682 [Sean Owen] Avoid System.exit(1) in TaskSchedulerImpl and convert to SparkException; ensure scheduler calls sc.stop() even when this exception is thrown --- .../org/apache/spark/scheduler/TaskSchedulerImpl.scala | 5 ++--- .../scheduler/cluster/SparkDeploySchedulerBackend.scala | 9 ++++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 2362cc7240039..ecc8bf189986d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -394,7 +394,7 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (activeTaskSets.size > 0) { + if (activeTaskSets.nonEmpty) { // Have each task set throw a SparkException with the error for ((taskSetId, manager) <- activeTaskSets) { try { @@ -407,8 +407,7 @@ private[spark] class TaskSchedulerImpl( // No task sets are active but we still got an error. Just exit since this // must mean the error is during registration. // It might be good to do something smarter here in the future. - logError("Exiting due to error from cluster scheduler: " + message) - System.exit(1) + throw new SparkException(s"Exiting due to error from cluster scheduler: $message") } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ed5b7c1088196..ccf1dc5af6120 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -118,9 +118,12 @@ private[spark] class SparkDeploySchedulerBackend( notifyContext() if (!stopping) { logError("Application has been killed. Reason: " + reason) - scheduler.error(reason) - // Ensure the application terminates, as we can no longer run jobs. - sc.stop() + try { + scheduler.error(reason) + } finally { + // Ensure the application terminates, as we can no longer run jobs. + sc.stop() + } } } From de4fa6b6d12e2bee0307ffba2abfca0c33f15e45 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 16 Apr 2015 10:48:31 +0100 Subject: [PATCH 018/191] [SPARK-4194] [core] Make SparkContext initialization exception-safe. SparkContext has a very long constructor, where multiple things are initialized, multiple threads are spawned, and multiple opportunities for exceptions to be thrown exist. If one of these happens at an innoportune time, lots of garbage tends to stick around. This patch re-organizes SparkContext so that its internal state is initialized in a big "try" block. The fields keeping state are now completely private to SparkContext, and are "vars", because Scala doesn't allow you to initialize a val later. The existing API interface is kept by turning vals into defs (which works because Scala guarantees the same binary interface for those). On top of that, a few things in other areas were changed to avoid more things leaking: - Executor was changed to explicitly wait for the heartbeat thread to stop. LocalBackend was changed to wait for the "StopExecutor" message to be received, since otherwise there could be a race between that message arriving and the actor system being shut down. - ConnectionManager could possibly hang during shutdown, because an interrupt at the wrong moment could cause the selector thread to still call select and then wait forever. So also wake up the selector so that this situation is avoided. Author: Marcelo Vanzin Closes #5335 from vanzin/SPARK-4194 and squashes the following commits: 746b661 [Marcelo Vanzin] Fix borked merge. 80fc00e [Marcelo Vanzin] Merge branch 'master' into SPARK-4194 408dada [Marcelo Vanzin] Merge branch 'master' into SPARK-4194 2621609 [Marcelo Vanzin] Merge branch 'master' into SPARK-4194 6b73fcb [Marcelo Vanzin] Scalastyle. c671c46 [Marcelo Vanzin] Fix merge. 3979aad [Marcelo Vanzin] Merge branch 'master' into SPARK-4194 8caa8b3 [Marcelo Vanzin] [SPARK-4194] [core] Make SparkContext initialization exception-safe. 071f16e [Marcelo Vanzin] Nits. 27456b9 [Marcelo Vanzin] More exception safety. a0b0881 [Marcelo Vanzin] Stop alloc manager before scheduler. 5545d83 [Marcelo Vanzin] [SPARK-6650] [core] Stop ExecutorAllocationManager when context stops. --- .../scala/org/apache/spark/SparkContext.scala | 505 ++++++++++-------- .../org/apache/spark/executor/Executor.scala | 33 +- .../spark/network/nio/ConnectionManager.scala | 7 +- .../spark/scheduler/TaskSchedulerImpl.scala | 3 +- .../spark/scheduler/local/LocalBackend.scala | 19 +- .../ExecutorAllocationManagerSuite.scala | 6 - 6 files changed, 329 insertions(+), 244 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3f1a7dd99d635..e106c5c4bef60 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -31,6 +31,7 @@ import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -50,9 +51,10 @@ import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.io.CompressionCodec +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} @@ -192,8 +194,42 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") - private[spark] val conf = config.clone() - conf.validateSettings() + /* ------------------------------------------------------------------------------------- * + | Private variables. These variables keep the internal state of the context, and are | + | not accessible by the outside world. They're mutable since we want to initialize all | + | of them to some neutral value ahead of time, so that calling "stop()" while the | + | constructor is still running is safe. | + * ------------------------------------------------------------------------------------- */ + + private var _conf: SparkConf = _ + private var _eventLogDir: Option[URI] = None + private var _eventLogCodec: Option[String] = None + private var _env: SparkEnv = _ + private var _metadataCleaner: MetadataCleaner = _ + private var _jobProgressListener: JobProgressListener = _ + private var _statusTracker: SparkStatusTracker = _ + private var _progressBar: Option[ConsoleProgressBar] = None + private var _ui: Option[SparkUI] = None + private var _hadoopConfiguration: Configuration = _ + private var _executorMemory: Int = _ + private var _schedulerBackend: SchedulerBackend = _ + private var _taskScheduler: TaskScheduler = _ + private var _heartbeatReceiver: RpcEndpointRef = _ + @volatile private var _dagScheduler: DAGScheduler = _ + private var _applicationId: String = _ + private var _eventLogger: Option[EventLoggingListener] = None + private var _executorAllocationManager: Option[ExecutorAllocationManager] = None + private var _cleaner: Option[ContextCleaner] = None + private var _listenerBusStarted: Boolean = false + private var _jars: Seq[String] = _ + private var _files: Seq[String] = _ + + /* ------------------------------------------------------------------------------------- * + | Accessors and public fields. These provide access to the internal state of the | + | context. | + * ------------------------------------------------------------------------------------- */ + + private[spark] def conf: SparkConf = _conf /** * Return a copy of this SparkContext's configuration. The configuration ''cannot'' be @@ -201,65 +237,24 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def getConf: SparkConf = conf.clone() - if (!conf.contains("spark.master")) { - throw new SparkException("A master URL must be set in your configuration") - } - if (!conf.contains("spark.app.name")) { - throw new SparkException("An application name must be set in your configuration") - } - - if (conf.getBoolean("spark.logConf", false)) { - logInfo("Spark configuration:\n" + conf.toDebugString) - } - - // Set Spark driver host and port system properties - conf.setIfMissing("spark.driver.host", Utils.localHostName()) - conf.setIfMissing("spark.driver.port", "0") - - val jars: Seq[String] = - conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten - - val files: Seq[String] = - conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten - - val master = conf.get("spark.master") - val appName = conf.get("spark.app.name") + def jars: Seq[String] = _jars + def files: Seq[String] = _files + def master: String = _conf.get("spark.master") + def appName: String = _conf.get("spark.app.name") - private[spark] val isEventLogEnabled = conf.getBoolean("spark.eventLog.enabled", false) - private[spark] val eventLogDir: Option[URI] = { - if (isEventLogEnabled) { - val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) - .stripSuffix("/") - Some(Utils.resolveURI(unresolvedDir)) - } else { - None - } - } - private[spark] val eventLogCodec: Option[String] = { - val compress = conf.getBoolean("spark.eventLog.compress", false) - if (compress && isEventLogEnabled) { - Some(CompressionCodec.getCodecName(conf)).map(CompressionCodec.getShortName) - } else { - None - } - } + private[spark] def isEventLogEnabled: Boolean = _conf.getBoolean("spark.eventLog.enabled", false) + private[spark] def eventLogDir: Option[URI] = _eventLogDir + private[spark] def eventLogCodec: Option[String] = _eventLogCodec // Generate the random name for a temp folder in Tachyon // Add a timestamp as the suffix here to make it more safe val tachyonFolderName = "spark-" + randomUUID.toString() - conf.set("spark.tachyonStore.folderName", tachyonFolderName) - val isLocal = (master == "local" || master.startsWith("local[")) - - if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + def isLocal: Boolean = (master == "local" || master.startsWith("local[")) // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus - conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - - // Create the Spark execution environment (cache, map output tracker, etc) - // This function allows components created by SparkEnv to be mocked in unit tests: private[spark] def createSparkEnv( conf: SparkConf, @@ -268,8 +263,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkEnv.createDriverEnv(conf, isLocal, listenerBus) } - private[spark] val env = createSparkEnv(conf, isLocal, listenerBus) - SparkEnv.set(env) + private[spark] def env: SparkEnv = _env // Used to store a URL for each static file/jar together with the file's local timestamp private[spark] val addedFiles = HashMap[String, Long]() @@ -277,35 +271,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] - private[spark] val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) - + private[spark] def metadataCleaner: MetadataCleaner = _metadataCleaner + private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener - private[spark] val jobProgressListener = new JobProgressListener(conf) - listenerBus.addListener(jobProgressListener) + def statusTracker: SparkStatusTracker = _statusTracker - val statusTracker = new SparkStatusTracker(this) + private[spark] def progressBar: Option[ConsoleProgressBar] = _progressBar - private[spark] val progressBar: Option[ConsoleProgressBar] = - if (conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { - Some(new ConsoleProgressBar(this)) - } else { - None - } - - // Initialize the Spark UI - private[spark] val ui: Option[SparkUI] = - if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.createLiveUI(this, conf, listenerBus, jobProgressListener, - env.securityManager,appName)) - } else { - // For tests, do not enable the UI - None - } - - // Bind the UI before starting the task scheduler to communicate - // the bound port to the cluster manager properly - ui.foreach(_.bind()) + private[spark] def ui: Option[SparkUI] = _ui /** * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. @@ -313,134 +286,248 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ - val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) + def hadoopConfiguration: Configuration = _hadoopConfiguration + + private[spark] def executorMemory: Int = _executorMemory + + // Environment variables to pass to our executors. + private[spark] val executorEnvs = HashMap[String, String]() + + // Set SPARK_USER for user who is running SparkContext. + val sparkUser = Utils.getCurrentUserName() - // Add each JAR given through the constructor - if (jars != null) { - jars.foreach(addJar) + private[spark] def schedulerBackend: SchedulerBackend = _schedulerBackend + private[spark] def schedulerBackend_=(sb: SchedulerBackend): Unit = { + _schedulerBackend = sb } - if (files != null) { - files.foreach(addFile) + private[spark] def taskScheduler: TaskScheduler = _taskScheduler + private[spark] def taskScheduler_=(ts: TaskScheduler): Unit = { + _taskScheduler = ts } + private[spark] def dagScheduler: DAGScheduler = _dagScheduler + private[spark] def dagScheduler_=(ds: DAGScheduler): Unit = { + _dagScheduler = ds + } + + def applicationId: String = _applicationId + + def metricsSystem: MetricsSystem = if (_env != null) _env.metricsSystem else null + + private[spark] def eventLogger: Option[EventLoggingListener] = _eventLogger + + private[spark] def executorAllocationManager: Option[ExecutorAllocationManager] = + _executorAllocationManager + + private[spark] def cleaner: Option[ContextCleaner] = _cleaner + + private[spark] var checkpointDir: Option[String] = None + + // Thread Local variable that can be used by users to pass information down the stack + private val localProperties = new InheritableThreadLocal[Properties] { + override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() + } + + /* ------------------------------------------------------------------------------------- * + | Initialization. This code initializes the context in a manner that is exception-safe. | + | All internal fields holding state are initialized here, and any error prompts the | + | stop() method to be called. | + * ------------------------------------------------------------------------------------- */ + private def warnSparkMem(value: String): String = { logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " + "deprecated, please use spark.executor.memory instead.") value } - private[spark] val executorMemory = conf.getOption("spark.executor.memory") - .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) - .orElse(Option(System.getenv("SPARK_MEM")).map(warnSparkMem)) - .map(Utils.memoryStringToMb) - .getOrElse(512) + try { + _conf = config.clone() + _conf.validateSettings() - // Environment variables to pass to our executors. - private[spark] val executorEnvs = HashMap[String, String]() + if (!_conf.contains("spark.master")) { + throw new SparkException("A master URL must be set in your configuration") + } + if (!_conf.contains("spark.app.name")) { + throw new SparkException("An application name must be set in your configuration") + } - // Convert java options to env vars as a work around - // since we can't set env vars directly in sbt. - for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) - value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { - executorEnvs(envKey) = value - } - Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => - executorEnvs("SPARK_PREPEND_CLASSES") = v - } - // The Mesos scheduler backend relies on this environment variable to set executor memory. - // TODO: Set this only in the Mesos scheduler. - executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" - executorEnvs ++= conf.getExecutorEnv + if (_conf.getBoolean("spark.logConf", false)) { + logInfo("Spark configuration:\n" + _conf.toDebugString) + } - // Set SPARK_USER for user who is running SparkContext. - val sparkUser = Utils.getCurrentUserName() - executorEnvs("SPARK_USER") = sparkUser + // Set Spark driver host and port system properties + _conf.setIfMissing("spark.driver.host", Utils.localHostName()) + _conf.setIfMissing("spark.driver.port", "0") - // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will - // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) - private val heartbeatReceiver = env.rpcEnv.setupEndpoint( - HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) + _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - // Create and start the scheduler - private[spark] var (schedulerBackend, taskScheduler) = - SparkContext.createTaskScheduler(this, master) + _jars =_conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten + _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)) + .toSeq.flatten - heartbeatReceiver.send(TaskSchedulerIsSet) + _eventLogDir = + if (isEventLogEnabled) { + val unresolvedDir = conf.get("spark.eventLog.dir", EventLoggingListener.DEFAULT_LOG_DIR) + .stripSuffix("/") + Some(Utils.resolveURI(unresolvedDir)) + } else { + None + } - @volatile private[spark] var dagScheduler: DAGScheduler = _ - try { - dagScheduler = new DAGScheduler(this) - } catch { - case e: Exception => { - try { - stop() - } finally { - throw new SparkException("Error while constructing DAGScheduler", e) + _eventLogCodec = { + val compress = _conf.getBoolean("spark.eventLog.compress", false) + if (compress && isEventLogEnabled) { + Some(CompressionCodec.getCodecName(_conf)).map(CompressionCodec.getShortName) + } else { + None } } - } - // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's - // constructor - taskScheduler.start() + _conf.set("spark.tachyonStore.folderName", tachyonFolderName) - val applicationId: String = taskScheduler.applicationId() - conf.set("spark.app.id", applicationId) + if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") - env.blockManager.initialize(applicationId) + // Create the Spark execution environment (cache, map output tracker, etc) + _env = createSparkEnv(_conf, isLocal, listenerBus) + SparkEnv.set(_env) - val metricsSystem = env.metricsSystem + _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) - // The metrics system for Driver need to be set spark.app.id to app ID. - // So it should start after we get app ID from the task scheduler and set spark.app.id. - metricsSystem.start() - // Attach the driver metrics servlet handler to the web ui after the metrics system is started. - metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + _jobProgressListener = new JobProgressListener(_conf) + listenerBus.addListener(jobProgressListener) - // Optionally log Spark events - private[spark] val eventLogger: Option[EventLoggingListener] = { - if (isEventLogEnabled) { - val logger = - new EventLoggingListener(applicationId, eventLogDir.get, conf, hadoopConfiguration) - logger.start() - listenerBus.addListener(logger) - Some(logger) - } else None - } + _statusTracker = new SparkStatusTracker(this) - // Optionally scale number of executors dynamically based on workload. Exposed for testing. - private val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false) - private val dynamicAllocationTesting = conf.getBoolean("spark.dynamicAllocation.testing", false) - private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] = - if (dynamicAllocationEnabled) { - assert(supportDynamicAllocation, - "Dynamic allocation of executors is currently only supported in YARN mode") - Some(new ExecutorAllocationManager(this, listenerBus, conf)) - } else { - None + _progressBar = + if (_conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + Some(new ConsoleProgressBar(this)) + } else { + None + } + + _ui = + if (conf.getBoolean("spark.ui.enabled", true)) { + Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, + _env.securityManager,appName)) + } else { + // For tests, do not enable the UI + None + } + // Bind the UI before starting the task scheduler to communicate + // the bound port to the cluster manager properly + _ui.foreach(_.bind()) + + _hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(_conf) + + // Add each JAR given through the constructor + if (jars != null) { + jars.foreach(addJar) } - executorAllocationManager.foreach(_.start()) - private[spark] val cleaner: Option[ContextCleaner] = { - if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { - Some(new ContextCleaner(this)) - } else { - None + if (files != null) { + files.foreach(addFile) } - } - cleaner.foreach(_.start()) - setupAndStartListenerBus() - postEnvironmentUpdate() - postApplicationStart() + _executorMemory = _conf.getOption("spark.executor.memory") + .orElse(Option(System.getenv("SPARK_EXECUTOR_MEMORY"))) + .orElse(Option(System.getenv("SPARK_MEM")) + .map(warnSparkMem)) + .map(Utils.memoryStringToMb) + .getOrElse(512) + + // Convert java options to env vars as a work around + // since we can't set env vars directly in sbt. + for { (envKey, propKey) <- Seq(("SPARK_TESTING", "spark.testing")) + value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} { + executorEnvs(envKey) = value + } + Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => + executorEnvs("SPARK_PREPEND_CLASSES") = v + } + // The Mesos scheduler backend relies on this environment variable to set executor memory. + // TODO: Set this only in the Mesos scheduler. + executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m" + executorEnvs ++= _conf.getExecutorEnv + executorEnvs("SPARK_USER") = sparkUser + + // We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will + // retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640) + _heartbeatReceiver = env.rpcEnv.setupEndpoint( + HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) + + // Create and start the scheduler + val (sched, ts) = SparkContext.createTaskScheduler(this, master) + _schedulerBackend = sched + _taskScheduler = ts + _dagScheduler = new DAGScheduler(this) + _heartbeatReceiver.send(TaskSchedulerIsSet) + + // start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's + // constructor + _taskScheduler.start() + + _applicationId = _taskScheduler.applicationId() + _conf.set("spark.app.id", _applicationId) + _env.blockManager.initialize(_applicationId) + + // The metrics system for Driver need to be set spark.app.id to app ID. + // So it should start after we get app ID from the task scheduler and set spark.app.id. + metricsSystem.start() + // Attach the driver metrics servlet handler to the web ui after the metrics system is started. + metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler))) + + _eventLogger = + if (isEventLogEnabled) { + val logger = + new EventLoggingListener(_applicationId, _eventLogDir.get, _conf, _hadoopConfiguration) + logger.start() + listenerBus.addListener(logger) + Some(logger) + } else { + None + } - private[spark] var checkpointDir: Option[String] = None + // Optionally scale number of executors dynamically based on workload. Exposed for testing. + val dynamicAllocationEnabled = _conf.getBoolean("spark.dynamicAllocation.enabled", false) + _executorAllocationManager = + if (dynamicAllocationEnabled) { + assert(supportDynamicAllocation, + "Dynamic allocation of executors is currently only supported in YARN mode") + Some(new ExecutorAllocationManager(this, listenerBus, _conf)) + } else { + None + } + _executorAllocationManager.foreach(_.start()) - // Thread Local variable that can be used by users to pass information down the stack - private val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) - override protected def initialValue(): Properties = new Properties() + _cleaner = + if (_conf.getBoolean("spark.cleaner.referenceTracking", true)) { + Some(new ContextCleaner(this)) + } else { + None + } + _cleaner.foreach(_.start()) + + setupAndStartListenerBus() + postEnvironmentUpdate() + postApplicationStart() + + // Post init + _taskScheduler.postStartHook() + _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) + _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) + } catch { + case NonFatal(e) => + logError("Error initializing SparkContext.", e) + try { + stop() + } catch { + case NonFatal(inner) => + logError("Error stopping SparkContext after init error.", inner) + } finally { + throw e + } } /** @@ -544,19 +631,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, null) } - // Post init - taskScheduler.postStartHook() - - private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler) - private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager) - - private def initDriverMetrics() { - SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) - SparkEnv.get.metricsSystem.registerSource(blockManagerSource) - } - - initDriverMetrics() - // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. @@ -1146,7 +1220,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * this application is supported. This is currently only available for YARN. */ private[spark] def supportDynamicAllocation = - master.contains("yarn") || dynamicAllocationTesting + master.contains("yarn") || _conf.getBoolean("spark.dynamicAllocation.testing", false) /** * :: DeveloperApi :: @@ -1163,7 +1237,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * This is currently only supported in YARN mode. Return whether the request is received. */ private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { - assert(master.contains("yarn") || dynamicAllocationTesting, + assert(supportDynamicAllocation, "Requesting executors is currently only supported in YARN mode") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => @@ -1403,28 +1477,40 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli def stop() { // Use the stopping variable to ensure no contention for the stop scenario. // Still track the stopped variable for use elsewhere in the code. - if (!stopped.compareAndSet(false, true)) { logInfo("SparkContext already stopped.") return } - + postApplicationEnd() - ui.foreach(_.stop()) - env.metricsSystem.report() - metadataCleaner.cancel() - cleaner.foreach(_.stop()) - executorAllocationManager.foreach(_.stop()) - dagScheduler.stop() - dagScheduler = null - listenerBus.stop() - eventLogger.foreach(_.stop()) - env.rpcEnv.stop(heartbeatReceiver) - progressBar.foreach(_.stop()) - taskScheduler = null + _ui.foreach(_.stop()) + if (env != null) { + env.metricsSystem.report() + } + if (metadataCleaner != null) { + metadataCleaner.cancel() + } + _cleaner.foreach(_.stop()) + _executorAllocationManager.foreach(_.stop()) + if (_dagScheduler != null) { + _dagScheduler.stop() + _dagScheduler = null + } + if (_listenerBusStarted) { + listenerBus.stop() + _listenerBusStarted = false + } + _eventLogger.foreach(_.stop()) + if (env != null && _heartbeatReceiver != null) { + env.rpcEnv.stop(_heartbeatReceiver) + } + _progressBar.foreach(_.stop()) + _taskScheduler = null // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) + if (_env != null) { + _env.stop() + SparkEnv.set(null) + } SparkContext.clearActiveContext() logInfo("Successfully stopped SparkContext") } @@ -1749,6 +1835,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } listenerBus.start(this) + _listenerBusStarted = true } /** Post the application start event */ @@ -2152,7 +2239,7 @@ object SparkContext extends Logging { master match { case "local" => val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, 1) + val backend = new LocalBackend(sc.getConf, scheduler, 1) scheduler.initialize(backend) (backend, scheduler) @@ -2164,7 +2251,7 @@ object SparkContext extends Logging { throw new SparkException(s"Asked to run locally with $threadCount threads") } val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, threadCount) + val backend = new LocalBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) @@ -2174,7 +2261,7 @@ object SparkContext extends Logging { // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threadCount) + val backend = new LocalBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 516f619529c48..1b5fdeba28ee2 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -21,7 +21,7 @@ import java.io.File import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{ConcurrentHashMap, Executors, TimeUnit} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -60,8 +60,6 @@ private[spark] class Executor( private val conf = env.conf - @volatile private var isStopped = false - // No ip or host:port - just hostname Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") // must not have port specified. @@ -114,6 +112,10 @@ private[spark] class Executor( // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + // Executor for the heartbeat task. + private val heartbeater = Executors.newSingleThreadScheduledExecutor( + Utils.namedThreadFactory("driver-heartbeater")) + startDriverHeartbeater() def launchTask( @@ -138,7 +140,8 @@ private[spark] class Executor( def stop(): Unit = { env.metricsSystem.report() env.rpcEnv.stop(executorEndpoint) - isStopped = true + heartbeater.shutdown() + heartbeater.awaitTermination(10, TimeUnit.SECONDS) threadPool.shutdown() if (!isLocal) { env.stop() @@ -432,23 +435,17 @@ private[spark] class Executor( } /** - * Starts a thread to report heartbeat and partial metrics for active tasks to driver. - * This thread stops running when the executor is stopped. + * Schedules a task to report heartbeat and partial metrics for active tasks to driver. */ private def startDriverHeartbeater(): Unit = { val intervalMs = conf.getTimeAsMs("spark.executor.heartbeatInterval", "10s") - val thread = new Thread() { - override def run() { - // Sleep a random interval so the heartbeats don't end up in sync - Thread.sleep(intervalMs + (math.random * intervalMs).asInstanceOf[Int]) - while (!isStopped) { - reportHeartBeat() - Thread.sleep(intervalMs) - } - } + + // Wait a random interval so the heartbeats don't end up in sync + val initialDelay = intervalMs + (math.random * intervalMs).asInstanceOf[Int] + + val heartbeatTask = new Runnable() { + override def run(): Unit = Utils.logUncaughtExceptions(reportHeartBeat()) } - thread.setDaemon(true) - thread.setName("driver-heartbeater") - thread.start() + heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 8e3c30fc3d781..5a74c13b38bf7 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -86,11 +86,11 @@ private[nio] class ConnectionManager( conf.get("spark.network.timeout", "120s")) // Get the thread counts from the Spark Configuration. - // + // // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value, // we only query for the minimum value because we are using LinkedBlockingDeque. - // - // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is + // + // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min" // parameter is necessary. private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20) @@ -989,6 +989,7 @@ private[nio] class ConnectionManager( def stop() { ackTimeoutMonitor.stop() + selector.wakeup() selectorThread.interrupt() selectorThread.join() selector.close() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index ecc8bf189986d..13a52d836f32f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -142,11 +142,10 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - import sc.env.actorSystem.dispatcher sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL_MS milliseconds, SPECULATION_INTERVAL_MS milliseconds) { Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } - } + }(sc.env.actorSystem.dispatcher) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 70a477a6895cc..50ba0b9d5a612 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -20,12 +20,12 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer import java.util.concurrent.{Executors, TimeUnit} -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} +import org.apache.spark.util.Utils private case class ReviveOffers() @@ -71,11 +71,15 @@ private[spark] class LocalEndpoint( case KillTask(taskId, interruptThread) => executor.killTask(taskId, interruptThread) + } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case StopExecutor => executor.stop() + context.reply(true) } + def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) val tasks = scheduler.resourceOffers(offers).flatten @@ -104,8 +108,11 @@ private[spark] class LocalEndpoint( * master all run in the same JVM. It sits behind a TaskSchedulerImpl and handles launching tasks * on a single Executor (created by the LocalBackend) running locally. */ -private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: Int) - extends SchedulerBackend with ExecutorBackend { +private[spark] class LocalBackend( + conf: SparkConf, + scheduler: TaskSchedulerImpl, + val totalCores: Int) + extends SchedulerBackend with ExecutorBackend with Logging { private val appId = "local-" + System.currentTimeMillis var localEndpoint: RpcEndpointRef = null @@ -116,7 +123,7 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: } override def stop() { - localEndpoint.send(StopExecutor) + localEndpoint.sendWithReply(StopExecutor) } override def reviveOffers() { diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 6b3049b28cd5e..22acc270b983e 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -56,19 +56,13 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext wit // Min < 0 val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "-1") intercept[SparkException] { contexts += new SparkContext(conf1) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() // Max < 0 val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "-1") intercept[SparkException] { contexts += new SparkContext(conf2) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() // Both min and max, but min > max intercept[SparkException] { createSparkContext(2, 1) } - SparkEnv.get.stop() - SparkContext.clearActiveContext() // Both min and max, and min == max val sc1 = createSparkContext(1, 1) From 3ae37b93a7c299bd8b22a36248035bca5de3422f Mon Sep 17 00:00:00 2001 From: Jin Adachi Date: Thu, 16 Apr 2015 23:41:04 +0800 Subject: [PATCH 019/191] [SPARK-6694][SQL]SparkSQL CLI must be able to specify an option --database on the command line. SparkSQL CLI has an option --database as follows. But, the option --database is ignored. ``` $ spark-sql --help : CLI options: : --database Specify the database to use ``` Author: Jin Adachi Author: adachij Closes #5345 from adachij2002/SPARK-6694 and squashes the following commits: 8659084 [Jin Adachi] Merge branch 'master' of https://github.com/apache/spark into SPARK-6694 0301eb9 [Jin Adachi] Merge branch 'master' of https://github.com/apache/spark into SPARK-6694 df81086 [Jin Adachi] Modify code style. 846f83e [Jin Adachi] Merge branch 'master' of https://github.com/apache/spark into SPARK-6694 dbe8c63 [Jin Adachi] Change file permission to 644. 7b58f42 [Jin Adachi] Merge branch 'master' of https://github.com/apache/spark into SPARK-6694 c581d06 [Jin Adachi] Add an option --database test db56122 [Jin Adachi] Merge branch 'SPARK-6694' of https://github.com/adachij2002/spark into SPARK-6694 ee09fa5 [adachij] Merge branch 'master' into SPARK-6694 c804c03 [adachij] SparkSQL CLI must be able to specify an option --database on the command line. --- .../hive/thriftserver/SparkSQLCLIDriver.scala | 3 ++ .../sql/hive/thriftserver/CliSuite.scala | 45 +++++++++++++++---- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 62c061bef690a..85281c6d73a3b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -145,6 +145,9 @@ private[hive] object SparkSQLCLIDriver { case e: UnsupportedEncodingException => System.exit(3) } + // use the specified database if specified + cli.processSelectDatabase(sessionState); + // Execute -i init files (always in silent mode) cli.processInitFiles(sessionState) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 6d1d7c3a4e698..b070fa8eaa469 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -25,22 +25,31 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.Logging import org.apache.spark.util.Utils -class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { +class CliSuite extends FunSuite with BeforeAndAfter with Logging { + val warehousePath = Utils.createTempDir() + val metastorePath = Utils.createTempDir() + + before { + warehousePath.delete() + metastorePath.delete() + } + + after { + warehousePath.delete() + metastorePath.delete() + } + def runCliWithin( timeout: FiniteDuration, extraArgs: Seq[String] = Seq.empty)( - queriesAndExpectedAnswers: (String, String)*) { + queriesAndExpectedAnswers: (String, String)*): Unit = { val (queries, expectedAnswers) = queriesAndExpectedAnswers.unzip - val warehousePath = Utils.createTempDir() - warehousePath.delete() - val metastorePath = Utils.createTempDir() - metastorePath.delete() val cliScript = "../../bin/spark-sql".split("/").mkString(File.separator) val command = { @@ -95,8 +104,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { """.stripMargin, cause) throw cause } finally { - warehousePath.delete() - metastorePath.delete() process.destroy() } } @@ -124,4 +131,24 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging { test("Single command with -e") { runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") } + + test("Single command with --database") { + runCliWithin(1.minute)( + "CREATE DATABASE hive_test_db;" + -> "OK", + "USE hive_test_db;" + -> "OK", + "CREATE TABLE hive_test(key INT, val STRING);" + -> "OK", + "SHOW TABLES;" + -> "Time taken: " + ) + + runCliWithin(1.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( + "" + -> "OK", + "" + -> "hive_test" + ) + } } From ef3fb801ae971656ed9cd1b0ab95bc5a1548adbd Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 16 Apr 2015 13:45:55 -0500 Subject: [PATCH 020/191] [SPARK-6934][Core] Use 'spark.akka.askTimeout' for the ask timeout Fixed my mistake in #4588 Author: zsxwing Closes #5529 from zsxwing/SPARK-6934 and squashes the following commits: 9890b2d [zsxwing] Use 'spark.akka.askTimeout' for the ask timeout --- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index e259867c14040..f2c1c86af767e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -284,7 +284,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) - private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds + private[this] val defaultAskTimeout = conf.getLong("spark.akka.askTimeout", 30) seconds /** * return the address for the [[RpcEndpointRef]] @@ -304,7 +304,8 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * * This method only sends the message once and never retries. */ - def sendWithReply[T: ClassTag](message: Any): Future[T] = sendWithReply(message, defaultTimeout) + def sendWithReply[T: ClassTag](message: Any): Future[T] = + sendWithReply(message, defaultAskTimeout) /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to @@ -327,7 +328,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultTimeout) + def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultAskTimeout) /** * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a From 55f553a979db925aa0c3559f7e80b99d2bf3feb4 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 16 Apr 2015 13:06:34 -0700 Subject: [PATCH 021/191] [SPARK-6855] [SPARKR] Set R includes to get the right collate order. This prevents tools like devtools::document creating invalid collate orders Author: Shivaram Venkataraman Closes #5462 from shivaram/collate-order and squashes the following commits: f3db562 [Shivaram Venkataraman] Set R includes to get the right collate order. This prevents tools like devtools::document creating invalid collate orders --- R/pkg/DESCRIPTION | 6 +++--- R/pkg/R/DataFrame.R | 2 +- R/pkg/R/column.R | 2 +- R/pkg/R/group.R | 3 +++ R/pkg/R/jobj.R | 3 +++ R/pkg/R/pairRDD.R | 2 ++ 6 files changed, 13 insertions(+), 5 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 1842b97d43651..052f68c6c24e2 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -17,19 +17,19 @@ License: Apache License (== 2.0) Collate: 'generics.R' 'jobj.R' - 'SQLTypes.R' 'RDD.R' 'pairRDD.R' + 'SQLTypes.R' 'column.R' 'group.R' 'DataFrame.R' 'SQLContext.R' + 'backend.R' 'broadcast.R' + 'client.R' 'context.R' 'deserialize.R' 'serialize.R' 'sparkR.R' - 'backend.R' - 'client.R' 'utils.R' 'zzz.R' diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index feafd56909a67..044fdb4d01223 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -17,7 +17,7 @@ # DataFrame.R - DataFrame class and methods implemented in S4 OO classes -#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R +#' @include generics.R jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R NULL setOldClass("jobj") diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index e196305186b9a..b282001d8b6b5 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -17,7 +17,7 @@ # Column Class -#' @include generics.R jobj.R +#' @include generics.R jobj.R SQLTypes.R NULL setOldClass("jobj") diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 09fc0a7abe48a..855fbdfc7c4ca 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -17,6 +17,9 @@ # group.R - GroupedData class and methods implemented in S4 OO classes +#' @include generics.R jobj.R SQLTypes.R column.R +NULL + setOldClass("jobj") #' @title S4 class that represents a GroupedData diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R index 4180f146b7fbc..a8a25230b636d 100644 --- a/R/pkg/R/jobj.R +++ b/R/pkg/R/jobj.R @@ -18,6 +18,9 @@ # References to objects that exist on the JVM backend # are maintained using the jobj. +#' @include generics.R +NULL + # Maintain a reference count of Java object references # This allows us to GC the java object when it is safe .validJobjs <- new.env(parent = emptyenv()) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 739d399f0820f..5d64822859d1f 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -16,6 +16,8 @@ # # Operations supported on RDDs contains pairs (i.e key, value) +#' @include generics.R jobj.R RDD.R +NULL ############ Actions and Transformations ############ From 04e44b37cc04f62fbf9e08c7076349e0a4d12ea8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 16 Apr 2015 16:20:57 -0700 Subject: [PATCH 022/191] [SPARK-4897] [PySpark] Python 3 support This PR update PySpark to support Python 3 (tested with 3.4). Known issue: unpickle array from Pyrolite is broken in Python 3, those tests are skipped. TODO: ec2/spark-ec2.py is not fully tested with python3. Author: Davies Liu Author: twneale Author: Josh Rosen Closes #5173 from davies/python3 and squashes the following commits: d7d6323 [Davies Liu] fix tests 6c52a98 [Davies Liu] fix mllib test 99e334f [Davies Liu] update timeout b716610 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 cafd5ec [Davies Liu] adddress comments from @mengxr bf225d7 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 179fc8d [Davies Liu] tuning flaky tests 8c8b957 [Davies Liu] fix ResourceWarning in Python 3 5c57c95 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 4006829 [Davies Liu] fix test 2fc0066 [Davies Liu] add python3 path 71535e9 [Davies Liu] fix xrange and divide 5a55ab4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 125f12c [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 ed498c8 [Davies Liu] fix compatibility with python 3 820e649 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 e8ce8c9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 ad7c374 [Davies Liu] fix mllib test and warning ef1fc2f [Davies Liu] fix tests 4eee14a [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 20112ff [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 59bb492 [Davies Liu] fix tests 1da268c [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 ca0fdd3 [Davies Liu] fix code style 9563a15 [Davies Liu] add imap back for python 2 0b1ec04 [Davies Liu] make python examples work with Python 3 d2fd566 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 a716d34 [Davies Liu] test with python 3.4 f1700e8 [Davies Liu] fix test in python3 671b1db [Davies Liu] fix test in python3 692ff47 [Davies Liu] fix flaky test 7b9699f [Davies Liu] invalidate import cache for Python 3.3+ 9c58497 [Davies Liu] fix kill worker 309bfbf [Davies Liu] keep compatibility 5707476 [Davies Liu] cleanup, fix hash of string in 3.3+ 8662d5b [Davies Liu] Merge branch 'master' of github.com:apache/spark into python3 f53e1f0 [Davies Liu] fix tests 70b6b73 [Davies Liu] compile ec2/spark_ec2.py in python 3 a39167e [Davies Liu] support customize class in __main__ 814c77b [Davies Liu] run unittests with python 3 7f4476e [Davies Liu] mllib tests passed d737924 [Davies Liu] pass ml tests 375ea17 [Davies Liu] SQL tests pass 6cc42a9 [Davies Liu] rename 431a8de [Davies Liu] streaming tests pass 78901a7 [Davies Liu] fix hash of serializer in Python 3 24b2f2e [Davies Liu] pass all RDD tests 35f48fe [Davies Liu] run future again 1eebac2 [Davies Liu] fix conflict in ec2/spark_ec2.py 6e3c21d [Davies Liu] make cloudpickle work with Python3 2fb2db3 [Josh Rosen] Guard more changes behind sys.version; still doesn't run 1aa5e8f [twneale] Turned out `pickle.DictionaryType is dict` == True, so swapped it out 7354371 [twneale] buffer --> memoryview I'm not super sure if this a valid change, but the 2.7 docs recommend using memoryview over buffer where possible, so hoping it'll work. b69ccdf [twneale] Uses the pure python pickle._Pickler instead of c-extension _pickle.Pickler. It appears pyspark 2.7 uses the pure python pickler as well, so this shouldn't degrade pickling performance (?). f40d925 [twneale] xrange --> range e104215 [twneale] Replaces 2.7 types.InstsanceType with 3.4 `object`....could be horribly wrong depending on how types.InstanceType is used elsewhere in the package--see http://bugs.python.org/issue8206 79de9d0 [twneale] Replaces python2.7 `file` with 3.4 _io.TextIOWrapper 2adb42d [Josh Rosen] Fix up some import differences between Python 2 and 3 854be27 [Josh Rosen] Run `futurize` on Python code: 7c5b4ce [Josh Rosen] Remove Python 3 check in shell.py. --- bin/pyspark | 1 + bin/spark-submit | 3 + bin/spark-submit2.cmd | 3 + dev/run-tests | 2 + dev/run-tests-jenkins | 2 +- ec2/spark_ec2.py | 262 ++++---- examples/src/main/python/als.py | 15 +- examples/src/main/python/avro_inputformat.py | 9 +- .../src/main/python/cassandra_inputformat.py | 8 +- .../src/main/python/cassandra_outputformat.py | 6 +- examples/src/main/python/hbase_inputformat.py | 8 +- .../src/main/python/hbase_outputformat.py | 6 +- examples/src/main/python/kmeans.py | 11 +- .../src/main/python/logistic_regression.py | 20 +- .../ml/simple_text_classification_pipeline.py | 20 +- .../src/main/python/mllib/correlations.py | 19 +- .../src/main/python/mllib/dataset_example.py | 13 +- .../main/python/mllib/decision_tree_runner.py | 29 +- .../python/mllib/gaussian_mixture_model.py | 9 +- .../python/mllib/gradient_boosted_trees.py | 7 +- examples/src/main/python/mllib/kmeans.py | 5 +- .../main/python/mllib/logistic_regression.py | 9 +- .../python/mllib/random_forest_example.py | 9 +- .../python/mllib/random_rdd_generation.py | 21 +- .../src/main/python/mllib/sampled_rdds.py | 29 +- examples/src/main/python/mllib/word2vec.py | 5 +- examples/src/main/python/pagerank.py | 16 +- .../src/main/python/parquet_inputformat.py | 7 +- examples/src/main/python/pi.py | 5 +- examples/src/main/python/sort.py | 6 +- examples/src/main/python/sql.py | 4 +- examples/src/main/python/status_api_demo.py | 10 +- .../main/python/streaming/hdfs_wordcount.py | 3 +- .../main/python/streaming/kafka_wordcount.py | 3 +- .../python/streaming/network_wordcount.py | 3 +- .../recoverable_network_wordcount.py | 11 +- .../python/streaming/sql_network_wordcount.py | 5 +- .../streaming/stateful_network_wordcount.py | 3 +- .../src/main/python/transitive_closure.py | 10 +- examples/src/main/python/wordcount.py | 6 +- .../MatrixFactorizationModelWrapper.scala | 9 +- .../mllib/api/python/PythonMLLibAPI.scala | 39 +- python/pyspark/accumulators.py | 9 +- python/pyspark/broadcast.py | 37 +- python/pyspark/cloudpickle.py | 577 +++++------------- python/pyspark/conf.py | 9 +- python/pyspark/context.py | 42 +- python/pyspark/daemon.py | 36 +- python/pyspark/heapq3.py | 24 +- python/pyspark/java_gateway.py | 2 +- python/pyspark/join.py | 1 + python/pyspark/ml/classification.py | 4 +- python/pyspark/ml/feature.py | 22 +- python/pyspark/ml/param/__init__.py | 8 +- .../ml/param/_shared_params_code_gen.py | 10 +- python/pyspark/mllib/__init__.py | 11 +- python/pyspark/mllib/classification.py | 7 +- python/pyspark/mllib/clustering.py | 18 +- python/pyspark/mllib/common.py | 19 +- python/pyspark/mllib/feature.py | 18 +- python/pyspark/mllib/fpm.py | 2 + python/pyspark/mllib/linalg.py | 48 +- python/pyspark/mllib/rand.py | 33 +- python/pyspark/mllib/recommendation.py | 7 +- python/pyspark/mllib/stat/_statistics.py | 25 +- python/pyspark/mllib/tests.py | 20 +- python/pyspark/mllib/tree.py | 15 +- python/pyspark/mllib/util.py | 26 +- python/pyspark/profiler.py | 10 +- python/pyspark/rdd.py | 189 +++--- python/pyspark/rddsampler.py | 4 +- python/pyspark/serializers.py | 101 ++- python/pyspark/shell.py | 16 +- python/pyspark/shuffle.py | 126 ++-- python/pyspark/sql/__init__.py | 15 +- python/pyspark/sql/{types.py => _types.py} | 49 +- python/pyspark/sql/context.py | 32 +- python/pyspark/sql/dataframe.py | 63 +- python/pyspark/sql/functions.py | 6 +- python/pyspark/sql/tests.py | 11 +- python/pyspark/statcounter.py | 4 +- python/pyspark/streaming/context.py | 5 +- python/pyspark/streaming/dstream.py | 51 +- python/pyspark/streaming/kafka.py | 8 +- python/pyspark/streaming/tests.py | 39 +- python/pyspark/streaming/util.py | 6 +- python/pyspark/tests.py | 327 +++++----- python/pyspark/worker.py | 16 +- python/run-tests | 15 +- python/test_support/userlib-0.1-py2.7.egg | Bin 1945 -> 0 bytes python/test_support/userlib-0.1.zip | Bin 0 -> 668 bytes 91 files changed, 1398 insertions(+), 1396 deletions(-) rename python/pyspark/sql/{types.py => _types.py} (97%) delete mode 100644 python/test_support/userlib-0.1-py2.7.egg create mode 100644 python/test_support/userlib-0.1.zip diff --git a/bin/pyspark b/bin/pyspark index 776b28dc41099..8acad6113797d 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -89,6 +89,7 @@ export PYTHONSTARTUP="$SPARK_HOME/python/pyspark/shell.py" if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR + export PYTHONHASHSEED=0 if [[ -n "$PYSPARK_DOC_TEST" ]]; then exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 else diff --git a/bin/spark-submit b/bin/spark-submit index bcff78edd51ca..0e0afe71a0f05 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -19,6 +19,9 @@ SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +# disable randomized hash for string in Python 3.3+ +export PYTHONHASHSEED=0 + # Only define a usage function if an upstream script hasn't done so. if ! type -t usage >/dev/null 2>&1; then usage() { diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd index 08ddb185742d2..d3fc4a5cc3f6e 100644 --- a/bin/spark-submit2.cmd +++ b/bin/spark-submit2.cmd @@ -20,6 +20,9 @@ rem rem This is the entry point for running Spark submit. To avoid polluting the rem environment, it just launches a new cmd to do the real work. +rem disable randomized hash for string in Python 3.3+ +set PYTHONHASHSEED=0 + set CLASS=org.apache.spark.deploy.SparkSubmit call %~dp0spark-class2.cmd %CLASS% %* set SPARK_ERROR_LEVEL=%ERRORLEVEL% diff --git a/dev/run-tests b/dev/run-tests index bb21ab6c9aa04..861d1671182c2 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -235,6 +235,8 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS +# add path for python 3 in jenkins +export PATH="${PATH}:/home/anaonda/envs/py3k/bin" ./python/run-tests echo "" diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3c1c91a111357..030f2cdddb350 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -47,7 +47,7 @@ COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" -TESTS_TIMEOUT="120m" # format: http://linux.die.net/man/1/timeout +TESTS_TIMEOUT="150m" # format: http://linux.die.net/man/1/timeout # Array to capture all tests to run on the pull request. These tests are held under the #+ dev/tests/ directory. diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 0c1f24761d0de..87c0818279713 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -19,7 +19,7 @@ # limitations under the License. # -from __future__ import with_statement +from __future__ import with_statement, print_function import hashlib import itertools @@ -37,12 +37,17 @@ import tempfile import textwrap import time -import urllib2 import warnings from datetime import datetime from optparse import OptionParser from sys import stderr +if sys.version < "3": + from urllib2 import urlopen, Request, HTTPError +else: + from urllib.request import urlopen, Request + from urllib.error import HTTPError + SPARK_EC2_VERSION = "1.2.1" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -88,10 +93,10 @@ def setup_external_libs(libs): SPARK_EC2_LIB_DIR = os.path.join(SPARK_EC2_DIR, "lib") if not os.path.exists(SPARK_EC2_LIB_DIR): - print "Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format( + print("Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format( path=SPARK_EC2_LIB_DIR - ) - print "This should be a one-time operation." + )) + print("This should be a one-time operation.") os.mkdir(SPARK_EC2_LIB_DIR) for lib in libs: @@ -100,8 +105,8 @@ def setup_external_libs(libs): if not os.path.isdir(lib_dir): tgz_file_path = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name + ".tar.gz") - print " - Downloading {lib}...".format(lib=lib["name"]) - download_stream = urllib2.urlopen( + print(" - Downloading {lib}...".format(lib=lib["name"])) + download_stream = urlopen( "{prefix}/{first_letter}/{lib_name}/{lib_name}-{lib_version}.tar.gz".format( prefix=PYPI_URL_PREFIX, first_letter=lib["name"][:1], @@ -113,13 +118,13 @@ def setup_external_libs(libs): tgz_file.write(download_stream.read()) with open(tgz_file_path) as tar: if hashlib.md5(tar.read()).hexdigest() != lib["md5"]: - print >> stderr, "ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]) + print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr) sys.exit(1) tar = tarfile.open(tgz_file_path) tar.extractall(path=SPARK_EC2_LIB_DIR) tar.close() os.remove(tgz_file_path) - print " - Finished downloading {lib}.".format(lib=lib["name"]) + print(" - Finished downloading {lib}.".format(lib=lib["name"])) sys.path.insert(1, lib_dir) @@ -299,12 +304,12 @@ def parse_args(): if home_dir is None or not os.path.isfile(home_dir + '/.boto'): if not os.path.isfile('/etc/boto.cfg'): if os.getenv('AWS_ACCESS_KEY_ID') is None: - print >> stderr, ("ERROR: The environment variable AWS_ACCESS_KEY_ID " + - "must be set") + print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", + file=stderr) sys.exit(1) if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print >> stderr, ("ERROR: The environment variable AWS_SECRET_ACCESS_KEY " + - "must be set") + print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", + file=stderr) sys.exit(1) return (opts, action, cluster_name) @@ -316,7 +321,7 @@ def get_or_make_group(conn, name, vpc_id): if len(group) > 0: return group[0] else: - print "Creating security group " + name + print("Creating security group " + name) return conn.create_security_group(name, "Spark EC2 group", vpc_id) @@ -324,18 +329,19 @@ def get_validate_spark_version(version, repo): if "." in version: version = version.replace("v", "") if version not in VALID_SPARK_VERSIONS: - print >> stderr, "Don't know about Spark version: {v}".format(v=version) + print("Don't know about Spark version: {v}".format(v=version), file=stderr) sys.exit(1) return version else: github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version) - request = urllib2.Request(github_commit_url) + request = Request(github_commit_url) request.get_method = lambda: 'HEAD' try: - response = urllib2.urlopen(request) - except urllib2.HTTPError, e: - print >> stderr, "Couldn't validate Spark commit: {url}".format(url=github_commit_url) - print >> stderr, "Received HTTP response code of {code}.".format(code=e.code) + response = urlopen(request) + except HTTPError as e: + print("Couldn't validate Spark commit: {url}".format(url=github_commit_url), + file=stderr) + print("Received HTTP response code of {code}.".format(code=e.code), file=stderr) sys.exit(1) return version @@ -394,8 +400,7 @@ def get_spark_ami(opts): instance_type = EC2_INSTANCE_TYPES[opts.instance_type] else: instance_type = "pvm" - print >> stderr,\ - "Don't recognize %s, assuming type is pvm" % opts.instance_type + print("Don't recognize %s, assuming type is pvm" % opts.instance_type, file=stderr) # URL prefix from which to fetch AMI information ami_prefix = "{r}/{b}/ami-list".format( @@ -404,10 +409,10 @@ def get_spark_ami(opts): ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) try: - ami = urllib2.urlopen(ami_path).read().strip() - print "Spark AMI: " + ami + ami = urlopen(ami_path).read().strip() + print("Spark AMI: " + ami) except: - print >> stderr, "Could not resolve AMI at: " + ami_path + print("Could not resolve AMI at: " + ami_path, file=stderr) sys.exit(1) return ami @@ -419,11 +424,11 @@ def get_spark_ami(opts): # Fails if there already instances running in the cluster's groups. def launch_cluster(conn, opts, cluster_name): if opts.identity_file is None: - print >> stderr, "ERROR: Must provide an identity file (-i) for ssh connections." + print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr) sys.exit(1) if opts.key_pair is None: - print >> stderr, "ERROR: Must provide a key pair name (-k) to use on instances." + print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr) sys.exit(1) user_data_content = None @@ -431,7 +436,7 @@ def launch_cluster(conn, opts, cluster_name): with open(opts.user_data) as user_data_file: user_data_content = user_data_file.read() - print "Setting up security groups..." + print("Setting up security groups...") master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id) slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id) authorized_address = opts.authorized_address @@ -497,8 +502,8 @@ def launch_cluster(conn, opts, cluster_name): existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, die_on_error=False) if existing_slaves or (existing_masters and not opts.use_existing_master): - print >> stderr, ("ERROR: There are already instances running in " + - "group %s or %s" % (master_group.name, slave_group.name)) + print("ERROR: There are already instances running in group %s or %s" % + (master_group.name, slave_group.name), file=stderr) sys.exit(1) # Figure out Spark AMI @@ -511,12 +516,12 @@ def launch_cluster(conn, opts, cluster_name): additional_group_ids = [sg.id for sg in conn.get_all_security_groups() if opts.additional_security_group in (sg.name, sg.id)] - print "Launching instances..." + print("Launching instances...") try: image = conn.get_all_images(image_ids=[opts.ami])[0] except: - print >> stderr, "Could not find AMI " + opts.ami + print("Could not find AMI " + opts.ami, file=stderr) sys.exit(1) # Create block device mapping so that we can add EBS volumes if asked to. @@ -542,8 +547,8 @@ def launch_cluster(conn, opts, cluster_name): # Launch slaves if opts.spot_price is not None: # Launch spot instances with the requested price - print ("Requesting %d slaves as spot instances with price $%.3f" % - (opts.slaves, opts.spot_price)) + print("Requesting %d slaves as spot instances with price $%.3f" % + (opts.slaves, opts.spot_price)) zones = get_zones(conn, opts) num_zones = len(zones) i = 0 @@ -566,7 +571,7 @@ def launch_cluster(conn, opts, cluster_name): my_req_ids += [req.id for req in slave_reqs] i += 1 - print "Waiting for spot instances to be granted..." + print("Waiting for spot instances to be granted...") try: while True: time.sleep(10) @@ -579,24 +584,24 @@ def launch_cluster(conn, opts, cluster_name): if i in id_to_req and id_to_req[i].state == "active": active_instance_ids.append(id_to_req[i].instance_id) if len(active_instance_ids) == opts.slaves: - print "All %d slaves granted" % opts.slaves + print("All %d slaves granted" % opts.slaves) reservations = conn.get_all_reservations(active_instance_ids) slave_nodes = [] for r in reservations: slave_nodes += r.instances break else: - print "%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves) + print("%d of %d slaves granted, waiting longer" % ( + len(active_instance_ids), opts.slaves)) except: - print "Canceling spot instance requests" + print("Canceling spot instance requests") conn.cancel_spot_instance_requests(my_req_ids) # Log a warning if any of these requests actually launched instances: (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) running = len(master_nodes) + len(slave_nodes) if running: - print >> stderr, ("WARNING: %d instances are still running" % running) + print(("WARNING: %d instances are still running" % running), file=stderr) sys.exit(0) else: # Launch non-spot instances @@ -618,16 +623,16 @@ def launch_cluster(conn, opts, cluster_name): placement_group=opts.placement_group, user_data=user_data_content) slave_nodes += slave_res.instances - print "Launched {s} slave{plural_s} in {z}, regid = {r}".format( - s=num_slaves_this_zone, - plural_s=('' if num_slaves_this_zone == 1 else 's'), - z=zone, - r=slave_res.id) + print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( + s=num_slaves_this_zone, + plural_s=('' if num_slaves_this_zone == 1 else 's'), + z=zone, + r=slave_res.id)) i += 1 # Launch or resume masters if existing_masters: - print "Starting master..." + print("Starting master...") for inst in existing_masters: if inst.state not in ["shutting-down", "terminated"]: inst.start() @@ -650,10 +655,10 @@ def launch_cluster(conn, opts, cluster_name): user_data=user_data_content) master_nodes = master_res.instances - print "Launched master in %s, regid = %s" % (zone, master_res.id) + print("Launched master in %s, regid = %s" % (zone, master_res.id)) # This wait time corresponds to SPARK-4983 - print "Waiting for AWS to propagate instance metadata..." + print("Waiting for AWS to propagate instance metadata...") time.sleep(5) # Give the instances descriptive names for master in master_nodes: @@ -674,8 +679,8 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): Get the EC2 instances in an existing cluster if available. Returns a tuple of lists of EC2 instance objects for the masters and slaves. """ - print "Searching for existing cluster {c} in region {r}...".format( - c=cluster_name, r=opts.region) + print("Searching for existing cluster {c} in region {r}...".format( + c=cluster_name, r=opts.region)) def get_instances(group_names): """ @@ -693,16 +698,15 @@ def get_instances(group_names): slave_instances = get_instances([cluster_name + "-slaves"]) if any((master_instances, slave_instances)): - print "Found {m} master{plural_m}, {s} slave{plural_s}.".format( - m=len(master_instances), - plural_m=('' if len(master_instances) == 1 else 's'), - s=len(slave_instances), - plural_s=('' if len(slave_instances) == 1 else 's')) + print("Found {m} master{plural_m}, {s} slave{plural_s}.".format( + m=len(master_instances), + plural_m=('' if len(master_instances) == 1 else 's'), + s=len(slave_instances), + plural_s=('' if len(slave_instances) == 1 else 's'))) if not master_instances and die_on_error: - print >> sys.stderr, \ - "ERROR: Could not find a master for cluster {c} in region {r}.".format( - c=cluster_name, r=opts.region) + print("ERROR: Could not find a master for cluster {c} in region {r}.".format( + c=cluster_name, r=opts.region), file=sys.stderr) sys.exit(1) return (master_instances, slave_instances) @@ -713,7 +717,7 @@ def get_instances(group_names): def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): master = get_dns_name(master_nodes[0], opts.private_ips) if deploy_ssh_key: - print "Generating cluster's SSH key on master..." + print("Generating cluster's SSH key on master...") key_setup = """ [ -f ~/.ssh/id_rsa ] || (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa && @@ -721,10 +725,10 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): """ ssh(master, opts, key_setup) dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) - print "Transferring cluster's SSH key to slaves..." + print("Transferring cluster's SSH key to slaves...") for slave in slave_nodes: slave_address = get_dns_name(slave, opts.private_ips) - print slave_address + print(slave_address) ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', @@ -738,8 +742,8 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - print "Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( - r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch) + print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( + r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch)) ssh( host=master, opts=opts, @@ -749,7 +753,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): b=opts.spark_ec2_git_branch) ) - print "Deploying files to master..." + print("Deploying files to master...") deploy_files( conn=conn, root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", @@ -760,25 +764,25 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): ) if opts.deploy_root_dir is not None: - print "Deploying {s} to master...".format(s=opts.deploy_root_dir) + print("Deploying {s} to master...".format(s=opts.deploy_root_dir)) deploy_user_files( root_dir=opts.deploy_root_dir, opts=opts, master_nodes=master_nodes ) - print "Running setup on master..." + print("Running setup on master...") setup_spark_cluster(master, opts) - print "Done!" + print("Done!") def setup_spark_cluster(master, opts): ssh(master, opts, "chmod u+x spark-ec2/setup.sh") ssh(master, opts, "spark-ec2/setup.sh") - print "Spark standalone cluster started at http://%s:8080" % master + print("Spark standalone cluster started at http://%s:8080" % master) if opts.ganglia: - print "Ganglia started at http://%s:5080/ganglia" % master + print("Ganglia started at http://%s:5080/ganglia" % master) def is_ssh_available(host, opts, print_ssh_output=True): @@ -795,7 +799,7 @@ def is_ssh_available(host, opts, print_ssh_output=True): if s.returncode != 0 and print_ssh_output: # extra leading newline is for spacing in wait_for_cluster_state() - print textwrap.dedent("""\n + print(textwrap.dedent("""\n Warning: SSH connection error. (This could be temporary.) Host: {h} SSH return code: {r} @@ -804,7 +808,7 @@ def is_ssh_available(host, opts, print_ssh_output=True): h=host, r=s.returncode, o=cmd_output.strip() - ) + )) return s.returncode == 0 @@ -865,10 +869,10 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): sys.stdout.write("\n") end_time = datetime.now() - print "Cluster is now in '{s}' state. Waited {t} seconds.".format( + print("Cluster is now in '{s}' state. Waited {t} seconds.".format( s=cluster_state, t=(end_time - start_time).seconds - ) + )) # Get number of local disks available for a given EC2 instance type. @@ -916,8 +920,8 @@ def get_num_disks(instance_type): if instance_type in disks_by_instance: return disks_by_instance[instance_type] else: - print >> stderr, ("WARNING: Don't know number of disks on instance type %s; assuming 1" - % instance_type) + print("WARNING: Don't know number of disks on instance type %s; assuming 1" + % instance_type, file=stderr) return 1 @@ -951,7 +955,7 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): # Spark-only custom deploy spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version) tachyon_v = "" - print "Deploying Spark via git hash; Tachyon won't be set up" + print("Deploying Spark via git hash; Tachyon won't be set up") modules = filter(lambda x: x != "tachyon", modules) master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] @@ -1067,8 +1071,8 @@ def ssh(host, opts, command): "--key-pair parameters and try again.".format(host)) else: raise e - print >> stderr, \ - "Error executing remote command, retrying after 30 seconds: {0}".format(e) + print("Error executing remote command, retrying after 30 seconds: {0}".format(e), + file=stderr) time.sleep(30) tries = tries + 1 @@ -1107,8 +1111,8 @@ def ssh_write(host, opts, command, arguments): elif tries > 5: raise RuntimeError("ssh_write failed with error %s" % proc.returncode) else: - print >> stderr, \ - "Error {0} while executing remote command, retrying after 30 seconds".format(status) + print("Error {0} while executing remote command, retrying after 30 seconds". + format(status), file=stderr) time.sleep(30) tries = tries + 1 @@ -1162,42 +1166,41 @@ def real_main(): if opts.identity_file is not None: if not os.path.exists(opts.identity_file): - print >> stderr,\ - "ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file) + print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file), + file=stderr) sys.exit(1) file_mode = os.stat(opts.identity_file).st_mode if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00': - print >> stderr, "ERROR: The identity file must be accessible only by you." - print >> stderr, 'You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file) + print("ERROR: The identity file must be accessible only by you.", file=stderr) + print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file), + file=stderr) sys.exit(1) if opts.instance_type not in EC2_INSTANCE_TYPES: - print >> stderr, "Warning: Unrecognized EC2 instance type for instance-type: {t}".format( - t=opts.instance_type) + print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format( + t=opts.instance_type), file=stderr) if opts.master_instance_type != "": if opts.master_instance_type not in EC2_INSTANCE_TYPES: - print >> stderr, \ - "Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( - t=opts.master_instance_type) + print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( + t=opts.master_instance_type), file=stderr) # Since we try instance types even if we can't resolve them, we check if they resolve first # and, if they do, see if they resolve to the same virtualization type. if opts.instance_type in EC2_INSTANCE_TYPES and \ opts.master_instance_type in EC2_INSTANCE_TYPES: if EC2_INSTANCE_TYPES[opts.instance_type] != \ EC2_INSTANCE_TYPES[opts.master_instance_type]: - print >> stderr, \ - "Error: spark-ec2 currently does not support having a master and slaves " + \ - "with different AMI virtualization types." - print >> stderr, "master instance virtualization type: {t}".format( - t=EC2_INSTANCE_TYPES[opts.master_instance_type]) - print >> stderr, "slave instance virtualization type: {t}".format( - t=EC2_INSTANCE_TYPES[opts.instance_type]) + print("Error: spark-ec2 currently does not support having a master and slaves " + "with different AMI virtualization types.", file=stderr) + print("master instance virtualization type: {t}".format( + t=EC2_INSTANCE_TYPES[opts.master_instance_type]), file=stderr) + print("slave instance virtualization type: {t}".format( + t=EC2_INSTANCE_TYPES[opts.instance_type]), file=stderr) sys.exit(1) if opts.ebs_vol_num > 8: - print >> stderr, "ebs-vol-num cannot be greater than 8" + print("ebs-vol-num cannot be greater than 8", file=stderr) sys.exit(1) # Prevent breaking ami_prefix (/, .git and startswith checks) @@ -1206,23 +1209,22 @@ def real_main(): opts.spark_ec2_git_repo.endswith(".git") or \ not opts.spark_ec2_git_repo.startswith("https://github.com") or \ not opts.spark_ec2_git_repo.endswith("spark-ec2"): - print >> stderr, "spark-ec2-git-repo must be a github repo and it must not have a " \ - "trailing / or .git. " \ - "Furthermore, we currently only support forks named spark-ec2." + print("spark-ec2-git-repo must be a github repo and it must not have a trailing / or .git. " + "Furthermore, we currently only support forks named spark-ec2.", file=stderr) sys.exit(1) if not (opts.deploy_root_dir is None or (os.path.isabs(opts.deploy_root_dir) and os.path.isdir(opts.deploy_root_dir) and os.path.exists(opts.deploy_root_dir))): - print >> stderr, "--deploy-root-dir must be an absolute path to a directory that exists " \ - "on the local file system" + print("--deploy-root-dir must be an absolute path to a directory that exists " + "on the local file system", file=stderr) sys.exit(1) try: conn = ec2.connect_to_region(opts.region) except Exception as e: - print >> stderr, (e) + print((e), file=stderr) sys.exit(1) # Select an AZ at random if it was not specified. @@ -1231,7 +1233,7 @@ def real_main(): if action == "launch": if opts.slaves <= 0: - print >> sys.stderr, "ERROR: You have to start at least 1 slave" + print("ERROR: You have to start at least 1 slave", file=sys.stderr) sys.exit(1) if opts.resume: (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) @@ -1250,18 +1252,18 @@ def real_main(): conn, opts, cluster_name, die_on_error=False) if any(master_nodes + slave_nodes): - print "The following instances will be terminated:" + print("The following instances will be terminated:") for inst in master_nodes + slave_nodes: - print "> %s" % get_dns_name(inst, opts.private_ips) - print "ALL DATA ON ALL NODES WILL BE LOST!!" + print("> %s" % get_dns_name(inst, opts.private_ips)) + print("ALL DATA ON ALL NODES WILL BE LOST!!") msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name) response = raw_input(msg) if response == "y": - print "Terminating master..." + print("Terminating master...") for inst in master_nodes: inst.terminate() - print "Terminating slaves..." + print("Terminating slaves...") for inst in slave_nodes: inst.terminate() @@ -1274,16 +1276,16 @@ def real_main(): cluster_instances=(master_nodes + slave_nodes), cluster_state='terminated' ) - print "Deleting security groups (this will take some time)..." + print("Deleting security groups (this will take some time)...") attempt = 1 while attempt <= 3: - print "Attempt %d" % attempt + print("Attempt %d" % attempt) groups = [g for g in conn.get_all_security_groups() if g.name in group_names] success = True # Delete individual rules in all groups before deleting groups to # remove dependencies between them for group in groups: - print "Deleting rules in security group " + group.name + print("Deleting rules in security group " + group.name) for rule in group.rules: for grant in rule.grants: success &= group.revoke(ip_protocol=rule.ip_protocol, @@ -1298,10 +1300,10 @@ def real_main(): try: # It is needed to use group_id to make it work with VPC conn.delete_security_group(group_id=group.id) - print "Deleted security group %s" % group.name + print("Deleted security group %s" % group.name) except boto.exception.EC2ResponseError: success = False - print "Failed to delete security group %s" % group.name + print("Failed to delete security group %s" % group.name) # Unfortunately, group.revoke() returns True even if a rule was not # deleted, so this needs to be rerun if something fails @@ -1311,17 +1313,16 @@ def real_main(): attempt += 1 if not success: - print "Failed to delete all security groups after 3 tries." - print "Try re-running in a few minutes." + print("Failed to delete all security groups after 3 tries.") + print("Try re-running in a few minutes.") elif action == "login": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) if not master_nodes[0].public_dns_name and not opts.private_ips: - print "Master has no public DNS name. Maybe you meant to specify " \ - "--private-ips?" + print("Master has no public DNS name. Maybe you meant to specify --private-ips?") else: master = get_dns_name(master_nodes[0], opts.private_ips) - print "Logging into master " + master + "..." + print("Logging into master " + master + "...") proxy_opt = [] if opts.proxy_port is not None: proxy_opt = ['-D', opts.proxy_port] @@ -1336,19 +1337,18 @@ def real_main(): if response == "y": (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - print "Rebooting slaves..." + print("Rebooting slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: - print "Rebooting " + inst.id + print("Rebooting " + inst.id) inst.reboot() elif action == "get-master": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) if not master_nodes[0].public_dns_name and not opts.private_ips: - print "Master has no public DNS name. Maybe you meant to specify " \ - "--private-ips?" + print("Master has no public DNS name. Maybe you meant to specify --private-ips?") else: - print get_dns_name(master_nodes[0], opts.private_ips) + print(get_dns_name(master_nodes[0], opts.private_ips)) elif action == "stop": response = raw_input( @@ -1361,11 +1361,11 @@ def real_main(): if response == "y": (master_nodes, slave_nodes) = get_existing_cluster( conn, opts, cluster_name, die_on_error=False) - print "Stopping master..." + print("Stopping master...") for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.stop() - print "Stopping slaves..." + print("Stopping slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: if inst.spot_instance_request_id: @@ -1375,11 +1375,11 @@ def real_main(): elif action == "start": (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print "Starting slaves..." + print("Starting slaves...") for inst in slave_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() - print "Starting master..." + print("Starting master...") for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() @@ -1403,15 +1403,15 @@ def real_main(): setup_cluster(conn, master_nodes, slave_nodes, opts, False) else: - print >> stderr, "Invalid action: %s" % action + print("Invalid action: %s" % action, file=stderr) sys.exit(1) def main(): try: real_main() - except UsageError, e: - print >> stderr, "\nError:\n", e + except UsageError as e: + print("\nError:\n", e, file=stderr) sys.exit(1) diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 70b6146e39a87..1c3a787bd0e94 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -21,7 +21,8 @@ This example requires numpy (http://www.numpy.org/) """ -from os.path import realpath +from __future__ import print_function + import sys import numpy as np @@ -57,9 +58,9 @@ def update(i, vec, mat, ratings): Usage: als [M] [U] [F] [iterations] [partitions]" """ - print >> sys.stderr, """WARN: This is a naive implementation of ALS and is given as an + print("""WARN: This is a naive implementation of ALS and is given as an example. Please use the ALS method found in pyspark.mllib.recommendation for more - conventional use.""" + conventional use.""", file=sys.stderr) sc = SparkContext(appName="PythonALS") M = int(sys.argv[1]) if len(sys.argv) > 1 else 100 @@ -68,8 +69,8 @@ def update(i, vec, mat, ratings): ITERATIONS = int(sys.argv[4]) if len(sys.argv) > 4 else 5 partitions = int(sys.argv[5]) if len(sys.argv) > 5 else 2 - print "Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" % \ - (M, U, F, ITERATIONS, partitions) + print("Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" % + (M, U, F, ITERATIONS, partitions)) R = matrix(rand(M, F)) * matrix(rand(U, F).T) ms = matrix(rand(M, F)) @@ -95,7 +96,7 @@ def update(i, vec, mat, ratings): usb = sc.broadcast(us) error = rmse(R, ms, us) - print "Iteration %d:" % i - print "\nRMSE: %5.4f\n" % error + print("Iteration %d:" % i) + print("\nRMSE: %5.4f\n" % error) sc.stop() diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 4626bbb7e3b02..da368ac628a49 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -15,9 +15,12 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext +from functools import reduce """ Read data file users.avro in local Spark distro: @@ -49,7 +52,7 @@ """ if __name__ == "__main__": if len(sys.argv) != 2 and len(sys.argv) != 3: - print >> sys.stderr, """ + print(""" Usage: avro_inputformat [reader_schema_file] Run with example jar: @@ -57,7 +60,7 @@ /path/to/examples/avro_inputformat.py [reader_schema_file] Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. - """ + """, file=sys.stderr) exit(-1) path = sys.argv[1] @@ -77,6 +80,6 @@ conf=conf) output = avro_rdd.map(lambda x: x[0]).collect() for k in output: - print k + print(k) sc.stop() diff --git a/examples/src/main/python/cassandra_inputformat.py b/examples/src/main/python/cassandra_inputformat.py index 05f34b74df45a..93ca0cfcc9302 100644 --- a/examples/src/main/python/cassandra_inputformat.py +++ b/examples/src/main/python/cassandra_inputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -47,14 +49,14 @@ """ if __name__ == "__main__": if len(sys.argv) != 4: - print >> sys.stderr, """ + print(""" Usage: cassandra_inputformat Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \ /path/to/examples/cassandra_inputformat.py Assumes you have some data in Cassandra already, running on , in and - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] @@ -77,6 +79,6 @@ conf=conf) output = cass_rdd.collect() for (k, v) in output: - print (k, v) + print((k, v)) sc.stop() diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py index d144539e58b8f..5d643eac92f94 100644 --- a/examples/src/main/python/cassandra_outputformat.py +++ b/examples/src/main/python/cassandra_outputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -46,7 +48,7 @@ """ if __name__ == "__main__": if len(sys.argv) != 7: - print >> sys.stderr, """ + print(""" Usage: cassandra_outputformat Run with example jar: @@ -60,7 +62,7 @@ ... fname text, ... lname text ... ); - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index 3b16010f1cb97..e17819d5feb76 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -47,14 +49,14 @@ """ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, """ + print(""" Usage: hbase_inputformat Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \ /path/to/examples/hbase_inputformat.py
Assumes you have some data in HBase already, running on , in
- """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] @@ -74,6 +76,6 @@ conf=conf) output = hbase_rdd.collect() for (k, v) in output: - print (k, v) + print((k, v)) sc.stop() diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py index abb425b1f886a..9e5641789a976 100644 --- a/examples/src/main/python/hbase_outputformat.py +++ b/examples/src/main/python/hbase_outputformat.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -40,7 +42,7 @@ """ if __name__ == "__main__": if len(sys.argv) != 7: - print >> sys.stderr, """ + print(""" Usage: hbase_outputformat
Run with example jar: @@ -48,7 +50,7 @@ /path/to/examples/hbase_outputformat.py Assumes you have created
with column family in HBase running on already - """ + """, file=sys.stderr) exit(-1) host = sys.argv[1] diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 86ef6f32c84e8..19391506463f0 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -22,6 +22,7 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function import sys @@ -47,12 +48,12 @@ def closestPoint(p, centers): if __name__ == "__main__": if len(sys.argv) != 4: - print >> sys.stderr, "Usage: kmeans " + print("Usage: kmeans ", file=sys.stderr) exit(-1) - print >> sys.stderr, """WARN: This is a naive implementation of KMeans Clustering and is given + print("""WARN: This is a naive implementation of KMeans Clustering and is given as an example! Please refer to examples/src/main/python/mllib/kmeans.py for an example on - how to use MLlib's KMeans implementation.""" + how to use MLlib's KMeans implementation.""", file=sys.stderr) sc = SparkContext(appName="PythonKMeans") lines = sc.textFile(sys.argv[1]) @@ -69,13 +70,13 @@ def closestPoint(p, centers): pointStats = closest.reduceByKey( lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2)) newPoints = pointStats.map( - lambda (x, (y, z)): (x, y / z)).collect() + lambda xy: (xy[0], xy[1][0] / xy[1][1])).collect() tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints) for (x, y) in newPoints: kPoints[x] = y - print "Final centers: " + str(kPoints) + print("Final centers: " + str(kPoints)) sc.stop() diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 3aa56b0528168..b318b7d87bfdc 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -22,10 +22,8 @@ In practice, one may prefer to use the LogisticRegression algorithm in MLlib, as shown in examples/src/main/python/mllib/logistic_regression.py. """ +from __future__ import print_function -from collections import namedtuple -from math import exp -from os.path import realpath import sys import numpy as np @@ -42,19 +40,19 @@ def readPointBatch(iterator): strs = list(iterator) matrix = np.zeros((len(strs), D + 1)) - for i in xrange(len(strs)): - matrix[i] = np.fromstring(strs[i].replace(',', ' '), dtype=np.float32, sep=' ') + for i, s in enumerate(strs): + matrix[i] = np.fromstring(s.replace(',', ' '), dtype=np.float32, sep=' ') return [matrix] if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: logistic_regression " + print("Usage: logistic_regression ", file=sys.stderr) exit(-1) - print >> sys.stderr, """WARN: This is a naive implementation of Logistic Regression and is + print("""WARN: This is a naive implementation of Logistic Regression and is given as an example! Please refer to examples/src/main/python/mllib/logistic_regression.py - to see how MLlib's implementation is used.""" + to see how MLlib's implementation is used.""", file=sys.stderr) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).mapPartitions(readPointBatch).cache() @@ -62,7 +60,7 @@ def readPointBatch(iterator): # Initialize w to a random value w = 2 * np.random.ranf(size=D) - 1 - print "Initial w: " + str(w) + print("Initial w: " + str(w)) # Compute logistic regression gradient for a matrix of data points def gradient(matrix, w): @@ -76,9 +74,9 @@ def add(x, y): return x for i in range(iterations): - print "On iteration %i" % (i + 1) + print("On iteration %i" % (i + 1)) w -= points.map(lambda m: gradient(m, w)).reduce(add) - print "Final w: " + str(w) + print("Final w: " + str(w)) sc.stop() diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index c73edb7fd6b20..fab21f003b233 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + from pyspark import SparkContext from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression @@ -37,10 +39,10 @@ # Prepare training documents, which are labeled. LabeledDocument = Row("id", "text", "label") - training = sc.parallelize([(0L, "a b c d e spark", 1.0), - (1L, "b d", 0.0), - (2L, "spark f g h", 1.0), - (3L, "hadoop mapreduce", 0.0)]) \ + training = sc.parallelize([(0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0)]) \ .map(lambda x: LabeledDocument(*x)).toDF() # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. @@ -54,16 +56,16 @@ # Prepare test documents, which are unlabeled. Document = Row("id", "text") - test = sc.parallelize([(4L, "spark i j k"), - (5L, "l m n"), - (6L, "mapreduce spark"), - (7L, "apache hadoop")]) \ + test = sc.parallelize([(4, "spark i j k"), + (5, "l m n"), + (6, "mapreduce spark"), + (7, "apache hadoop")]) \ .map(lambda x: Document(*x)).toDF() # Make predictions on test documents and print columns of interest. prediction = model.transform(test) selected = prediction.select("id", "text", "prediction") for row in selected.collect(): - print row + print(row) sc.stop() diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py index 4218eca822a99..0e13546b88e67 100755 --- a/examples/src/main/python/mllib/correlations.py +++ b/examples/src/main/python/mllib/correlations.py @@ -18,6 +18,7 @@ """ Correlations using MLlib. """ +from __future__ import print_function import sys @@ -29,7 +30,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: - print >> sys.stderr, "Usage: correlations ()" + print("Usage: correlations ()", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonCorrelations") if len(sys.argv) == 2: @@ -41,20 +42,20 @@ points = MLUtils.loadLibSVMFile(sc, filepath)\ .map(lambda lp: LabeledPoint(lp.label, lp.features.toArray())) - print - print 'Summary of data file: ' + filepath - print '%d data points' % points.count() + print() + print('Summary of data file: ' + filepath) + print('%d data points' % points.count()) # Statistics (correlations) - print - print 'Correlation (%s) between label and each feature' % corrType - print 'Feature\tCorrelation' + print() + print('Correlation (%s) between label and each feature' % corrType) + print('Feature\tCorrelation') numFeatures = points.take(1)[0].features.size labelRDD = points.map(lambda lp: lp.label) for i in range(numFeatures): featureRDD = points.map(lambda lp: lp.features[i]) corr = Statistics.corr(labelRDD, featureRDD, corrType) - print '%d\t%g' % (i, corr) - print + print('%d\t%g' % (i, corr)) + print() sc.stop() diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index fcbf56cbf0c52..e23ecc0c5d302 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -19,6 +19,7 @@ An example of how to use DataFrame as a dataset for ML. Run with:: bin/spark-submit examples/src/main/python/mllib/dataset_example.py """ +from __future__ import print_function import os import sys @@ -32,16 +33,16 @@ def summarize(dataset): - print "schema: %s" % dataset.schema().json() + print("schema: %s" % dataset.schema().json()) labels = dataset.map(lambda r: r.label) - print "label average: %f" % labels.mean() + print("label average: %f" % labels.mean()) features = dataset.map(lambda r: r.features) summary = Statistics.colStats(features) - print "features average: %r" % summary.mean() + print("features average: %r" % summary.mean()) if __name__ == "__main__": if len(sys.argv) > 2: - print >> sys.stderr, "Usage: dataset_example.py " + print("Usage: dataset_example.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="DatasetExample") sqlContext = SQLContext(sc) @@ -54,9 +55,9 @@ def summarize(dataset): summarize(dataset0) tempdir = tempfile.NamedTemporaryFile(delete=False).name os.unlink(tempdir) - print "Save dataset as a Parquet file to %s." % tempdir + print("Save dataset as a Parquet file to %s." % tempdir) dataset0.saveAsParquetFile(tempdir) - print "Load it back and summarize it again." + print("Load it back and summarize it again.") dataset1 = sqlContext.parquetFile(tempdir).setName("dataset1").cache() summarize(dataset1) shutil.rmtree(tempdir) diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py index fccabd841b139..513ed8fd51450 100755 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -20,6 +20,7 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function import numpy import os @@ -83,18 +84,17 @@ def reindexClassLabels(data): numClasses = len(classCounts) # origToNewLabels: class --> index in 0,...,numClasses-1 if (numClasses < 2): - print >> sys.stderr, \ - "Dataset for classification should have at least 2 classes." + \ - " The given dataset had only %d classes." % numClasses + print("Dataset for classification should have at least 2 classes." + " The given dataset had only %d classes." % numClasses, file=sys.stderr) exit(1) origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) - print "numClasses = %d" % numClasses - print "Per-class example fractions, counts:" - print "Class\tFrac\tCount" + print("numClasses = %d" % numClasses) + print("Per-class example fractions, counts:") + print("Class\tFrac\tCount") for c in sortedClasses: frac = classCounts[c] / (numExamples + 0.0) - print "%g\t%g\t%d" % (c, frac, classCounts[c]) + print("%g\t%g\t%d" % (c, frac, classCounts[c])) if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): return (data, origToNewLabels) @@ -105,8 +105,7 @@ def reindexClassLabels(data): def usage(): - print >> sys.stderr, \ - "Usage: decision_tree_runner [libsvm format data filepath]" + print("Usage: decision_tree_runner [libsvm format data filepath]", file=sys.stderr) exit(1) @@ -133,13 +132,13 @@ def usage(): model = DecisionTree.trainClassifier(reindexedData, numClasses=numClasses, categoricalFeaturesInfo=categoricalFeaturesInfo) # Print learned tree and stats. - print "Trained DecisionTree for classification:" - print " Model numNodes: %d" % model.numNodes() - print " Model depth: %d" % model.depth() - print " Training accuracy: %g" % getAccuracy(model, reindexedData) + print("Trained DecisionTree for classification:") + print(" Model numNodes: %d" % model.numNodes()) + print(" Model depth: %d" % model.depth()) + print(" Training accuracy: %g" % getAccuracy(model, reindexedData)) if model.numNodes() < 20: - print model.toDebugString() + print(model.toDebugString()) else: - print model + print(model) sc.stop() diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py index a2cd626c9f19d..2cb8010cdc07f 100644 --- a/examples/src/main/python/mllib/gaussian_mixture_model.py +++ b/examples/src/main/python/mllib/gaussian_mixture_model.py @@ -18,7 +18,8 @@ """ A Gaussian Mixture Model clustering program using MLlib. """ -import sys +from __future__ import print_function + import random import argparse import numpy as np @@ -59,7 +60,7 @@ def parseVector(line): model = GaussianMixture.train(data, args.k, args.convergenceTol, args.maxIterations, args.seed) for i in range(args.k): - print ("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, - "sigma = ", model.gaussians[i].sigma.toArray()) - print ("Cluster labels (first 100): ", model.predict(data).take(100)) + print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu, + "sigma = ", model.gaussians[i].sigma.toArray())) + print(("Cluster labels (first 100): ", model.predict(data).take(100))) sc.stop() diff --git a/examples/src/main/python/mllib/gradient_boosted_trees.py b/examples/src/main/python/mllib/gradient_boosted_trees.py index e647773ad9060..781bd61c9d2b5 100644 --- a/examples/src/main/python/mllib/gradient_boosted_trees.py +++ b/examples/src/main/python/mllib/gradient_boosted_trees.py @@ -18,6 +18,7 @@ """ Gradient boosted Trees classification and regression using MLlib. """ +from __future__ import print_function import sys @@ -34,7 +35,7 @@ def testClassification(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() \ + testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count() \ / float(testData.count()) print('Test Error = ' + str(testErr)) print('Learned classification ensemble model:') @@ -49,7 +50,7 @@ def testRegression(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() \ + testMSE = labelsAndPredictions.map(lambda vp: (vp[0] - vp[1]) * (vp[0] - vp[1])).sum() \ / float(testData.count()) print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression ensemble model:') @@ -58,7 +59,7 @@ def testRegression(trainingData, testData): if __name__ == "__main__": if len(sys.argv) > 1: - print >> sys.stderr, "Usage: gradient_boosted_trees" + print("Usage: gradient_boosted_trees", file=sys.stderr) exit(1) sc = SparkContext(appName="PythonGradientBoostedTrees") diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index 2eeb1abeeb12b..f901a87fa63ac 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -20,6 +20,7 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function import sys @@ -34,12 +35,12 @@ def parseVector(line): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: kmeans " + print("Usage: kmeans ", file=sys.stderr) exit(-1) sc = SparkContext(appName="KMeans") lines = sc.textFile(sys.argv[1]) data = lines.map(parseVector) k = int(sys.argv[2]) model = KMeans.train(data, k) - print "Final centers: " + str(model.clusterCenters) + print("Final centers: " + str(model.clusterCenters)) sc.stop() diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 8cae27fc4a52d..d4f1d34e2d8cf 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -20,11 +20,10 @@ This example requires NumPy (http://www.numpy.org/). """ +from __future__ import print_function -from math import exp import sys -import numpy as np from pyspark import SparkContext from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.classification import LogisticRegressionWithSGD @@ -42,12 +41,12 @@ def parsePoint(line): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: logistic_regression " + print("Usage: logistic_regression ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonLR") points = sc.textFile(sys.argv[1]).map(parsePoint) iterations = int(sys.argv[2]) model = LogisticRegressionWithSGD.train(points, iterations) - print "Final weights: " + str(model.weights) - print "Final intercept: " + str(model.intercept) + print("Final weights: " + str(model.weights)) + print("Final intercept: " + str(model.intercept)) sc.stop() diff --git a/examples/src/main/python/mllib/random_forest_example.py b/examples/src/main/python/mllib/random_forest_example.py index d3c24f7664329..4cfdad868c66e 100755 --- a/examples/src/main/python/mllib/random_forest_example.py +++ b/examples/src/main/python/mllib/random_forest_example.py @@ -22,6 +22,7 @@ For information on multiclass classification, please refer to the decision_tree_runner.py example. """ +from __future__ import print_function import sys @@ -43,7 +44,7 @@ def testClassification(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count()\ + testErr = labelsAndPredictions.filter(lambda v_p: v_p[0] != v_p[1]).count()\ / float(testData.count()) print('Test Error = ' + str(testErr)) print('Learned classification forest model:') @@ -62,8 +63,8 @@ def testRegression(trainingData, testData): # Evaluate model on test instances and compute test error predictions = model.predict(testData.map(lambda x: x.features)) labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) - testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum()\ - / float(testData.count()) + testMSE = labelsAndPredictions.map(lambda v_p1: (v_p1[0] - v_p1[1]) * (v_p1[0] - v_p1[1]))\ + .sum() / float(testData.count()) print('Test Mean Squared Error = ' + str(testMSE)) print('Learned regression forest model:') print(model.toDebugString()) @@ -71,7 +72,7 @@ def testRegression(trainingData, testData): if __name__ == "__main__": if len(sys.argv) > 1: - print >> sys.stderr, "Usage: random_forest_example" + print("Usage: random_forest_example", file=sys.stderr) exit(1) sc = SparkContext(appName="PythonRandomForestExample") diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py index 1e8892741e714..729bae30b152c 100755 --- a/examples/src/main/python/mllib/random_rdd_generation.py +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -18,6 +18,7 @@ """ Randomly generated RDDs. """ +from __future__ import print_function import sys @@ -27,7 +28,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: - print >> sys.stderr, "Usage: random_rdd_generation" + print("Usage: random_rdd_generation", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonRandomRDDGeneration") @@ -37,19 +38,19 @@ # Example: RandomRDDs.normalRDD normalRDD = RandomRDDs.normalRDD(sc, numExamples) - print 'Generated RDD of %d examples sampled from the standard normal distribution'\ - % normalRDD.count() - print ' First 5 samples:' + print('Generated RDD of %d examples sampled from the standard normal distribution' + % normalRDD.count()) + print(' First 5 samples:') for sample in normalRDD.take(5): - print ' ' + str(sample) - print + print(' ' + str(sample)) + print() # Example: RandomRDDs.normalVectorRDD normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows=numExamples, numCols=2) - print 'Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count() - print ' First 5 samples:' + print('Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count()) + print(' First 5 samples:') for sample in normalVectorRDD.take(5): - print ' ' + str(sample) - print + print(' ' + str(sample)) + print() sc.stop() diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py index 92af3af5ebd1e..b7033ab7daeb3 100755 --- a/examples/src/main/python/mllib/sampled_rdds.py +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -18,6 +18,7 @@ """ Randomly sampled RDDs. """ +from __future__ import print_function import sys @@ -27,7 +28,7 @@ if __name__ == "__main__": if len(sys.argv) not in [1, 2]: - print >> sys.stderr, "Usage: sampled_rdds " + print("Usage: sampled_rdds ", file=sys.stderr) exit(-1) if len(sys.argv) == 2: datapath = sys.argv[1] @@ -41,24 +42,24 @@ examples = MLUtils.loadLibSVMFile(sc, datapath) numExamples = examples.count() if numExamples == 0: - print >> sys.stderr, "Error: Data file had no samples to load." + print("Error: Data file had no samples to load.", file=sys.stderr) exit(1) - print 'Loaded data with %d examples from file: %s' % (numExamples, datapath) + print('Loaded data with %d examples from file: %s' % (numExamples, datapath)) # Example: RDD.sample() and RDD.takeSample() expectedSampleSize = int(numExamples * fraction) - print 'Sampling RDD using fraction %g. Expected sample size = %d.' \ - % (fraction, expectedSampleSize) + print('Sampling RDD using fraction %g. Expected sample size = %d.' + % (fraction, expectedSampleSize)) sampledRDD = examples.sample(withReplacement=True, fraction=fraction) - print ' RDD.sample(): sample has %d examples' % sampledRDD.count() + print(' RDD.sample(): sample has %d examples' % sampledRDD.count()) sampledArray = examples.takeSample(withReplacement=True, num=expectedSampleSize) - print ' RDD.takeSample(): sample has %d examples' % len(sampledArray) + print(' RDD.takeSample(): sample has %d examples' % len(sampledArray)) - print + print() # Example: RDD.sampleByKey() keyedRDD = examples.map(lambda lp: (int(lp.label), lp.features)) - print ' Keyed data using label (Int) as key ==> Orig' + print(' Keyed data using label (Int) as key ==> Orig') # Count examples per label in original data. keyCountsA = keyedRDD.countByKey() @@ -69,18 +70,18 @@ sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement=True, fractions=fractions) keyCountsB = sampledByKeyRDD.countByKey() sizeB = sum(keyCountsB.values()) - print ' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' \ - % sizeB + print(' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' + % sizeB) # Compare samples - print ' \tFractions of examples with key' - print 'Key\tOrig\tSample' + print(' \tFractions of examples with key') + print('Key\tOrig\tSample') for k in sorted(keyCountsA.keys()): fracA = keyCountsA[k] / float(numExamples) if sizeB != 0: fracB = keyCountsB.get(k, 0) / float(sizeB) else: fracB = 0 - print '%d\t%g\t%g' % (k, fracA, fracB) + print('%d\t%g\t%g' % (k, fracA, fracB)) sc.stop() diff --git a/examples/src/main/python/mllib/word2vec.py b/examples/src/main/python/mllib/word2vec.py index 99fef4276a369..40d1b887927e0 100644 --- a/examples/src/main/python/mllib/word2vec.py +++ b/examples/src/main/python/mllib/word2vec.py @@ -23,6 +23,7 @@ # grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines # This was done so that the example can be run in local mode +from __future__ import print_function import sys @@ -34,7 +35,7 @@ if __name__ == "__main__": if len(sys.argv) < 2: - print USAGE + print(USAGE) sys.exit("Argument for file not provided") file_path = sys.argv[1] sc = SparkContext(appName='Word2Vec') @@ -46,5 +47,5 @@ synonyms = model.findSynonyms('china', 40) for word, cosine_distance in synonyms: - print "{}: {}".format(word, cosine_distance) + print("{}: {}".format(word, cosine_distance)) sc.stop() diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index a5f25d78c1146..2fdc9773d4eb1 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -19,6 +19,7 @@ This is an example implementation of PageRank. For more conventional use, Please refer to PageRank implementation provided by graphx """ +from __future__ import print_function import re import sys @@ -42,11 +43,12 @@ def parseNeighbors(urls): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: pagerank " + print("Usage: pagerank ", file=sys.stderr) exit(-1) - print >> sys.stderr, """WARN: This is a naive implementation of PageRank and is - given as an example! Please refer to PageRank implementation provided by graphx""" + print("""WARN: This is a naive implementation of PageRank and is + given as an example! Please refer to PageRank implementation provided by graphx""", + file=sys.stderr) # Initialize the spark context. sc = SparkContext(appName="PythonPageRank") @@ -62,19 +64,19 @@ def parseNeighbors(urls): links = lines.map(lambda urls: parseNeighbors(urls)).distinct().groupByKey().cache() # Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one. - ranks = links.map(lambda (url, neighbors): (url, 1.0)) + ranks = links.map(lambda url_neighbors: (url_neighbors[0], 1.0)) # Calculates and updates URL ranks continuously using PageRank algorithm. - for iteration in xrange(int(sys.argv[2])): + for iteration in range(int(sys.argv[2])): # Calculates URL contributions to the rank of other URLs. contribs = links.join(ranks).flatMap( - lambda (url, (urls, rank)): computeContribs(urls, rank)) + lambda url_urls_rank: computeContribs(url_urls_rank[1][0], url_urls_rank[1][1])) # Re-calculates URL ranks based on neighbor contributions. ranks = contribs.reduceByKey(add).mapValues(lambda rank: rank * 0.85 + 0.15) # Collects all URL ranks and dump them to console. for (link, rank) in ranks.collect(): - print "%s has rank: %s." % (link, rank) + print("%s has rank: %s." % (link, rank)) sc.stop() diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index fa4c20ab20281..96ddac761d698 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -1,3 +1,4 @@ +from __future__ import print_function # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -35,14 +36,14 @@ """ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, """ + print(""" Usage: parquet_inputformat.py Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \\ /path/to/examples/parquet_inputformat.py Assumes you have Parquet data stored in . - """ + """, file=sys.stderr) exit(-1) path = sys.argv[1] @@ -56,6 +57,6 @@ valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter') output = parquet_rdd.map(lambda x: x[1]).collect() for k in output: - print k + print(k) sc.stop() diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index a7c74e969cdb9..92e5cf45abc8b 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -1,3 +1,4 @@ +from __future__ import print_function # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -35,7 +36,7 @@ def f(_): y = random() * 2 - 1 return 1 if x ** 2 + y ** 2 < 1 else 0 - count = sc.parallelize(xrange(1, n + 1), partitions).map(f).reduce(add) - print "Pi is roughly %f" % (4.0 * count / n) + count = sc.parallelize(range(1, n + 1), partitions).map(f).reduce(add) + print("Pi is roughly %f" % (4.0 * count / n)) sc.stop() diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index bb686f17518a0..f6b0ecb02c100 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from pyspark import SparkContext @@ -22,7 +24,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: sort " + print("Usage: sort ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonSort") lines = sc.textFile(sys.argv[1], 1) @@ -33,6 +35,6 @@ # In reality, we wouldn't want to collect all the data to the driver node. output = sortedCount.collect() for (num, unitcount) in output: - print num + print(num) sc.stop() diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index d89361f324917..87d7b088f077b 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import os from pyspark import SparkContext @@ -68,6 +70,6 @@ teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") for each in teenagers.collect(): - print each[0] + print(each[0]) sc.stop() diff --git a/examples/src/main/python/status_api_demo.py b/examples/src/main/python/status_api_demo.py index a33bdc475a06d..49b7902185aaa 100644 --- a/examples/src/main/python/status_api_demo.py +++ b/examples/src/main/python/status_api_demo.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import time import threading import Queue @@ -52,15 +54,15 @@ def run(): ids = status.getJobIdsForGroup() for id in ids: job = status.getJobInfo(id) - print "Job", id, "status: ", job.status + print("Job", id, "status: ", job.status) for sid in job.stageIds: info = status.getStageInfo(sid) if info: - print "Stage %d: %d tasks total (%d active, %d complete)" % \ - (sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks) + print("Stage %d: %d tasks total (%d active, %d complete)" % + (sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks)) time.sleep(1) - print "Job results are:", result.get() + print("Job results are:", result.get()) sc.stop() if __name__ == "__main__": diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py index f7ffb5379681e..f815dd26823d1 100644 --- a/examples/src/main/python/streaming/hdfs_wordcount.py +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -25,6 +25,7 @@ Then create a text file in `localdir` and the words in the file will get counted. """ +from __future__ import print_function import sys @@ -33,7 +34,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: hdfs_wordcount.py " + print("Usage: hdfs_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingHDFSWordCount") diff --git a/examples/src/main/python/streaming/kafka_wordcount.py b/examples/src/main/python/streaming/kafka_wordcount.py index 51e1ff822fc55..b178e7899b5e1 100644 --- a/examples/src/main/python/streaming/kafka_wordcount.py +++ b/examples/src/main/python/streaming/kafka_wordcount.py @@ -27,6 +27,7 @@ spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \ localhost:2181 test` """ +from __future__ import print_function import sys @@ -36,7 +37,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: kafka_wordcount.py " + print("Usage: kafka_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingKafkaWordCount") diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py index cfa9c1ff5bfbc..2b48bcfd55db0 100644 --- a/examples/src/main/python/streaming/network_wordcount.py +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -25,6 +25,7 @@ and then run the example `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999` """ +from __future__ import print_function import sys @@ -33,7 +34,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: network_wordcount.py " + print("Usage: network_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index fc6827c82bf9b..ac91f0a06b172 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -35,6 +35,7 @@ checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from the checkpoint data. """ +from __future__ import print_function import os import sys @@ -46,7 +47,7 @@ def createContext(host, port, outputPath): # If you do not see this printed, that means the StreamingContext has been loaded # from the new checkpoint - print "Creating new context" + print("Creating new context") if os.path.exists(outputPath): os.remove(outputPath) sc = SparkContext(appName="PythonStreamingRecoverableNetworkWordCount") @@ -60,8 +61,8 @@ def createContext(host, port, outputPath): def echo(time, rdd): counts = "Counts at time %s %s" % (time, rdd.collect()) - print counts - print "Appending to " + os.path.abspath(outputPath) + print(counts) + print("Appending to " + os.path.abspath(outputPath)) with open(outputPath, 'a') as f: f.write(counts + "\n") @@ -70,8 +71,8 @@ def echo(time, rdd): if __name__ == "__main__": if len(sys.argv) != 5: - print >> sys.stderr, "Usage: recoverable_network_wordcount.py "\ - " " + print("Usage: recoverable_network_wordcount.py " + " ", file=sys.stderr) exit(-1) host, port, checkpoint, output = sys.argv[1:] ssc = StreamingContext.getOrCreate(checkpoint, diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index f89bc562d856b..da90c07dbd82f 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -27,6 +27,7 @@ and then run the example `$ bin/spark-submit examples/src/main/python/streaming/sql_network_wordcount.py localhost 9999` """ +from __future__ import print_function import os import sys @@ -44,7 +45,7 @@ def getSqlContextInstance(sparkContext): if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: sql_network_wordcount.py " + print("Usage: sql_network_wordcount.py ", file=sys.stderr) exit(-1) host, port = sys.argv[1:] sc = SparkContext(appName="PythonSqlNetworkWordCount") @@ -57,7 +58,7 @@ def getSqlContextInstance(sparkContext): # Convert RDDs of the words DStream to DataFrame and run SQL query def process(time, rdd): - print "========= %s =========" % str(time) + print("========= %s =========" % str(time)) try: # Get the singleton instance of SQLContext diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index 18a9a5a452ffb..16ef646b7c42e 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -29,6 +29,7 @@ `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \ localhost 9999` """ +from __future__ import print_function import sys @@ -37,7 +38,7 @@ if __name__ == "__main__": if len(sys.argv) != 3: - print >> sys.stderr, "Usage: stateful_network_wordcount.py " + print("Usage: stateful_network_wordcount.py ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") ssc = StreamingContext(sc, 1) diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 00a281bfb6506..7bf5fb6ddfe29 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from random import Random @@ -49,20 +51,20 @@ def generateGraph(): # the graph to obtain the path (x, z). # Because join() joins on keys, the edges are stored in reversed order. - edges = tc.map(lambda (x, y): (y, x)) + edges = tc.map(lambda x_y: (x_y[1], x_y[0])) - oldCount = 0L + oldCount = 0 nextCount = tc.count() while True: oldCount = nextCount # Perform the join, obtaining an RDD of (y, (z, x)) pairs, # then project the result to obtain the new (x, z) paths. - new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a)) + new_edges = tc.join(edges).map(lambda __a_b: (__a_b[1][1], __a_b[1][0])) tc = tc.union(new_edges).distinct().cache() nextCount = tc.count() if nextCount == oldCount: break - print "TC has %i edges" % tc.count() + print("TC has %i edges" % tc.count()) sc.stop() diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index ae6cd13b83d92..7c0143607b61d 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import sys from operator import add @@ -23,7 +25,7 @@ if __name__ == "__main__": if len(sys.argv) != 2: - print >> sys.stderr, "Usage: wordcount " + print("Usage: wordcount ", file=sys.stderr) exit(-1) sc = SparkContext(appName="PythonWordCount") lines = sc.textFile(sys.argv[1], 1) @@ -32,6 +34,6 @@ .reduceByKey(add) output = counts.collect() for (word, count) in output: - print "%s: %i" % (word, count) + print("%s: %i" % (word, count)) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala index ecd3b16598438..534edac56bc5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/MatrixFactorizationModelWrapper.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.api.python import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.recommendation.{MatrixFactorizationModel, Rating} import org.apache.spark.rdd.RDD @@ -31,10 +32,14 @@ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorization predict(SerDe.asTupleRDD(userAndProducts.rdd)) def getUserFeatures: RDD[Array[Any]] = { - SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + SerDe.fromTuple2RDD(userFeatures.map { + case (user, feature) => (user, Vectors.dense(feature)) + }.asInstanceOf[RDD[(Any, Any)]]) } def getProductFeatures: RDD[Array[Any]] = { - SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + SerDe.fromTuple2RDD(productFeatures.map { + case (product, feature) => (product, Vectors.dense(feature)) + }.asInstanceOf[RDD[(Any, Any)]]) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index ab15f0f36a14b..f976d2f97b043 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -28,7 +28,6 @@ import scala.reflect.ClassTag import net.razorvine.pickle._ -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ @@ -40,15 +39,15 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree} -import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy} +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.loss.Losses -import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel} +import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -279,7 +278,7 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[LabeledPoint], lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) - List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta). + List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta.map(Vectors.dense)). map(_.asInstanceOf[Object]).asJava } @@ -335,7 +334,7 @@ private[python] class PythonMLLibAPI extends Serializable { mu += model.gaussians(i).mu sigma += model.gaussians(i).sigma } - List(wt.toArray, mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava } finally { data.rdd.unpersist(blocking = false) } @@ -346,20 +345,20 @@ private[python] class PythonMLLibAPI extends Serializable { */ def predictSoftGMM( data: JavaRDD[Vector], - wt: Object, + wt: Vector, mu: Array[Object], - si: Array[Object]): RDD[Array[Double]] = { + si: Array[Object]): RDD[Vector] = { - val weight = wt.asInstanceOf[Array[Double]] + val weight = wt.toArray val mean = mu.map(_.asInstanceOf[DenseVector]) val sigma = si.map(_.asInstanceOf[DenseMatrix]) val gaussians = Array.tabulate(weight.length){ i => new MultivariateGaussian(mean(i), sigma(i)) } val model = new GaussianMixtureModel(weight, gaussians) - model.predictSoft(data) + model.predictSoft(data).map(Vectors.dense) } - + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -936,6 +935,14 @@ private[spark] object SerDe extends Serializable { out.write(code) } + protected def getBytes(obj: Object): Array[Byte] = { + if (obj.getClass.isArray) { + obj.asInstanceOf[Array[Byte]] + } else { + obj.asInstanceOf[String].getBytes(LATIN1) + } + } + private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler) } @@ -961,7 +968,7 @@ private[spark] object SerDe extends Serializable { if (args.length != 1) { throw new PickleException("should be 1") } - val bytes = args(0).asInstanceOf[String].getBytes(LATIN1) + val bytes = getBytes(args(0)) val bb = ByteBuffer.wrap(bytes, 0, bytes.length) bb.order(ByteOrder.nativeOrder()) val db = bb.asDoubleBuffer() @@ -994,7 +1001,7 @@ private[spark] object SerDe extends Serializable { if (args.length != 3) { throw new PickleException("should be 3") } - val bytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val bytes = getBytes(args(2)) val n = bytes.length / 8 val values = new Array[Double](n) val order = ByteOrder.nativeOrder() @@ -1031,8 +1038,8 @@ private[spark] object SerDe extends Serializable { throw new PickleException("should be 3") } val size = args(0).asInstanceOf[Int] - val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1) - val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val indiceBytes = getBytes(args(1)) + val valueBytes = getBytes(args(2)) val n = indiceBytes.length / 4 val indices = new Array[Int](n) val values = new Array[Double](n) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index ccbca67656c8d..7271809e43880 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -54,7 +54,7 @@ ... def zero(self, value): ... return [0.0] * len(value) ... def addInPlace(self, val1, val2): -... for i in xrange(len(val1)): +... for i in range(len(val1)): ... val1[i] += val2[i] ... return val1 >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) @@ -86,9 +86,13 @@ Exception:... """ +import sys import select import struct -import SocketServer +if sys.version < '3': + import SocketServer +else: + import socketserver as SocketServer import threading from pyspark.cloudpickle import CloudPickler from pyspark.serializers import read_int, PickleSerializer @@ -247,6 +251,7 @@ class AccumulatorServer(SocketServer.TCPServer): def shutdown(self): self.server_shutdown = True SocketServer.TCPServer.shutdown(self) + self.server_close() def _start_update_server(): diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 6b8a8b256a891..3de4615428bb6 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -16,10 +16,15 @@ # import os -import cPickle +import sys import gc from tempfile import NamedTemporaryFile +if sys.version < '3': + import cPickle as pickle +else: + import pickle + unicode = str __all__ = ['Broadcast'] @@ -70,33 +75,19 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None): self._path = path def dump(self, value, f): - if isinstance(value, basestring): - if isinstance(value, unicode): - f.write('U') - value = value.encode('utf8') - else: - f.write('S') - f.write(value) - else: - f.write('P') - cPickle.dump(value, f, 2) + pickle.dump(value, f, 2) f.close() return f.name def load(self, path): with open(path, 'rb', 1 << 20) as f: - flag = f.read(1) - data = f.read() - if flag == 'P': - # cPickle.loads() may create lots of objects, disable GC - # temporary for better performance - gc.disable() - try: - return cPickle.loads(data) - finally: - gc.enable() - else: - return data.decode('utf8') if flag == 'U' else data + # pickle.load() may create lots of objects, disable GC + # temporary for better performance + gc.disable() + try: + return pickle.load(f) + finally: + gc.enable() @property def value(self): diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index bb0783555aa77..9ef93071d2e77 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -40,164 +40,126 @@ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ - +from __future__ import print_function import operator import os +import io import pickle import struct import sys import types from functools import partial import itertools -from copy_reg import _extension_registry, _inverted_registry, _extension_cache -import new import dis import traceback -import platform - -PyImp = platform.python_implementation() - -import logging -cloudLog = logging.getLogger("Cloud.Transport") +if sys.version < '3': + from pickle import Pickler + try: + from cStringIO import StringIO + except ImportError: + from StringIO import StringIO + PY3 = False +else: + types.ClassType = type + from pickle import _Pickler as Pickler + from io import BytesIO as StringIO + PY3 = True #relevant opcodes -STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL')) -DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL')) -LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL')) +STORE_GLOBAL = dis.opname.index('STORE_GLOBAL') +DELETE_GLOBAL = dis.opname.index('DELETE_GLOBAL') +LOAD_GLOBAL = dis.opname.index('LOAD_GLOBAL') GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL] +HAVE_ARGUMENT = dis.HAVE_ARGUMENT +EXTENDED_ARG = dis.EXTENDED_ARG -HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT) -EXTENDED_ARG = chr(dis.EXTENDED_ARG) - -if PyImp == "PyPy": - # register builtin type in `new` - new.method = types.MethodType - -try: - from cStringIO import StringIO -except ImportError: - from StringIO import StringIO -# These helper functions were copied from PiCloud's util module. def islambda(func): - return getattr(func,'func_name') == '' + return getattr(func,'__name__') == '' -def xrange_params(xrangeobj): - """Returns a 3 element tuple describing the xrange start, step, and len - respectively - Note: Only guarentees that elements of xrange are the same. parameters may - be different. - e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same - though w/ iteration - """ - - xrange_len = len(xrangeobj) - if not xrange_len: #empty - return (0,1,0) - start = xrangeobj[0] - if xrange_len == 1: #one element - return start, 1, 1 - return (start, xrangeobj[1] - xrangeobj[0], xrange_len) - -#debug variables intended for developer use: -printSerialization = False -printMemoization = False +_BUILTIN_TYPE_NAMES = {} +for k, v in types.__dict__.items(): + if type(v) is type: + _BUILTIN_TYPE_NAMES[v] = k -useForcedImports = True #Should I use forced imports for tracking? +def _builtin_type(name): + return getattr(types, name) -class CloudPickler(pickle.Pickler): +class CloudPickler(Pickler): - dispatch = pickle.Pickler.dispatch.copy() - savedForceImports = False - savedDjangoEnv = False #hack tro transport django environment + dispatch = Pickler.dispatch.copy() - def __init__(self, file, protocol=None, min_size_to_save= 0): - pickle.Pickler.__init__(self,file,protocol) - self.modules = set() #set of modules needed to depickle - self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env + def __init__(self, file, protocol=None): + Pickler.__init__(self, file, protocol) + # set of modules to unpickle + self.modules = set() + # map ids to dictionary. used to ensure that functions can share global env + self.globals_ref = {} def dump(self, obj): - # note: not thread safe - # minimal side-effects, so not fixing - recurse_limit = 3000 - base_recurse = sys.getrecursionlimit() - if base_recurse < recurse_limit: - sys.setrecursionlimit(recurse_limit) self.inject_addons() try: - return pickle.Pickler.dump(self, obj) - except RuntimeError, e: + return Pickler.dump(self, obj) + except RuntimeError as e: if 'recursion' in e.args[0]: - msg = """Could not pickle object as excessively deep recursion required. - Try _fast_serialization=2 or contact PiCloud support""" + msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) - finally: - new_recurse = sys.getrecursionlimit() - if new_recurse == recurse_limit: - sys.setrecursionlimit(base_recurse) + + def save_memoryview(self, obj): + """Fallback to save_string""" + Pickler.save_string(self, str(obj)) def save_buffer(self, obj): """Fallback to save_string""" - pickle.Pickler.save_string(self,str(obj)) - dispatch[buffer] = save_buffer + Pickler.save_string(self,str(obj)) + if PY3: + dispatch[memoryview] = save_memoryview + else: + dispatch[buffer] = save_buffer - #block broken objects - def save_unsupported(self, obj, pack=None): + def save_unsupported(self, obj): raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) dispatch[types.GeneratorType] = save_unsupported - #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it - try: - slice(0,1).__reduce__() - except TypeError: #can't pickle - - dispatch[slice] = save_unsupported - - #itertools objects do not pickle! + # itertools objects do not pickle! for v in itertools.__dict__.values(): if type(v) is type: dispatch[v] = save_unsupported - - def save_dict(self, obj): - """hack fix - If the dict is a global, deal with it in a special way - """ - #print 'saving', obj - if obj is __builtins__: - self.save_reduce(_get_module_builtins, (), obj=obj) - else: - pickle.Pickler.save_dict(self, obj) - dispatch[pickle.DictionaryType] = save_dict - - - def save_module(self, obj, pack=struct.pack): + def save_module(self, obj): """ Save a module as an import """ - #print 'try save import', obj.__name__ self.modules.add(obj) - self.save_reduce(subimport,(obj.__name__,), obj=obj) - dispatch[types.ModuleType] = save_module #new type + self.save_reduce(subimport, (obj.__name__,), obj=obj) + dispatch[types.ModuleType] = save_module - def save_codeobject(self, obj, pack=struct.pack): + def save_codeobject(self, obj): """ Save a code object """ - #print 'try to save codeobj: ', obj - args = ( - obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, - obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, - obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars - ) + if PY3: + args = ( + obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, + obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames, + obj.co_filename, obj.co_name, obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, + obj.co_cellvars + ) + else: + args = ( + obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, + obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars + ) self.save_reduce(types.CodeType, args, obj=obj) - dispatch[types.CodeType] = save_codeobject #new type + dispatch[types.CodeType] = save_codeobject - def save_function(self, obj, name=None, pack=struct.pack): + def save_function(self, obj, name=None): """ Registered with the dispatch to handle all function types. Determines what kind of function obj is (e.g. lambda, defined at @@ -205,12 +167,14 @@ def save_function(self, obj, name=None, pack=struct.pack): """ write = self.write - name = obj.__name__ + if name is None: + name = obj.__name__ modname = pickle.whichmodule(obj, name) - #print 'which gives %s %s %s' % (modname, obj, name) + # print('which gives %s %s %s' % (modname, obj, name)) try: themodule = sys.modules[modname] - except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__ + except KeyError: + # eval'd items such as namedtuple give invalid items for their function __module__ modname = '__main__' if modname == '__main__': @@ -221,37 +185,18 @@ def save_function(self, obj, name=None, pack=struct.pack): if getattr(themodule, name, None) is obj: return self.save_global(obj, name) - if not self.savedDjangoEnv: - #hack for django - if we detect the settings module, we transport it - django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '') - if django_settings: - django_mod = sys.modules.get(django_settings) - if django_mod: - cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name) - self.savedDjangoEnv = True - self.modules.add(django_mod) - write(pickle.MARK) - self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod) - write(pickle.POP_MARK) - - # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a # reference (as is done in default pickler), via save_function_tuple. - if islambda(obj) or obj.func_code.co_filename == '' or themodule is None: - #Force server to import modules that have been imported in main - modList = None - if themodule is None and not self.savedForceImports: - mainmod = sys.modules['__main__'] - if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'): - modList = list(mainmod.___pyc_forcedImports__) - self.savedForceImports = True - self.save_function_tuple(obj, modList) + if islambda(obj) or obj.__code__.co_filename == '' or themodule is None: + #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule) + self.save_function_tuple(obj) return - else: # func is nested + else: + # func is nested klass = getattr(themodule, name, None) if klass is None or klass is not obj: - self.save_function_tuple(obj, [themodule]) + self.save_function_tuple(obj) return if obj.__dict__: @@ -266,7 +211,7 @@ def save_function(self, obj, name=None, pack=struct.pack): self.memoize(obj) dispatch[types.FunctionType] = save_function - def save_function_tuple(self, func, forced_imports): + def save_function_tuple(self, func): """ Pickles an actual func object. A func comprises: code, globals, defaults, closure, and dict. We @@ -281,19 +226,6 @@ def save_function_tuple(self, func, forced_imports): save = self.save write = self.write - # save the modules (if any) - if forced_imports: - write(pickle.MARK) - save(_modules_to_main) - #print 'forced imports are', forced_imports - - forced_names = map(lambda m: m.__name__, forced_imports) - save((forced_names,)) - - #save((forced_imports,)) - write(pickle.REDUCE) - write(pickle.POP_MARK) - code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) save(_fill_function) # skeleton function updater @@ -318,6 +250,8 @@ def extract_code_globals(co): Find all globals names read or written to by codeblock co """ code = co.co_code + if not PY3: + code = [ord(c) for c in code] names = co.co_names out_names = set() @@ -327,18 +261,18 @@ def extract_code_globals(co): while i < n: op = code[i] - i = i+1 + i += 1 if op >= HAVE_ARGUMENT: - oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg + oparg = code[i] + code[i+1] * 256 + extended_arg extended_arg = 0 - i = i+2 + i += 2 if op == EXTENDED_ARG: - extended_arg = oparg*65536L + extended_arg = oparg*65536 if op in GLOBAL_OPS: out_names.add(names[oparg]) - #print 'extracted', out_names, ' from ', names - if co.co_consts: # see if nested function have any global refs + # see if nested function have any global refs + if co.co_consts: for const in co.co_consts: if type(const) is types.CodeType: out_names |= CloudPickler.extract_code_globals(const) @@ -350,46 +284,28 @@ def extract_func_data(self, func): Turn the function into a tuple of data necessary to recreate it: code, globals, defaults, closure, dict """ - code = func.func_code + code = func.__code__ # extract all global ref's - func_global_refs = CloudPickler.extract_code_globals(code) + func_global_refs = self.extract_code_globals(code) # process all variables referenced by global environment f_globals = {} for var in func_global_refs: - #Some names, such as class functions are not global - we don't need them - if func.func_globals.has_key(var): - f_globals[var] = func.func_globals[var] + if var in func.__globals__: + f_globals[var] = func.__globals__[var] # defaults requires no processing - defaults = func.func_defaults - - def get_contents(cell): - try: - return cell.cell_contents - except ValueError, e: #cell is empty error on not yet assigned - raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope') - + defaults = func.__defaults__ # process closure - if func.func_closure: - closure = map(get_contents, func.func_closure) - else: - closure = [] + closure = [c.cell_contents for c in func.__closure__] if func.__closure__ else [] # save the dict - dct = func.func_dict - - if printSerialization: - outvars = ['code: ' + str(code) ] - outvars.append('globals: ' + str(f_globals)) - outvars.append('defaults: ' + str(defaults)) - outvars.append('closure: ' + str(closure)) - print 'function ', func, 'is extracted to: ', ', '.join(outvars) + dct = func.__dict__ - base_globals = self.globals_ref.get(id(func.func_globals), {}) - self.globals_ref[id(func.func_globals)] = base_globals + base_globals = self.globals_ref.get(id(func.__globals__), {}) + self.globals_ref[id(func.__globals__)] = base_globals return (code, f_globals, defaults, closure, dct, base_globals) @@ -400,8 +316,9 @@ def save_builtin_function(self, obj): dispatch[types.BuiltinFunctionType] = save_builtin_function def save_global(self, obj, name=None, pack=struct.pack): - write = self.write - memo = self.memo + if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": + if obj in _BUILTIN_TYPE_NAMES: + return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) if name is None: name = obj.__name__ @@ -410,98 +327,57 @@ def save_global(self, obj, name=None, pack=struct.pack): if modname is None: modname = pickle.whichmodule(obj, name) - try: - __import__(modname) - themodule = sys.modules[modname] - except (ImportError, KeyError, AttributeError): #should never occur - raise pickle.PicklingError( - "Can't pickle %r: Module %s cannot be found" % - (obj, modname)) - if modname == '__main__': themodule = None - - if themodule: + else: + __import__(modname) + themodule = sys.modules[modname] self.modules.add(themodule) - sendRef = True - typ = type(obj) - #print 'saving', obj, typ - try: - try: #Deal with case when getattribute fails with exceptions - klass = getattr(themodule, name) - except (AttributeError): - if modname == '__builtin__': #new.* are misrepeported - modname = 'new' - __import__(modname) - themodule = sys.modules[modname] - try: - klass = getattr(themodule, name) - except AttributeError, a: - # print themodule, name, obj, type(obj) - raise pickle.PicklingError("Can't pickle builtin %s" % obj) - else: - raise + if hasattr(themodule, name) and getattr(themodule, name) is obj: + return Pickler.save_global(self, obj, name) - except (ImportError, KeyError, AttributeError): - if typ == types.TypeType or typ == types.ClassType: - sendRef = False - else: #we can't deal with this - raise - else: - if klass is not obj and (typ == types.TypeType or typ == types.ClassType): - sendRef = False - if not sendRef: - #note: Third party types might crash this - add better checks! - d = dict(obj.__dict__) #copy dict proxy to a dict - if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties - d.pop('__dict__',None) - d.pop('__weakref__',None) + typ = type(obj) + if typ is not obj and isinstance(obj, (type, types.ClassType)): + d = dict(obj.__dict__) # copy dict proxy to a dict + if not isinstance(d.get('__dict__', None), property): + # don't extract dict that are properties + d.pop('__dict__', None) + d.pop('__weakref__', None) # hack as __new__ is stored differently in the __dict__ new_override = d.get('__new__', None) if new_override: d['__new__'] = obj.__new__ - self.save_reduce(type(obj),(obj.__name__,obj.__bases__, - d),obj=obj) - #print 'internal reduce dask %s %s' % (obj, d) - return - - if self.proto >= 2: - code = _extension_registry.get((modname, name)) - if code: - assert code > 0 - if code <= 0xff: - write(pickle.EXT1 + chr(code)) - elif code <= 0xffff: - write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8)) - else: - write(pickle.EXT4 + pack("> sys.stderr, 'Cloud not import django settings %s:' % (name) - print_exec(sys.stderr) - if modified_env: - del os.environ['DJANGO_SETTINGS_MODULE'] - else: - #add project directory to sys,path: - if hasattr(module,'__file__'): - dirname = os.path.split(module.__file__)[0] + '/' - sys.path.append(dirname) # restores function attributes def _restore_attr(obj, attr): @@ -851,13 +636,16 @@ def _restore_attr(obj, attr): setattr(obj, key, val) return obj + def _get_module_builtins(): return pickle.__builtins__ + def print_exec(stream): ei = sys.exc_info() traceback.print_exception(ei[0], ei[1], ei[2], None, stream) + def _modules_to_main(modList): """Force every module in modList to be placed into main""" if not modList: @@ -868,22 +656,16 @@ def _modules_to_main(modList): if type(modname) is str: try: mod = __import__(modname) - except Exception, i: #catch all... - sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \ -A version mismatch is likely. Specific error was:\n' % modname) + except Exception as e: + sys.stderr.write('warning: could not import %s\n. ' + 'Your function may unexpectedly error due to this import failing;' + 'A version mismatch is likely. Specific error was:\n' % modname) print_exec(sys.stderr) else: - setattr(main,mod.__name__, mod) - else: - #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD) - #In old version actual module was sent - setattr(main,modname.__name__, modname) + setattr(main, mod.__name__, mod) -#object generators: -def _build_xrange(start, step, len): - """Built xrange explicitly""" - return xrange(start, start + step*len, step) +#object generators: def _genpartial(func, args, kwds): if not args: args = () @@ -891,22 +673,26 @@ def _genpartial(func, args, kwds): kwds = {} return partial(func, *args, **kwds) + def _fill_function(func, globals, defaults, dict): """ Fills in the rest of function data into the skeleton function object that were created via _make_skel_func(). """ - func.func_globals.update(globals) - func.func_defaults = defaults - func.func_dict = dict + func.__globals__.update(globals) + func.__defaults__ = defaults + func.__dict__ = dict return func + def _make_cell(value): - return (lambda: value).func_closure[0] + return (lambda: value).__closure__[0] + def _reconstruct_closure(values): return tuple([_make_cell(v) for v in values]) + def _make_skel_func(code, closures, base_globals = None): """ Creates a skeleton function object that contains just the provided code and the correct number of cells in func_closure. All other @@ -928,40 +714,3 @@ def _make_skel_func(code, closures, base_globals = None): def _getobject(modname, attribute): mod = __import__(modname, fromlist=[attribute]) return mod.__dict__[attribute] - -def _generateImage(size, mode, str_rep): - """Generate image from string representation""" - import Image - i = Image.new(mode, size) - i.fromstring(str_rep) - return i - -def _lazyloadImage(fp): - import Image - fp.seek(0) #works in almost any case - return Image.open(fp) - -"""Timeseries""" -def _genTimeSeries(reduce_args, state): - import scikits.timeseries.tseries as ts - from numpy import ndarray - from numpy.ma import MaskedArray - - - time_series = ts._tsreconstruct(*reduce_args) - - #from setstate modified - (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state - #print 'regenerating %s' % dtyp - - MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv)) - _dates = time_series._dates - #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ - ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm)) - _dates.freq = frq - _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None, - toobj=None, toord=None, tostr=None)) - # Update the _optinfo dictionary - time_series._optinfo.update(infodict) - return time_series - diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index dc7cd0bce56f3..924da3eecf214 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -44,7 +44,7 @@ >>> conf.get("spark.executorEnv.VAR1") u'value1' ->>> print conf.toDebugString() +>>> print(conf.toDebugString()) spark.executorEnv.VAR1=value1 spark.executorEnv.VAR3=value3 spark.executorEnv.VAR4=value4 @@ -56,6 +56,13 @@ __all__ = ['SparkConf'] +import sys +import re + +if sys.version > '3': + unicode = str + __doc__ = re.sub(r"(\W|^)[uU](['])", r'\1\2', __doc__) + class SparkConf(object): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 78dccc40470e3..1dc2fec0ae5c8 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + import os import shutil import sys @@ -32,11 +34,14 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel -from pyspark.rdd import RDD, _load_from_socket +from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.traceback_utils import CallSite, first_spark_call from pyspark.status import StatusTracker from pyspark.profiler import ProfilerCollector, BasicProfiler +if sys.version > '3': + xrange = range + __all__ = ['SparkContext'] @@ -133,7 +138,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, if sparkHome: self._conf.setSparkHome(sparkHome) if environment: - for key, value in environment.iteritems(): + for key, value in environment.items(): self._conf.setExecutorEnv(key, value) for key, value in DEFAULT_CONFIGS.items(): self._conf.setIfMissing(key, value) @@ -153,6 +158,10 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, if k.startswith("spark.executorEnv."): varName = k[len("spark.executorEnv."):] self.environment[varName] = v + if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ: + # disable randomness of hash of string in worker, if this is not + # launched by spark-submit + self.environment["PYTHONHASHSEED"] = "0" # Create the Java SparkContext through Py4J self._jsc = jsc or self._initialize_context(self._conf._jconf) @@ -323,7 +332,7 @@ def parallelize(self, c, numSlices=None): start0 = c[0] def getStart(split): - return start0 + (split * size / numSlices) * step + return start0 + int((split * size / numSlices)) * step def f(split, iterator): return xrange(getStart(split), getStart(split + 1), step) @@ -357,6 +366,7 @@ def pickleFile(self, name, minPartitions=None): minPartitions = minPartitions or self.defaultMinPartitions return RDD(self._jsc.objectFile(name, minPartitions), self) + @ignore_unicode_prefix def textFile(self, name, minPartitions=None, use_unicode=True): """ Read a text file from HDFS, a local file system (available on all @@ -369,7 +379,7 @@ def textFile(self, name, minPartitions=None, use_unicode=True): >>> path = os.path.join(tempdir, "sample-text.txt") >>> with open(path, "w") as testFile: - ... testFile.write("Hello world!") + ... _ = testFile.write("Hello world!") >>> textFile = sc.textFile(path) >>> textFile.collect() [u'Hello world!'] @@ -378,6 +388,7 @@ def textFile(self, name, minPartitions=None, use_unicode=True): return RDD(self._jsc.textFile(name, minPartitions), self, UTF8Deserializer(use_unicode)) + @ignore_unicode_prefix def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): """ Read a directory of text files from HDFS, a local file system @@ -411,9 +422,9 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): >>> dirPath = os.path.join(tempdir, "files") >>> os.mkdir(dirPath) >>> with open(os.path.join(dirPath, "1.txt"), "w") as file1: - ... file1.write("1") + ... _ = file1.write("1") >>> with open(os.path.join(dirPath, "2.txt"), "w") as file2: - ... file2.write("2") + ... _ = file2.write("2") >>> textFiles = sc.wholeTextFiles(dirPath) >>> sorted(textFiles.collect()) [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')] @@ -456,7 +467,7 @@ def _dictToJavaMap(self, d): jm = self._jvm.java.util.HashMap() if not d: d = {} - for k, v in d.iteritems(): + for k, v in d.items(): jm[k] = v return jm @@ -608,6 +619,7 @@ def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) return RDD(jrdd, self, input_deserializer) + @ignore_unicode_prefix def union(self, rdds): """ Build the union of a list of RDDs. @@ -618,7 +630,7 @@ def union(self, rdds): >>> path = os.path.join(tempdir, "union-text.txt") >>> with open(path, "w") as testFile: - ... testFile.write("Hello") + ... _ = testFile.write("Hello") >>> textFile = sc.textFile(path) >>> textFile.collect() [u'Hello'] @@ -677,7 +689,7 @@ def addFile(self, path): >>> from pyspark import SparkFiles >>> path = os.path.join(tempdir, "test.txt") >>> with open(path, "w") as testFile: - ... testFile.write("100") + ... _ = testFile.write("100") >>> sc.addFile(path) >>> def func(iterator): ... with open(SparkFiles.get("test.txt")) as testFile: @@ -705,11 +717,13 @@ def addPyFile(self, path): """ self.addFile(path) (dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix - if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: self._python_includes.append(filename) # for tests in local mode sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) + if sys.version > '3': + import importlib + importlib.invalidate_caches() def setCheckpointDir(self, dirName): """ @@ -744,7 +758,7 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): The application can use L{SparkContext.cancelJobGroup} to cancel all running jobs in this group. - >>> import thread, threading + >>> import threading >>> from time import sleep >>> result = "Not Set" >>> lock = threading.Lock() @@ -763,10 +777,10 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): ... sleep(5) ... sc.cancelJobGroup("job_to_cancel") >>> supress = lock.acquire() - >>> supress = thread.start_new_thread(start_job, (10,)) - >>> supress = thread.start_new_thread(stop_job, tuple()) + >>> supress = threading.Thread(target=start_job, args=(10,)).start() + >>> supress = threading.Thread(target=stop_job).start() >>> supress = lock.acquire() - >>> print result + >>> print(result) Cancelled If interruptOnCancel is set to true for the job group, then job cancellation will result diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 93885985fe377..7f06d4288c872 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -24,9 +24,10 @@ import traceback import time import gc -from errno import EINTR, ECHILD, EAGAIN +from errno import EINTR, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT + from pyspark.worker import main as worker_main from pyspark.serializers import read_int, write_int @@ -53,8 +54,8 @@ def worker(sock): # Read the socket using fdopen instead of socket.makefile() because the latter # seems to be very slow; note that we need to dup() the file descriptor because # otherwise writes also cause a seek that makes us miss data on the read side. - infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) + infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536) + outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536) exit_code = 0 try: worker_main(infile, outfile) @@ -68,17 +69,6 @@ def worker(sock): return exit_code -# Cleanup zombie children -def cleanup_dead_children(): - try: - while True: - pid, _ = os.waitpid(0, os.WNOHANG) - if not pid: - break - except: - pass - - def manager(): # Create a new process group to corral our children os.setpgid(0, 0) @@ -88,8 +78,12 @@ def manager(): listen_sock.bind(('127.0.0.1', 0)) listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() - write_int(listen_port, sys.stdout) - sys.stdout.flush() + + # re-open stdin/stdout in 'wb' mode + stdin_bin = os.fdopen(sys.stdin.fileno(), 'rb', 4) + stdout_bin = os.fdopen(sys.stdout.fileno(), 'wb', 4) + write_int(listen_port, stdout_bin) + stdout_bin.flush() def shutdown(code): signal.signal(SIGTERM, SIG_DFL) @@ -101,6 +95,7 @@ def handle_sigterm(*args): shutdown(1) signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP + signal.signal(SIGCHLD, SIG_IGN) reuse = os.environ.get("SPARK_REUSE_WORKER") @@ -115,12 +110,9 @@ def handle_sigterm(*args): else: raise - # cleanup in signal handler will cause deadlock - cleanup_dead_children() - if 0 in ready_fds: try: - worker_pid = read_int(sys.stdin) + worker_pid = read_int(stdin_bin) except EOFError: # Spark told us to exit by closing stdin shutdown(0) @@ -145,7 +137,7 @@ def handle_sigterm(*args): time.sleep(1) pid = os.fork() # error here will shutdown daemon else: - outfile = sock.makefile('w') + outfile = sock.makefile(mode='wb') write_int(e.errno, outfile) # Signal that the fork failed outfile.flush() outfile.close() @@ -157,7 +149,7 @@ def handle_sigterm(*args): listen_sock.close() try: # Acknowledge that the fork was successful - outfile = sock.makefile("w") + outfile = sock.makefile(mode="wb") write_int(os.getpid(), outfile) outfile.flush() outfile.close() diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index bc441f138f7fc..4ef2afe03544f 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -627,51 +627,49 @@ def merge(iterables, key=None, reverse=False): if key is None: for order, it in enumerate(map(iter, iterables)): try: - next = it.next - h_append([next(), order * direction, next]) + h_append([next(it), order * direction, it]) except StopIteration: pass _heapify(h) while len(h) > 1: try: while True: - value, order, next = s = h[0] + value, order, it = s = h[0] yield value - s[0] = next() # raises StopIteration when exhausted + s[0] = next(it) # raises StopIteration when exhausted _heapreplace(h, s) # restore heap condition except StopIteration: _heappop(h) # remove empty iterator if h: # fast case when only a single iterator remains - value, order, next = h[0] + value, order, it = h[0] yield value - for value in next.__self__: + for value in it: yield value return for order, it in enumerate(map(iter, iterables)): try: - next = it.next - value = next() - h_append([key(value), order * direction, value, next]) + value = next(it) + h_append([key(value), order * direction, value, it]) except StopIteration: pass _heapify(h) while len(h) > 1: try: while True: - key_value, order, value, next = s = h[0] + key_value, order, value, it = s = h[0] yield value - value = next() + value = next(it) s[0] = key(value) s[2] = value _heapreplace(h, s) except StopIteration: _heappop(h) if h: - key_value, order, value, next = h[0] + key_value, order, value, it = h[0] yield value - for value in next.__self__: + for value in it: yield value diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 2a5e84a7dfdb4..45bc38f7e61f8 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -69,7 +69,7 @@ def preexec_func(): if callback_socket in readable: gateway_connection = callback_socket.accept()[0] # Determine which ephemeral port the server started on: - gateway_port = read_int(gateway_connection.makefile()) + gateway_port = read_int(gateway_connection.makefile(mode="rb")) gateway_connection.close() callback_socket.close() if gateway_port is None: diff --git a/python/pyspark/join.py b/python/pyspark/join.py index c3491defb2b29..94df3990164d6 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -32,6 +32,7 @@ """ from pyspark.resultiterable import ResultIterable +from functools import reduce def _do_python_join(rdd, other, numPartitions, dispatch): diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d7bc09fd77adb..45754bc9d4b10 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -39,10 +39,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> lr = LogisticRegression(maxIter=5, regParam=0.01) >>> model = lr.fit(df) >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() - >>> print model.transform(test0).head().prediction + >>> model.transform(test0).head().prediction 0.0 >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() - >>> print model.transform(test1).head().prediction + >>> model.transform(test1).head().prediction 1.0 >>> lr.setParams("vector") Traceback (most recent call last): diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 263fe2a5bcc41..4e4614b859ac6 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -15,6 +15,7 @@ # limitations under the License. # +from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaTransformer @@ -24,6 +25,7 @@ @inherit_doc +@ignore_unicode_prefix class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ A tokenizer that converts the input string to lowercase and then @@ -32,15 +34,15 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): >>> from pyspark.sql import Row >>> df = sc.parallelize([Row(text="a b c")]).toDF() >>> tokenizer = Tokenizer(inputCol="text", outputCol="words") - >>> print tokenizer.transform(df).head() + >>> tokenizer.transform(df).head() Row(text=u'a b c', words=[u'a', u'b', u'c']) >>> # Change a parameter. - >>> print tokenizer.setParams(outputCol="tokens").transform(df).head() + >>> tokenizer.setParams(outputCol="tokens").transform(df).head() Row(text=u'a b c', tokens=[u'a', u'b', u'c']) >>> # Temporarily modify a parameter. - >>> print tokenizer.transform(df, {tokenizer.outputCol: "words"}).head() + >>> tokenizer.transform(df, {tokenizer.outputCol: "words"}).head() Row(text=u'a b c', words=[u'a', u'b', u'c']) - >>> print tokenizer.transform(df).head() + >>> tokenizer.transform(df).head() Row(text=u'a b c', tokens=[u'a', u'b', u'c']) >>> # Must use keyword arguments to specify params. >>> tokenizer.setParams("text") @@ -79,13 +81,13 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): >>> from pyspark.sql import Row >>> df = sc.parallelize([Row(words=["a", "b", "c"])]).toDF() >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features") - >>> print hashingTF.transform(df).head().features - (10,[7,8,9],[1.0,1.0,1.0]) - >>> print hashingTF.setParams(outputCol="freqs").transform(df).head().freqs - (10,[7,8,9],[1.0,1.0,1.0]) + >>> hashingTF.transform(df).head().features + SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0}) + >>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqs + SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0}) >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} - >>> print hashingTF.transform(df, params).head().vector - (5,[2,3,4],[1.0,1.0,1.0]) + >>> hashingTF.transform(df, params).head().vector + SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0}) """ _java_class = "org.apache.spark.ml.feature.HashingTF" diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 5c62620562a84..9fccb65675185 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -63,8 +63,8 @@ def params(self): uses :py:func:`dir` to get all attributes of type :py:class:`Param`. """ - return filter(lambda attr: isinstance(attr, Param), - [getattr(self, x) for x in dir(self) if x != "params"]) + return list(filter(lambda attr: isinstance(attr, Param), + [getattr(self, x) for x in dir(self) if x != "params"])) def _explain(self, param): """ @@ -185,7 +185,7 @@ def _set(self, **kwargs): """ Sets user-supplied params. """ - for param, value in kwargs.iteritems(): + for param, value in kwargs.items(): self.paramMap[getattr(self, param)] = value return self @@ -193,6 +193,6 @@ def _setDefault(self, **kwargs): """ Sets default params. """ - for param, value in kwargs.iteritems(): + for param, value in kwargs.items(): self.defaultParamMap[getattr(self, param)] = value return self diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 55f422497672f..6a3192465d66d 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -15,6 +15,8 @@ # limitations under the License. # +from __future__ import print_function + header = """# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -82,9 +84,9 @@ def get$Name(self): .replace("$defaultValueStr", str(defaultValueStr)) if __name__ == "__main__": - print header - print "\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n" - print "from pyspark.ml.param import Param, Params\n\n" + print(header) + print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n") + print("from pyspark.ml.param import Param, Params\n\n") shared = [ ("maxIter", "max number of iterations", None), ("regParam", "regularization constant", None), @@ -97,4 +99,4 @@ def get$Name(self): code = [] for name, doc, defaultValueStr in shared: code.append(_gen_param_code(name, doc, defaultValueStr)) - print "\n\n\n".join(code) + print("\n\n\n".join(code)) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index f2ef573fe9f6f..07507b2ad0d05 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -18,6 +18,7 @@ """ Python bindings for MLlib. """ +from __future__ import absolute_import # MLlib currently needs NumPy 1.4+, so complain if lower @@ -29,7 +30,9 @@ 'recommendation', 'regression', 'stat', 'tree', 'util'] import sys -import rand as random -random.__name__ = 'random' -random.RandomRDDs.__module__ = __name__ + '.random' -sys.modules[__name__ + '.random'] = random +from . import rand as random +modname = __name__ + '.random' +random.__name__ = modname +random.RandomRDDs.__module__ = modname +sys.modules[modname] = random +del modname, sys diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 2466e8ac43458..eda0b60f8b1e7 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -510,9 +510,10 @@ def save(self, sc, path): def load(cls, sc, path): java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load( sc._jsc.sc(), path) - py_labels = _java2py(sc, java_model.labels()) - py_pi = _java2py(sc, java_model.pi()) - py_theta = _java2py(sc, java_model.theta()) + # Can not unpickle array.array from Pyrolite in Python3 with "bytes" + py_labels = _java2py(sc, java_model.labels(), "latin1") + py_pi = _java2py(sc, java_model.pi(), "latin1") + py_theta = _java2py(sc, java_model.theta(), "latin1") return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta)) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 464f49aeee3cd..abbb7cf60eece 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -15,6 +15,12 @@ # limitations under the License. # +import sys +import array as pyarray + +if sys.version > '3': + xrange = range + from numpy import array from pyspark import RDD @@ -55,8 +61,8 @@ class KMeansModel(Saveable, Loader): True >>> model.predict(sparse_data[2]) == model.predict(sparse_data[3]) True - >>> type(model.clusterCenters) - + >>> isinstance(model.clusterCenters, list) + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) @@ -90,7 +96,7 @@ def predict(self, x): return best def save(self, sc, path): - java_centers = _py2java(sc, map(_convert_to_vector, self.centers)) + java_centers = _py2java(sc, [_convert_to_vector(c) for c in self.centers]) java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers) java_model.save(sc._jsc.sc(), path) @@ -133,7 +139,7 @@ class GaussianMixtureModel(object): ... 5.7048, 4.6567, 5.5026, ... 4.5605, 5.2043, 6.2734]).reshape(5, 3)) >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, - ... maxIterations=150, seed=10) + ... maxIterations=150, seed=10) >>> labels = model.predict(clusterdata_2).collect() >>> labels[0]==labels[1]==labels[2] True @@ -168,8 +174,8 @@ def predictSoft(self, x): if isinstance(x, RDD): means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), - self.weights, means, sigmas) - return membership_matrix + _convert_to_vector(self.weights), means, sigmas) + return membership_matrix.map(lambda x: pyarray.array('d', x)) class GaussianMixture(object): diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index a539d2f2846f9..ba6058978880a 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -15,6 +15,11 @@ # limitations under the License. # +import sys +if sys.version >= '3': + long = int + unicode = str + import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject @@ -36,7 +41,7 @@ def _new_smart_decode(obj): if isinstance(obj, float): - s = unicode(obj) + s = str(obj) return _float_str_mapping.get(s, s) return _old_smart_decode(obj) @@ -74,15 +79,15 @@ def _py2java(sc, obj): obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) elif isinstance(obj, JavaObject): pass - elif isinstance(obj, (int, long, float, bool, basestring)): + elif isinstance(obj, (int, long, float, bool, bytes, unicode)): pass else: - bytes = bytearray(PickleSerializer().dumps(obj)) - obj = sc._jvm.SerDe.loads(bytes) + data = bytearray(PickleSerializer().dumps(obj)) + obj = sc._jvm.SerDe.loads(data) return obj -def _java2py(sc, r): +def _java2py(sc, r, encoding="bytes"): if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() # convert RDD into JavaRDD @@ -102,8 +107,8 @@ def _java2py(sc, r): except Py4JJavaError: pass # not pickable - if isinstance(r, bytearray): - r = PickleSerializer().loads(str(r)) + if isinstance(r, (bytearray, bytes)): + r = PickleSerializer().loads(bytes(r), encoding=encoding) return r diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 8be819aceec24..1140539a24e95 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -23,12 +23,17 @@ import sys import warnings import random +import binascii +if sys.version >= '3': + basestring = str + unicode = str from py4j.protocol import Py4JJavaError -from pyspark import RDD, SparkContext +from pyspark import SparkContext +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Vectors, Vector, _convert_to_vector +from pyspark.mllib.linalg import Vectors, _convert_to_vector __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] @@ -206,7 +211,7 @@ class HashingTF(object): >>> htf = HashingTF(100) >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) - SparseVector(100, {1: 1.0, 14: 1.0, 31: 2.0, 44: 2.0}) + SparseVector(100, {...}) """ def __init__(self, numFeatures=1 << 20): """ @@ -360,6 +365,7 @@ def getVectors(self): return self.call("getVectors") +@ignore_unicode_prefix class Word2Vec(object): """ Word2Vec creates vector representation of words in a text corpus. @@ -382,7 +388,7 @@ class Word2Vec(object): >>> sentence = "a b " * 100 + "a c " * 10 >>> localDoc = [sentence, sentence] >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) - >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) + >>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc) >>> syms = model.findSynonyms("a", 2) >>> [s[0] for s in syms] @@ -400,7 +406,7 @@ def __init__(self): self.learningRate = 0.025 self.numPartitions = 1 self.numIterations = 1 - self.seed = random.randint(0, sys.maxint) + self.seed = random.randint(0, sys.maxsize) self.minCount = 5 def setVectorSize(self, vectorSize): @@ -459,7 +465,7 @@ def fit(self, data): raise TypeError("data should be an RDD of list of string") jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), - int(self.numIterations), long(self.seed), + int(self.numIterations), int(self.seed), int(self.minCount)) return Word2VecModel(jmodel) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index 3aa6d79d7093c..628ccc01cf3cc 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -16,12 +16,14 @@ # from pyspark import SparkContext +from pyspark.rdd import ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc __all__ = ['FPGrowth', 'FPGrowthModel'] @inherit_doc +@ignore_unicode_prefix class FPGrowthModel(JavaModelWrapper): """ diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index a80320c52d1d0..38b3aa3ad460e 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -25,7 +25,13 @@ import sys import array -import copy_reg + +if sys.version >= '3': + basestring = str + xrange = range + import copyreg as copy_reg +else: + import copy_reg import numpy as np @@ -57,7 +63,7 @@ def fast_pickle_array(ar): def _convert_to_vector(l): if isinstance(l, Vector): return l - elif type(l) in (array.array, np.array, np.ndarray, list, tuple): + elif type(l) in (array.array, np.array, np.ndarray, list, tuple, xrange): return DenseVector(l) elif _have_scipy and scipy.sparse.issparse(l): assert l.shape[1] == 1, "Expected column vector" @@ -88,7 +94,7 @@ def _vector_size(v): """ if isinstance(v, Vector): return len(v) - elif type(v) in (array.array, list, tuple): + elif type(v) in (array.array, list, tuple, xrange): return len(v) elif type(v) == np.ndarray: if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1): @@ -193,7 +199,7 @@ class DenseVector(Vector): DenseVector([1.0, 0.0]) """ def __init__(self, ar): - if isinstance(ar, basestring): + if isinstance(ar, bytes): ar = np.frombuffer(ar, dtype=np.float64) elif not isinstance(ar, np.ndarray): ar = np.array(ar, dtype=np.float64) @@ -321,11 +327,13 @@ def func(self, other): __sub__ = _delegate("__sub__") __mul__ = _delegate("__mul__") __div__ = _delegate("__div__") + __truediv__ = _delegate("__truediv__") __mod__ = _delegate("__mod__") __radd__ = _delegate("__radd__") __rsub__ = _delegate("__rsub__") __rmul__ = _delegate("__rmul__") __rdiv__ = _delegate("__rdiv__") + __rtruediv__ = _delegate("__rtruediv__") __rmod__ = _delegate("__rmod__") @@ -344,12 +352,12 @@ def __init__(self, size, *args): :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. - >>> print SparseVector(4, {1: 1.0, 3: 5.5}) - (4,[1,3],[1.0,5.5]) - >>> print SparseVector(4, [(1, 1.0), (3, 5.5)]) - (4,[1,3],[1.0,5.5]) - >>> print SparseVector(4, [1, 3], [1.0, 5.5]) - (4,[1,3],[1.0,5.5]) + >>> SparseVector(4, {1: 1.0, 3: 5.5}) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> SparseVector(4, [(1, 1.0), (3, 5.5)]) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> SparseVector(4, [1, 3], [1.0, 5.5]) + SparseVector(4, {1: 1.0, 3: 5.5}) """ self.size = int(size) assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments" @@ -361,8 +369,8 @@ def __init__(self, size, *args): self.indices = np.array([p[0] for p in pairs], dtype=np.int32) self.values = np.array([p[1] for p in pairs], dtype=np.float64) else: - if isinstance(args[0], basestring): - assert isinstance(args[1], str), "values should be string too" + if isinstance(args[0], bytes): + assert isinstance(args[1], bytes), "values should be string too" if args[0]: self.indices = np.frombuffer(args[0], np.int32) self.values = np.frombuffer(args[1], np.float64) @@ -591,12 +599,12 @@ def sparse(size, *args): :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. - >>> print Vectors.sparse(4, {1: 1.0, 3: 5.5}) - (4,[1,3],[1.0,5.5]) - >>> print Vectors.sparse(4, [(1, 1.0), (3, 5.5)]) - (4,[1,3],[1.0,5.5]) - >>> print Vectors.sparse(4, [1, 3], [1.0, 5.5]) - (4,[1,3],[1.0,5.5]) + >>> Vectors.sparse(4, {1: 1.0, 3: 5.5}) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> Vectors.sparse(4, [(1, 1.0), (3, 5.5)]) + SparseVector(4, {1: 1.0, 3: 5.5}) + >>> Vectors.sparse(4, [1, 3], [1.0, 5.5]) + SparseVector(4, {1: 1.0, 3: 5.5}) """ return SparseVector(size, *args) @@ -645,7 +653,7 @@ def _convert_to_array(array_like, dtype): """ Convert Matrix attributes which are array-like or buffer to array. """ - if isinstance(array_like, basestring): + if isinstance(array_like, bytes): return np.frombuffer(array_like, dtype=dtype) return np.asarray(array_like, dtype=dtype) @@ -677,7 +685,7 @@ def toArray(self): def toSparse(self): """Convert to SparseMatrix""" indices = np.nonzero(self.values)[0] - colCounts = np.bincount(indices / self.numRows) + colCounts = np.bincount(indices // self.numRows) colPtrs = np.cumsum(np.hstack( (0, colCounts, np.zeros(self.numCols - colCounts.size)))) values = self.values[indices] diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/rand.py index 20ee9d78bf5b0..06fbc0eb6aef0 100644 --- a/python/pyspark/mllib/rand.py +++ b/python/pyspark/mllib/rand.py @@ -88,10 +88,10 @@ def normalRDD(sc, size, numPartitions=None, seed=None): :param seed: Random seed (default: a random long integer). :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). - >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - 0.0) < 0.1 True >>> abs(stats.stdev() - 1.0) < 0.1 @@ -118,10 +118,10 @@ def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None): >>> std = 1.0 >>> expMean = exp(mean + 0.5 * std * std) >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) - >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2L) + >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - expMean) < 0.5 True >>> from math import sqrt @@ -145,10 +145,10 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): :return: RDD of float comprised of i.i.d. samples ~ Pois(mean). >>> mean = 100.0 - >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L) + >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - mean) < 0.5 True >>> from math import sqrt @@ -171,10 +171,10 @@ def exponentialRDD(sc, mean, size, numPartitions=None, seed=None): :return: RDD of float comprised of i.i.d. samples ~ Exp(mean). >>> mean = 2.0 - >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2L) + >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - mean) < 0.5 True >>> from math import sqrt @@ -202,10 +202,10 @@ def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None): >>> scale = 2.0 >>> expMean = shape * scale >>> expStd = sqrt(shape * scale * scale) - >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2L) + >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2) >>> stats = x.stats() >>> stats.count() - 1000L + 1000 >>> abs(stats.mean() - expMean) < 0.5 True >>> abs(stats.stdev() - expStd) < 0.5 @@ -254,7 +254,7 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. >>> import numpy as np - >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) + >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect()) >>> mat.shape (100, 100) >>> abs(mat.mean() - 0.0) < 0.1 @@ -286,8 +286,8 @@ def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed >>> std = 1.0 >>> expMean = exp(mean + 0.5 * std * std) >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std)) - >>> mat = np.matrix(RandomRDDs.logNormalVectorRDD(sc, mean, std, \ - 100, 100, seed=1L).collect()) + >>> m = RandomRDDs.logNormalVectorRDD(sc, mean, std, 100, 100, seed=1).collect() + >>> mat = np.matrix(m) >>> mat.shape (100, 100) >>> abs(mat.mean() - expMean) < 0.1 @@ -315,7 +315,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): >>> import numpy as np >>> mean = 100.0 - >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) + >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1) >>> mat = np.mat(rdd.collect()) >>> mat.shape (100, 100) @@ -345,7 +345,7 @@ def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=No >>> import numpy as np >>> mean = 0.5 - >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1L) + >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1) >>> mat = np.mat(rdd.collect()) >>> mat.shape (100, 100) @@ -380,8 +380,7 @@ def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed= >>> scale = 2.0 >>> expMean = shape * scale >>> expStd = sqrt(shape * scale * scale) - >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, \ - 100, 100, seed=1L).collect()) + >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, 100, 100, seed=1).collect()) >>> mat.shape (100, 100) >>> abs(mat.mean() - expMean) < 0.1 diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index c5c4c13dae105..80e0a356bb78a 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -15,6 +15,7 @@ # limitations under the License. # +import array from collections import namedtuple from pyspark import SparkContext @@ -104,14 +105,14 @@ def predictAll(self, user_product): assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" first = user_product.first() assert len(first) == 2, "user_product should be RDD of (user, product)" - user_product = user_product.map(lambda (u, p): (int(u), int(p))) + user_product = user_product.map(lambda u_p: (int(u_p[0]), int(u_p[1]))) return self.call("predict", user_product) def userFeatures(self): - return self.call("getUserFeatures") + return self.call("getUserFeatures").mapValues(lambda v: array.array('d', v)) def productFeatures(self): - return self.call("getProductFeatures") + return self.call("getProductFeatures").mapValues(lambda v: array.array('d', v)) @classmethod def load(cls, sc, path): diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 1d83e9d483f8e..b475be4b4d953 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -15,7 +15,7 @@ # limitations under the License. # -from pyspark import RDD +from pyspark.rdd import RDD, ignore_unicode_prefix from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Matrix, _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -38,7 +38,7 @@ def variance(self): return self.call("variance").toArray() def count(self): - return self.call("count") + return int(self.call("count")) def numNonzeros(self): return self.call("numNonzeros").toArray() @@ -78,7 +78,7 @@ def colStats(rdd): >>> cStats.variance() array([ 4., 13., 0., 25.]) >>> cStats.count() - 3L + 3 >>> cStats.numNonzeros() array([ 3., 2., 0., 3.]) >>> cStats.max() @@ -124,20 +124,20 @@ def corr(x, y=None, method=None): >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]), ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])]) >>> pearsonCorr = Statistics.corr(rdd) - >>> print str(pearsonCorr).replace('nan', 'NaN') + >>> print(str(pearsonCorr).replace('nan', 'NaN')) [[ 1. 0.05564149 NaN 0.40047142] [ 0.05564149 1. NaN 0.91359586] [ NaN NaN 1. NaN] [ 0.40047142 0.91359586 NaN 1. ]] >>> spearmanCorr = Statistics.corr(rdd, method="spearman") - >>> print str(spearmanCorr).replace('nan', 'NaN') + >>> print(str(spearmanCorr).replace('nan', 'NaN')) [[ 1. 0.10540926 NaN 0.4 ] [ 0.10540926 1. NaN 0.9486833 ] [ NaN NaN 1. NaN] [ 0.4 0.9486833 NaN 1. ]] >>> try: ... Statistics.corr(rdd, "spearman") - ... print "Method name as second argument without 'method=' shouldn't be allowed." + ... print("Method name as second argument without 'method=' shouldn't be allowed.") ... except TypeError: ... pass """ @@ -153,6 +153,7 @@ def corr(x, y=None, method=None): return callMLlibFunc("corr", x.map(float), y.map(float), method) @staticmethod + @ignore_unicode_prefix def chiSqTest(observed, expected=None): """ .. note:: Experimental @@ -188,11 +189,11 @@ def chiSqTest(observed, expected=None): >>> from pyspark.mllib.linalg import Vectors, Matrices >>> observed = Vectors.dense([4, 6, 5]) >>> pearson = Statistics.chiSqTest(observed) - >>> print pearson.statistic + >>> print(pearson.statistic) 0.4 >>> pearson.degreesOfFreedom 2 - >>> print round(pearson.pValue, 4) + >>> print(round(pearson.pValue, 4)) 0.8187 >>> pearson.method u'pearson' @@ -202,12 +203,12 @@ def chiSqTest(observed, expected=None): >>> observed = Vectors.dense([21, 38, 43, 80]) >>> expected = Vectors.dense([3, 5, 7, 20]) >>> pearson = Statistics.chiSqTest(observed, expected) - >>> print round(pearson.pValue, 4) + >>> print(round(pearson.pValue, 4)) 0.0027 >>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] >>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) - >>> print round(chi.statistic, 4) + >>> print(round(chi.statistic, 4)) 21.9958 >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), @@ -218,9 +219,9 @@ def chiSqTest(observed, expected=None): ... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),] >>> rdd = sc.parallelize(data, 4) >>> chi = Statistics.chiSqTest(rdd) - >>> print chi[0].statistic + >>> print(chi[0].statistic) 0.75 - >>> print chi[1].statistic + >>> print(chi[1].statistic) 1.5 """ if isinstance(observed, RDD): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 8eaddcf8b9b5e..c6ed5acd1770e 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -72,11 +72,11 @@ class VectorTests(PySparkTestCase): def _test_serialize(self, v): self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v))) - nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec))) + nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec))) self.assertEqual(v, nv) vs = [v] * 100 jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs))) - nvs = ser.loads(str(self.sc._jvm.SerDe.dumps(jvecs))) + nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs))) self.assertEqual(vs, nvs) def test_serialize(self): @@ -412,11 +412,11 @@ def test_col_norms(self): self.assertEqual(10, len(summary.normL1())) self.assertEqual(10, len(summary.normL2())) - data2 = self.sc.parallelize(xrange(10)).map(lambda x: Vectors.dense(x)) + data2 = self.sc.parallelize(range(10)).map(lambda x: Vectors.dense(x)) summary2 = Statistics.colStats(data2) self.assertEqual(array([45.0]), summary2.normL1()) import math - expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, xrange(10)))) + expectedNormL2 = math.sqrt(sum(map(lambda x: x*x, range(10)))) self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) @@ -438,11 +438,11 @@ def test_serialization(self): def test_infer_schema(self): sqlCtx = SQLContext(self.sc) rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) - srdd = sqlCtx.inferSchema(rdd) - schema = srdd.schema + df = rdd.toDF() + schema = df.schema field = [f for f in schema.fields if f.name == "features"][0] self.assertEqual(field.dataType, self.udt) - vectors = srdd.map(lambda p: p.features).collect() + vectors = df.map(lambda p: p.features).collect() self.assertEqual(len(vectors), 2) for v in vectors: if isinstance(v, SparseVector): @@ -695,7 +695,7 @@ def test_right_number_of_results(self): class SerDeTest(PySparkTestCase): def test_to_java_object_rdd(self): # SPARK-6660 - data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0L) + data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) self.assertEqual(_to_java_object_rdd(data).count(), 10) @@ -771,7 +771,7 @@ def test_model_transform(self): if __name__ == "__main__": if not _have_scipy: - print "NOTE: Skipping SciPy tests as it does not seem to be installed" + print("NOTE: Skipping SciPy tests as it does not seem to be installed") unittest.main() if not _have_scipy: - print "NOTE: SciPy tests were skipped as it does not seem to be installed" + print("NOTE: SciPy tests were skipped as it does not seem to be installed") diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index a7a4d2aaf855b..0fe6e4fabe43a 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -163,14 +163,16 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, ... LabeledPoint(1.0, [3.0]) ... ] >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {}) - >>> print model, # it already has newline + >>> print(model) DecisionTreeModel classifier of depth 1 with 3 nodes - >>> print model.toDebugString(), # it already has newline + + >>> print(model.toDebugString()) DecisionTreeModel classifier of depth 1 with 3 nodes If (feature 0 <= 0.0) Predict: 0.0 Else (feature 0 > 0.0) Predict: 1.0 + >>> model.predict(array([1.0])) 1.0 >>> model.predict(array([0.0])) @@ -318,9 +320,10 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, 3 >>> model.totalNumNodes() 7 - >>> print model, + >>> print(model) TreeEnsembleModel classifier with 3 trees - >>> print model.toDebugString(), + + >>> print(model.toDebugString()) TreeEnsembleModel classifier with 3 trees Tree 0: @@ -335,6 +338,7 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees, Predict: 0.0 Else (feature 0 > 1.0) Predict: 1.0 + >>> model.predict([2.0]) 1.0 >>> model.predict([0.0]) @@ -483,8 +487,9 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, 100 >>> model.totalNumNodes() 300 - >>> print model, # it already has newline + >>> print(model) # it already has newline TreeEnsembleModel classifier with 100 trees + >>> model.predict([2.0]) 1.0 >>> model.predict([0.0]) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index c5c3468eb95e9..16a90db146ef0 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -15,10 +15,14 @@ # limitations under the License. # +import sys import numpy as np import warnings -from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc +if sys.version > '3': + xrange = range + +from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -94,22 +98,16 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None >>> from pyspark.mllib.util import MLUtils >>> from pyspark.mllib.regression import LabeledPoint >>> tempFile = NamedTemporaryFile(delete=True) - >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") + >>> _ = tempFile.write(b"+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() >>> tempFile.close() - >>> type(examples[0]) == LabeledPoint - True - >>> print examples[0] - (1.0,(6,[0,2,4],[1.0,2.0,3.0])) - >>> type(examples[1]) == LabeledPoint - True - >>> print examples[1] - (-1.0,(6,[],[])) - >>> type(examples[2]) == LabeledPoint - True - >>> print examples[2] - (-1.0,(6,[1,3,5],[4.0,5.0,6.0])) + >>> examples[0] + LabeledPoint(1.0, (6,[0,2,4],[1.0,2.0,3.0])) + >>> examples[1] + LabeledPoint(-1.0, (6,[],[])) + >>> examples[2] + LabeledPoint(-1.0, (6,[1,3,5],[4.0,5.0,6.0])) """ from pyspark.mllib.regression import LabeledPoint if multiclass is not None: diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 4408996db0790..d18daaabfcb3c 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -84,11 +84,11 @@ class Profiler(object): >>> from pyspark import BasicProfiler >>> class MyCustomProfiler(BasicProfiler): ... def show(self, id): - ... print "My custom profiles for RDD:%s" % id + ... print("My custom profiles for RDD:%s" % id) ... >>> conf = SparkConf().set("spark.python.profile", "true") >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler) - >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) + >>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10) [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.show_profiles() My custom profiles for RDD:1 @@ -111,9 +111,9 @@ def show(self, id): """ Print the profile stats to stdout, id is the RDD id """ stats = self.stats() if stats: - print "=" * 60 - print "Profile of RDD" % id - print "=" * 60 + print("=" * 60) + print("Profile of RDD" % id) + print("=" * 60) stats.sort_stats("time", "cumulative").print_stats() def dump(self, id, path): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 93e658eded9e2..d9cdbb666f92a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -16,21 +16,29 @@ # import copy -from collections import defaultdict -from itertools import chain, ifilter, imap -import operator import sys +import os +import re +import operator import shlex -from subprocess import Popen, PIPE -from tempfile import NamedTemporaryFile -from threading import Thread import warnings import heapq import bisect import random import socket +from subprocess import Popen, PIPE +from tempfile import NamedTemporaryFile +from threading import Thread +from collections import defaultdict +from itertools import chain +from functools import reduce from math import sqrt, log, isinf, isnan, pow, ceil +if sys.version > '3': + basestring = unicode = str +else: + from itertools import imap as map, ifilter as filter + from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -50,20 +58,21 @@ __all__ = ["RDD"] -# TODO: for Python 3.3+, PYTHONHASHSEED should be reset to disable randomized -# hash for string def portable_hash(x): """ - This function returns consistant hash code for builtin types, especially + This function returns consistent hash code for builtin types, especially for None and tuple with None. - The algrithm is similar to that one used by CPython 2.7 + The algorithm is similar to that one used by CPython 2.7 >>> portable_hash(None) 0 >>> portable_hash((None, 1)) & 0xffffffff 219750521 """ + if sys.version >= '3.3' and 'PYTHONHASHSEED' not in os.environ: + raise Exception("Randomness of hash of string should be disabled via PYTHONHASHSEED") + if x is None: return 0 if isinstance(x, tuple): @@ -71,7 +80,7 @@ def portable_hash(x): for i in x: h ^= portable_hash(i) h *= 1000003 - h &= sys.maxint + h &= sys.maxsize h ^= len(x) if h == -1: h = -2 @@ -123,6 +132,19 @@ def _load_from_socket(port, serializer): sock.close() +def ignore_unicode_prefix(f): + """ + Ignore the 'u' prefix of string in doc tests, to make it works + in both python 2 and 3 + """ + if sys.version >= '3': + # the representation of unicode string in Python 3 does not have prefix 'u', + # so remove the prefix 'u' for doc tests + literal_re = re.compile(r"(\W|^)[uU](['])", re.UNICODE) + f.__doc__ = literal_re.sub(r'\1\2', f.__doc__) + return f + + class Partitioner(object): def __init__(self, numPartitions, partitionFunc): self.numPartitions = numPartitions @@ -251,7 +273,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return imap(f, iterator) + return map(f, iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -266,7 +288,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(imap(f, iterator)) + return chain.from_iterable(map(f, iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -329,7 +351,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return ifilter(f, iterator) + return filter(f, iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -341,7 +363,7 @@ def distinct(self, numPartitions=None): """ return self.map(lambda x: (x, None)) \ .reduceByKey(lambda x, _: x, numPartitions) \ - .map(lambda (x, _): x) + .map(lambda x: x[0]) def sample(self, withReplacement, fraction, seed=None): """ @@ -354,8 +376,8 @@ def sample(self, withReplacement, fraction, seed=None): :param seed: seed for the random number generator >>> rdd = sc.parallelize(range(100), 4) - >>> rdd.sample(False, 0.1, 81).count() - 10 + >>> 6 <= rdd.sample(False, 0.1, 81).count() <= 14 + True """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) @@ -368,12 +390,14 @@ def randomSplit(self, weights, seed=None): :param seed: random seed :return: split RDDs in a list - >>> rdd = sc.parallelize(range(5), 1) + >>> rdd = sc.parallelize(range(500), 1) >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17) - >>> rdd1.collect() - [1, 3] - >>> rdd2.collect() - [0, 2, 4] + >>> len(rdd1.collect() + rdd2.collect()) + 500 + >>> 150 < rdd1.count() < 250 + True + >>> 250 < rdd2.count() < 350 + True """ s = float(sum(weights)) cweights = [0.0] @@ -416,7 +440,7 @@ def takeSample(self, withReplacement, num, seed=None): rand.shuffle(samples) return samples - maxSampleSize = sys.maxint - int(numStDev * sqrt(sys.maxint)) + maxSampleSize = sys.maxsize - int(numStDev * sqrt(sys.maxsize)) if num > maxSampleSize: raise ValueError( "Sample size cannot be greater than %d." % maxSampleSize) @@ -430,7 +454,7 @@ def takeSample(self, withReplacement, num, seed=None): # See: scala/spark/RDD.scala while len(samples) < num: # TODO: add log warning for when more than one iteration was run - seed = rand.randint(0, sys.maxint) + seed = rand.randint(0, sys.maxsize) samples = self.sample(withReplacement, fraction, seed).collect() rand.shuffle(samples) @@ -507,7 +531,7 @@ def intersection(self, other): """ return self.map(lambda v: (v, None)) \ .cogroup(other.map(lambda v: (v, None))) \ - .filter(lambda (k, vs): all(vs)) \ + .filter(lambda k_vs: all(k_vs[1])) \ .keys() def _reserialize(self, serializer=None): @@ -549,7 +573,7 @@ def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=p def sortPartition(iterator): sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted - return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))) + return iter(sort(iterator, key=lambda k_v: keyfunc(k_v[0]), reverse=(not ascending))) return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) @@ -579,7 +603,7 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): def sortPartition(iterator): sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted - return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))) + return iter(sort(iterator, key=lambda kv: keyfunc(kv[0]), reverse=(not ascending))) if numPartitions == 1: if self.getNumPartitions() > 1: @@ -594,12 +618,12 @@ def sortPartition(iterator): return self # empty RDD maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) - samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() + samples = self.sample(False, fraction, 1).map(lambda kv: kv[0]).collect() samples = sorted(samples, key=keyfunc) # we have numPartitions many parts but one of the them has # an implicit boundary - bounds = [samples[len(samples) * (i + 1) / numPartitions] + bounds = [samples[int(len(samples) * (i + 1) / numPartitions)] for i in range(0, numPartitions - 1)] def rangePartitioner(k): @@ -662,12 +686,13 @@ def groupBy(self, f, numPartitions=None): """ return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) + @ignore_unicode_prefix def pipe(self, command, env={}): """ Return an RDD created by piping elements to a forked external process. >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect() - ['1', '2', '', '3'] + [u'1', u'2', u'', u'3'] """ def func(iterator): pipe = Popen( @@ -675,17 +700,18 @@ def func(iterator): def pipe_objs(out): for obj in iterator: - out.write(str(obj).rstrip('\n') + '\n') + s = str(obj).rstrip('\n') + '\n' + out.write(s.encode('utf-8')) out.close() Thread(target=pipe_objs, args=[pipe.stdin]).start() - return (x.rstrip('\n') for x in iter(pipe.stdout.readline, '')) + return (x.rstrip(b'\n').decode('utf-8') for x in iter(pipe.stdout.readline, b'')) return self.mapPartitions(func) def foreach(self, f): """ Applies a function to all elements of this RDD. - >>> def f(x): print x + >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ def processPartition(iterator): @@ -700,7 +726,7 @@ def foreachPartition(self, f): >>> def f(iterator): ... for x in iterator: - ... print x + ... print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ def func(it): @@ -874,7 +900,7 @@ def aggregatePartition(iterator): # aggregation. while numPartitions > scale + numPartitions / scale: numPartitions /= scale - curNumPartitions = numPartitions + curNumPartitions = int(numPartitions) def mapPartition(i, iterator): for obj in iterator: @@ -984,7 +1010,7 @@ def histogram(self, buckets): (('a', 'b', 'c'), [2, 2]) """ - if isinstance(buckets, (int, long)): + if isinstance(buckets, int): if buckets < 1: raise ValueError("number of buckets must be >= 1") @@ -1020,6 +1046,7 @@ def minmax(a, b): raise ValueError("Can not generate buckets with infinite value") # keep them as integer if possible + inc = int(inc) if inc * buckets != maxv - minv: inc = (maxv - minv) * 1.0 / buckets @@ -1137,7 +1164,7 @@ def countPartition(iterator): yield counts def mergeMaps(m1, m2): - for k, v in m2.iteritems(): + for k, v in m2.items(): m1[k] += v return m1 return self.mapPartitions(countPartition).reduce(mergeMaps) @@ -1378,8 +1405,8 @@ def saveAsPickleFile(self, path, batchSize=10): >>> tmpFile = NamedTemporaryFile(delete=True) >>> tmpFile.close() >>> sc.parallelize([1, 2, 'spark', 'rdd']).saveAsPickleFile(tmpFile.name, 3) - >>> sorted(sc.pickleFile(tmpFile.name, 5).collect()) - [1, 2, 'rdd', 'spark'] + >>> sorted(sc.pickleFile(tmpFile.name, 5).map(str).collect()) + ['1', '2', 'rdd', 'spark'] """ if batchSize == 0: ser = AutoBatchedSerializer(PickleSerializer()) @@ -1387,6 +1414,7 @@ def saveAsPickleFile(self, path, batchSize=10): ser = BatchedSerializer(PickleSerializer(), batchSize) self._reserialize(ser)._jrdd.saveAsObjectFile(path) + @ignore_unicode_prefix def saveAsTextFile(self, path, compressionCodecClass=None): """ Save this RDD as a text file, using string representations of elements. @@ -1418,12 +1446,13 @@ def saveAsTextFile(self, path, compressionCodecClass=None): >>> codec = "org.apache.hadoop.io.compress.GzipCodec" >>> sc.parallelize(['foo', 'bar']).saveAsTextFile(tempFile3.name, codec) >>> from fileinput import input, hook_compressed - >>> ''.join(sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed))) - 'bar\\nfoo\\n' + >>> result = sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed)) + >>> b''.join(result).decode('utf-8') + u'bar\\nfoo\\n' """ def func(split, iterator): for x in iterator: - if not isinstance(x, basestring): + if not isinstance(x, (unicode, bytes)): x = unicode(x) if isinstance(x, unicode): x = x.encode("utf-8") @@ -1458,7 +1487,7 @@ def keys(self): >>> m.collect() [1, 3] """ - return self.map(lambda (k, v): k) + return self.map(lambda x: x[0]) def values(self): """ @@ -1468,7 +1497,7 @@ def values(self): >>> m.collect() [2, 4] """ - return self.map(lambda (k, v): v) + return self.map(lambda x: x[1]) def reduceByKey(self, func, numPartitions=None): """ @@ -1507,7 +1536,7 @@ def reducePartition(iterator): yield m def mergeMaps(m1, m2): - for k, v in m2.iteritems(): + for k, v in m2.items(): m1[k] = func(m1[k], v) if k in m1 else v return m1 return self.mapPartitions(reducePartition).reduce(mergeMaps) @@ -1604,8 +1633,8 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x)) >>> sets = pairs.partitionBy(2).glom().collect() - >>> set(sets[0]).intersection(set(sets[1])) - set([]) + >>> len(set(sets[0]).intersection(set(sets[1]))) + 0 """ if numPartitions is None: numPartitions = self._defaultReducePartitions() @@ -1637,22 +1666,22 @@ def add_shuffle_key(split, iterator): if (c % 1000 == 0 and get_used_memory() > limit or c > batch): n, size = len(buckets), 0 - for split in buckets.keys(): + for split in list(buckets.keys()): yield pack_long(split) d = outputSerializer.dumps(buckets[split]) del buckets[split] yield d size += len(d) - avg = (size / n) >> 20 + avg = int(size / n) >> 20 # let 1M < avg < 10M if avg < 1: batch *= 1.5 elif avg > 10: - batch = max(batch / 1.5, 1) + batch = max(int(batch / 1.5), 1) c = 0 - for split, items in buckets.iteritems(): + for split, items in buckets.items(): yield pack_long(split) yield outputSerializer.dumps(items) @@ -1707,7 +1736,7 @@ def combineLocally(iterator): merger = ExternalMerger(agg, memory * 0.9, serializer) \ if spill else InMemoryMerger(agg) merger.mergeValues(iterator) - return merger.iteritems() + return merger.items() locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True) shuffled = locally_combined.partitionBy(numPartitions) @@ -1716,7 +1745,7 @@ def _mergeCombiners(iterator): merger = ExternalMerger(agg, memory, serializer) \ if spill else InMemoryMerger(agg) merger.mergeCombiners(iterator) - return merger.iteritems() + return merger.items() return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True) @@ -1745,7 +1774,7 @@ def foldByKey(self, zeroValue, func, numPartitions=None): >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> from operator import add - >>> rdd.foldByKey(0, add).collect() + >>> sorted(rdd.foldByKey(0, add).collect()) [('a', 2), ('b', 1)] """ def createZero(): @@ -1769,10 +1798,10 @@ def groupByKey(self, numPartitions=None): sum or average) over each key, using reduceByKey or aggregateByKey will provide much better performance. - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(x.groupByKey().mapValues(len).collect()) + >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) + >>> sorted(rdd.groupByKey().mapValues(len).collect()) [('a', 2), ('b', 1)] - >>> sorted(x.groupByKey().mapValues(list).collect()) + >>> sorted(rdd.groupByKey().mapValues(list).collect()) [('a', [1, 1]), ('b', [1])] """ def createCombiner(x): @@ -1795,7 +1824,7 @@ def combine(iterator): merger = ExternalMerger(agg, memory * 0.9, serializer) \ if spill else InMemoryMerger(agg) merger.mergeValues(iterator) - return merger.iteritems() + return merger.items() locally_combined = self.mapPartitions(combine, preservesPartitioning=True) shuffled = locally_combined.partitionBy(numPartitions) @@ -1804,7 +1833,7 @@ def groupByKey(it): merger = ExternalGroupBy(agg, memory, serializer)\ if spill else InMemoryMerger(agg) merger.mergeCombiners(it) - return merger.iteritems() + return merger.items() return shuffled.mapPartitions(groupByKey, True).mapValues(ResultIterable) @@ -1819,7 +1848,7 @@ def flatMapValues(self, f): >>> x.flatMapValues(f).collect() [('a', 'x'), ('a', 'y'), ('a', 'z'), ('b', 'p'), ('b', 'r')] """ - flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1])) return self.flatMap(flat_map_fn, preservesPartitioning=True) def mapValues(self, f): @@ -1833,7 +1862,7 @@ def mapValues(self, f): >>> x.mapValues(f).collect() [('a', 3), ('b', 1)] """ - map_values_fn = lambda (k, v): (k, f(v)) + map_values_fn = lambda kv: (kv[0], f(kv[1])) return self.map(map_values_fn, preservesPartitioning=True) def groupWith(self, other, *others): @@ -1844,8 +1873,7 @@ def groupWith(self, other, *others): >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) >>> z = sc.parallelize([("b", 42)]) - >>> map((lambda (x,y): (x, (list(y[0]), list(y[1]), list(y[2]), list(y[3])))), \ - sorted(list(w.groupWith(x, y, z).collect()))) + >>> [(x, tuple(map(list, y))) for x, y in sorted(list(w.groupWith(x, y, z).collect()))] [('a', ([5], [1], [2], [])), ('b', ([6], [4], [], [42]))] """ @@ -1860,7 +1888,7 @@ def cogroup(self, other, numPartitions=None): >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) - >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect()))) + >>> [(x, tuple(map(list, y))) for x, y in sorted(list(x.cogroup(y).collect()))] [('a', ([1], [2])), ('b', ([4], []))] """ return python_cogroup((self, other), numPartitions) @@ -1896,8 +1924,9 @@ def subtractByKey(self, other, numPartitions=None): >>> sorted(x.subtractByKey(y).collect()) [('b', 4), ('b', 5)] """ - def filter_func((key, vals)): - return vals[0] and not vals[1] + def filter_func(pair): + key, (val1, val2) = pair + return val1 and not val2 return self.cogroup(other, numPartitions).filter(filter_func).flatMapValues(lambda x: x[0]) def subtract(self, other, numPartitions=None): @@ -1919,8 +1948,8 @@ def keyBy(self, f): >>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x) >>> y = sc.parallelize(zip(range(0,5), range(0,5))) - >>> map((lambda (x,y): (x, (list(y[0]), (list(y[1]))))), sorted(x.cogroup(y).collect())) - [(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))] + >>> [(x, list(map(list, y))) for x, y in sorted(x.cogroup(y).collect())] + [(0, [[0], [0]]), (1, [[1], [1]]), (2, [[], [2]]), (3, [[], [3]]), (4, [[2], [4]])] """ return self.map(lambda x: (f(x), x)) @@ -2049,17 +2078,18 @@ def name(self): """ Return the name of this RDD. """ - name_ = self._jrdd.name() - if name_: - return name_.encode('utf-8') + n = self._jrdd.name() + if n: + return n + @ignore_unicode_prefix def setName(self, name): """ Assign a name to this RDD. - >>> rdd1 = sc.parallelize([1,2]) + >>> rdd1 = sc.parallelize([1, 2]) >>> rdd1.setName('RDD1').name() - 'RDD1' + u'RDD1' """ self._jrdd.setName(name) return self @@ -2121,7 +2151,7 @@ def lookup(self, key): >>> sorted.lookup(1024) [] """ - values = self.filter(lambda (k, v): k == key).values() + values = self.filter(lambda kv: kv[0] == key).values() if self.partitioner is not None: return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False) @@ -2159,7 +2189,7 @@ def sumApprox(self, timeout, confidence=0.95): or meet the confidence. >>> rdd = sc.parallelize(range(1000), 10) - >>> r = sum(xrange(1000)) + >>> r = sum(range(1000)) >>> (rdd.sumApprox(1000) - r) / r < 0.05 True """ @@ -2176,7 +2206,7 @@ def meanApprox(self, timeout, confidence=0.95): or meet the confidence. >>> rdd = sc.parallelize(range(1000), 10) - >>> r = sum(xrange(1000)) / 1000.0 + >>> r = sum(range(1000)) / 1000.0 >>> (rdd.meanApprox(1000) - r) / r < 0.05 True """ @@ -2201,10 +2231,10 @@ def countApproxDistinct(self, relativeSD=0.05): It must be greater than 0.000017. >>> n = sc.parallelize(range(1000)).map(str).countApproxDistinct() - >>> 950 < n < 1050 + >>> 900 < n < 1100 True >>> n = sc.parallelize([i % 20 for i in range(1000)]).countApproxDistinct() - >>> 18 < n < 22 + >>> 16 < n < 24 True """ if relativeSD < 0.000017: @@ -2223,8 +2253,7 @@ def toLocalIterator(self): >>> [x for x in rdd.toLocalIterator()] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ - partitions = xrange(self.getNumPartitions()) - for partition in partitions: + for partition in range(self.getNumPartitions()): rows = self.context.runJob(self, lambda x: x, [partition]) for row in rows: yield row diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 459e1427803cb..fe8f87324804b 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -23,7 +23,7 @@ class RDDSamplerBase(object): def __init__(self, withReplacement, seed=None): - self._seed = seed if seed is not None else random.randint(0, sys.maxint) + self._seed = seed if seed is not None else random.randint(0, sys.maxsize) self._withReplacement = withReplacement self._random = None @@ -31,7 +31,7 @@ def initRandomGenerator(self, split): self._random = random.Random(self._seed ^ split) # mixing because the initial seeds are close to each other - for _ in xrange(10): + for _ in range(10): self._random.randint(0, 1) def getUniformSample(self): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 4afa82f4b2973..d8cdcda3a3783 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -49,16 +49,24 @@ >>> sc.stop() """ -import cPickle -from itertools import chain, izip, product +import sys +from itertools import chain, product import marshal import struct -import sys import types import collections import zlib import itertools +if sys.version < '3': + import cPickle as pickle + protocol = 2 + from itertools import izip as zip +else: + import pickle + protocol = 3 + xrange = range + from pyspark import cloudpickle @@ -97,7 +105,7 @@ def _load_stream_without_unbatching(self, stream): # subclasses should override __eq__ as appropriate. def __eq__(self, other): - return isinstance(other, self.__class__) + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) @@ -212,10 +220,6 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): return self.serializer.load_stream(stream) - def __eq__(self, other): - return (isinstance(other, BatchedSerializer) and - other.serializer == self.serializer and other.batchSize == self.batchSize) - def __repr__(self): return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) @@ -233,14 +237,14 @@ def __init__(self, serializer, batchSize=10): def _batched(self, iterator): n = self.batchSize for key, values in iterator: - for i in xrange(0, len(values), n): + for i in range(0, len(values), n): yield key, values[i:i + n] def load_stream(self, stream): return self.serializer.load_stream(stream) def __repr__(self): - return "FlattenedValuesSerializer(%d)" % self.batchSize + return "FlattenedValuesSerializer(%s, %d)" % (self.serializer, self.batchSize) class AutoBatchedSerializer(BatchedSerializer): @@ -270,12 +274,8 @@ def dump_stream(self, iterator, stream): elif size > best * 10 and batch > 1: batch /= 2 - def __eq__(self, other): - return (isinstance(other, AutoBatchedSerializer) and - other.serializer == self.serializer and other.bestSize == self.bestSize) - def __repr__(self): - return "AutoBatchedSerializer(%s)" % str(self.serializer) + return "AutoBatchedSerializer(%s)" % self.serializer class CartesianDeserializer(FramedSerializer): @@ -285,6 +285,7 @@ class CartesianDeserializer(FramedSerializer): """ def __init__(self, key_ser, val_ser): + FramedSerializer.__init__(self) self.key_ser = key_ser self.val_ser = val_ser @@ -293,7 +294,7 @@ def prepare_keys_values(self, stream): val_stream = self.val_ser._load_stream_without_unbatching(stream) key_is_batched = isinstance(self.key_ser, BatchedSerializer) val_is_batched = isinstance(self.val_ser, BatchedSerializer) - for (keys, vals) in izip(key_stream, val_stream): + for (keys, vals) in zip(key_stream, val_stream): keys = keys if key_is_batched else [keys] vals = vals if val_is_batched else [vals] yield (keys, vals) @@ -303,10 +304,6 @@ def load_stream(self, stream): for pair in product(keys, vals): yield pair - def __eq__(self, other): - return (isinstance(other, CartesianDeserializer) and - self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __repr__(self): return "CartesianDeserializer(%s, %s)" % \ (str(self.key_ser), str(self.val_ser)) @@ -318,22 +315,14 @@ class PairDeserializer(CartesianDeserializer): Deserializes the JavaRDD zip() of two PythonRDDs. """ - def __init__(self, key_ser, val_ser): - self.key_ser = key_ser - self.val_ser = val_ser - def load_stream(self, stream): for (keys, vals) in self.prepare_keys_values(stream): if len(keys) != len(vals): raise ValueError("Can not deserialize RDD with different number of items" " in pair: (%d, %d)" % (len(keys), len(vals))) - for pair in izip(keys, vals): + for pair in zip(keys, vals): yield pair - def __eq__(self, other): - return (isinstance(other, PairDeserializer) and - self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __repr__(self): return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) @@ -382,8 +371,8 @@ def _hijack_namedtuple(): global _old_namedtuple # or it will put in closure def _copy_func(f): - return types.FunctionType(f.func_code, f.func_globals, f.func_name, - f.func_defaults, f.func_closure) + return types.FunctionType(f.__code__, f.__globals__, f.__name__, + f.__defaults__, f.__closure__) _old_namedtuple = _copy_func(collections.namedtuple) @@ -392,15 +381,15 @@ def namedtuple(*args, **kwargs): return _hack_namedtuple(cls) # replace namedtuple with new one - collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple - collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple - collections.namedtuple.func_code = namedtuple.func_code + collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple + collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple + collections.namedtuple.__code__ = namedtuple.__code__ collections.namedtuple.__hijack = 1 # hack the cls already generated by namedtuple # those created in other module can be pickled as normal, # so only hack those in __main__ module - for n, o in sys.modules["__main__"].__dict__.iteritems(): + for n, o in sys.modules["__main__"].__dict__.items(): if (type(o) is type and o.__base__ is tuple and hasattr(o, "_fields") and "__reduce__" not in o.__dict__): @@ -413,7 +402,7 @@ def namedtuple(*args, **kwargs): class PickleSerializer(FramedSerializer): """ - Serializes objects using Python's cPickle serializer: + Serializes objects using Python's pickle serializer: http://docs.python.org/2/library/pickle.html @@ -422,10 +411,14 @@ class PickleSerializer(FramedSerializer): """ def dumps(self, obj): - return cPickle.dumps(obj, 2) + return pickle.dumps(obj, protocol) - def loads(self, obj): - return cPickle.loads(obj) + if sys.version >= '3': + def loads(self, obj, encoding="bytes"): + return pickle.loads(obj, encoding=encoding) + else: + def loads(self, obj, encoding=None): + return pickle.loads(obj) class CloudPickleSerializer(PickleSerializer): @@ -454,7 +447,7 @@ def loads(self, obj): class AutoSerializer(FramedSerializer): """ - Choose marshal or cPickle as serialization protocol automatically + Choose marshal or pickle as serialization protocol automatically """ def __init__(self): @@ -463,19 +456,19 @@ def __init__(self): def dumps(self, obj): if self._type is not None: - return 'P' + cPickle.dumps(obj, -1) + return b'P' + pickle.dumps(obj, -1) try: - return 'M' + marshal.dumps(obj) + return b'M' + marshal.dumps(obj) except Exception: - self._type = 'P' - return 'P' + cPickle.dumps(obj, -1) + self._type = b'P' + return b'P' + pickle.dumps(obj, -1) def loads(self, obj): _type = obj[0] - if _type == 'M': + if _type == b'M': return marshal.loads(obj[1:]) - elif _type == 'P': - return cPickle.loads(obj[1:]) + elif _type == b'P': + return pickle.loads(obj[1:]) else: raise ValueError("invalid sevialization type: %s" % _type) @@ -495,8 +488,8 @@ def dumps(self, obj): def loads(self, obj): return self.serializer.loads(zlib.decompress(obj)) - def __eq__(self, other): - return isinstance(other, CompressedSerializer) and self.serializer == other.serializer + def __repr__(self): + return "CompressedSerializer(%s)" % self.serializer class UTF8Deserializer(Serializer): @@ -505,7 +498,7 @@ class UTF8Deserializer(Serializer): Deserializes streams written by String.getBytes. """ - def __init__(self, use_unicode=False): + def __init__(self, use_unicode=True): self.use_unicode = use_unicode def loads(self, stream): @@ -526,13 +519,13 @@ def load_stream(self, stream): except EOFError: return - def __eq__(self, other): - return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode + def __repr__(self): + return "UTF8Deserializer(%s)" % self.use_unicode def read_long(stream): length = stream.read(8) - if length == "": + if not length: raise EOFError return struct.unpack("!q", length)[0] @@ -547,7 +540,7 @@ def pack_long(value): def read_int(stream): length = stream.read(4) - if length == "": + if not length: raise EOFError return struct.unpack("!i", length)[0] diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 81aa970a32f76..144cdf0b0cdd5 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -21,13 +21,6 @@ This file is designed to be launched as a PYTHONSTARTUP script. """ -import sys -if sys.version_info[0] != 2: - print("Error: Default Python used is Python%s" % sys.version_info.major) - print("\tSet env variable PYSPARK_PYTHON to Python2 binary and re-run it.") - sys.exit(1) - - import atexit import os import platform @@ -53,9 +46,14 @@ try: # Try to access HiveConf, it will raise exception if Hive is not added sc._jvm.org.apache.hadoop.hive.conf.HiveConf() - sqlCtx = sqlContext = HiveContext(sc) + sqlContext = HiveContext(sc) except py4j.protocol.Py4JError: - sqlCtx = sqlContext = SQLContext(sc) + sqlContext = SQLContext(sc) +except TypeError: + sqlContext = SQLContext(sc) + +# for compatibility +sqlCtx = sqlContext print("""Welcome to ____ __ diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 8a6fc627eb383..b54baa57ec28a 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -78,8 +78,8 @@ def _get_local_dirs(sub): # global stats -MemoryBytesSpilled = 0L -DiskBytesSpilled = 0L +MemoryBytesSpilled = 0 +DiskBytesSpilled = 0 class Aggregator(object): @@ -126,7 +126,7 @@ def mergeCombiners(self, iterator): """ Merge the combined items by mergeCombiner """ raise NotImplementedError - def iteritems(self): + def items(self): """ Return the merged items ad iterator """ raise NotImplementedError @@ -156,9 +156,9 @@ def mergeCombiners(self, iterator): for k, v in iterator: d[k] = comb(d[k], v) if k in d else v - def iteritems(self): - """ Return the merged items as iterator """ - return self.data.iteritems() + def items(self): + """ Return the merged items ad iterator """ + return iter(self.data.items()) def _compressed_serializer(self, serializer=None): @@ -208,15 +208,15 @@ class ExternalMerger(Merger): >>> agg = SimpleAggregator(lambda x, y: x + y) >>> merger = ExternalMerger(agg, 10) >>> N = 10000 - >>> merger.mergeValues(zip(xrange(N), xrange(N))) + >>> merger.mergeValues(zip(range(N), range(N))) >>> assert merger.spills > 0 - >>> sum(v for k,v in merger.iteritems()) + >>> sum(v for k,v in merger.items()) 49995000 >>> merger = ExternalMerger(agg, 10) - >>> merger.mergeCombiners(zip(xrange(N), xrange(N))) + >>> merger.mergeCombiners(zip(range(N), range(N))) >>> assert merger.spills > 0 - >>> sum(v for k,v in merger.iteritems()) + >>> sum(v for k,v in merger.items()) 49995000 """ @@ -335,10 +335,10 @@ def _spill(self): # above limit at the first time. # open all the files for writing - streams = [open(os.path.join(path, str(i)), 'w') + streams = [open(os.path.join(path, str(i)), 'wb') for i in range(self.partitions)] - for k, v in self.data.iteritems(): + for k, v in self.data.items(): h = self._partition(k) # put one item in batch, make it compatible with load_stream # it will increase the memory if dump them in batch @@ -354,9 +354,9 @@ def _spill(self): else: for i in range(self.partitions): p = os.path.join(path, str(i)) - with open(p, "w") as f: + with open(p, "wb") as f: # dump items in batch - self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.serializer.dump_stream(iter(self.pdata[i].items()), f) self.pdata[i].clear() DiskBytesSpilled += os.path.getsize(p) @@ -364,10 +364,10 @@ def _spill(self): gc.collect() # release the memory as much as possible MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 - def iteritems(self): + def items(self): """ Return all merged items as iterator """ if not self.pdata and not self.spills: - return self.data.iteritems() + return iter(self.data.items()) return self._external_items() def _external_items(self): @@ -398,7 +398,8 @@ def _merged_items(self, index): path = self._get_spill_dir(j) p = os.path.join(path, str(index)) # do not check memory during merging - self.mergeCombiners(self.serializer.load_stream(open(p)), 0) + with open(p, "rb") as f: + self.mergeCombiners(self.serializer.load_stream(f), 0) # limit the total partitions if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS @@ -408,7 +409,7 @@ def _merged_items(self, index): gc.collect() # release the memory as much as possible return self._recursive_merged_items(index) - return self.data.iteritems() + return self.data.items() def _recursive_merged_items(self, index): """ @@ -426,7 +427,8 @@ def _recursive_merged_items(self, index): for j in range(self.spills): path = self._get_spill_dir(j) p = os.path.join(path, str(index)) - m.mergeCombiners(self.serializer.load_stream(open(p)), 0) + with open(p, 'rb') as f: + m.mergeCombiners(self.serializer.load_stream(f), 0) if get_used_memory() > limit: m._spill() @@ -451,7 +453,7 @@ class ExternalSorter(object): >>> sorter = ExternalSorter(1) # 1M >>> import random - >>> l = range(1024) + >>> l = list(range(1024)) >>> random.shuffle(l) >>> sorted(l) == list(sorter.sorted(l)) True @@ -499,9 +501,16 @@ def sorted(self, iterator, key=None, reverse=False): # sort them inplace will save memory current_chunk.sort(key=key, reverse=reverse) path = self._get_path(len(chunks)) - with open(path, 'w') as f: + with open(path, 'wb') as f: self.serializer.dump_stream(current_chunk, f) - chunks.append(self.serializer.load_stream(open(path))) + + def load(f): + for v in self.serializer.load_stream(f): + yield v + # close the file explicit once we consume all the items + # to avoid ResourceWarning in Python3 + f.close() + chunks.append(load(open(path, 'rb'))) current_chunk = [] gc.collect() limit = self._next_limit() @@ -527,7 +536,7 @@ class ExternalList(object): ExternalList can have many items which cannot be hold in memory in the same time. - >>> l = ExternalList(range(100)) + >>> l = ExternalList(list(range(100))) >>> len(l) 100 >>> l.append(10) @@ -555,11 +564,11 @@ def __init__(self, values): def __getstate__(self): if self._file is not None: self._file.flush() - f = os.fdopen(os.dup(self._file.fileno())) - f.seek(0) - serialized = f.read() + with os.fdopen(os.dup(self._file.fileno()), "rb") as f: + f.seek(0) + serialized = f.read() else: - serialized = '' + serialized = b'' return self.values, self.count, serialized def __setstate__(self, item): @@ -575,7 +584,7 @@ def __iter__(self): if self._file is not None: self._file.flush() # read all items from disks first - with os.fdopen(os.dup(self._file.fileno()), 'r') as f: + with os.fdopen(os.dup(self._file.fileno()), 'rb') as f: f.seek(0) for v in self._ser.load_stream(f): yield v @@ -598,11 +607,16 @@ def _open_file(self): d = dirs[id(self) % len(dirs)] if not os.path.exists(d): os.makedirs(d) - p = os.path.join(d, str(id)) - self._file = open(p, "w+", 65536) + p = os.path.join(d, str(id(self))) + self._file = open(p, "wb+", 65536) self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) os.unlink(p) + def __del__(self): + if self._file: + self._file.close() + self._file = None + def _spill(self): """ dump the values into disk """ global MemoryBytesSpilled, DiskBytesSpilled @@ -651,33 +665,28 @@ class GroupByKey(object): """ Group a sorted iterator as [(k1, it1), (k2, it2), ...] - >>> k = [i/3 for i in range(6)] + >>> k = [i // 3 for i in range(6)] >>> v = [[i] for i in range(6)] - >>> g = GroupByKey(iter(zip(k, v))) + >>> g = GroupByKey(zip(k, v)) >>> [(k, list(it)) for k, it in g] [(0, [0, 1, 2]), (1, [3, 4, 5])] """ def __init__(self, iterator): - self.iterator = iter(iterator) - self.next_item = None + self.iterator = iterator def __iter__(self): - return self - - def next(self): - key, value = self.next_item if self.next_item else next(self.iterator) - values = ExternalListOfList([value]) - try: - while True: - k, v = next(self.iterator) - if k != key: - self.next_item = (k, v) - break + key, values = None, None + for k, v in self.iterator: + if values is not None and k == key: values.append(v) - except StopIteration: - self.next_item = None - return key, values + else: + if values is not None: + yield (key, values) + key = k + values = ExternalListOfList([v]) + if values is not None: + yield (key, values) class ExternalGroupBy(ExternalMerger): @@ -744,7 +753,7 @@ def _spill(self): # above limit at the first time. # open all the files for writing - streams = [open(os.path.join(path, str(i)), 'w') + streams = [open(os.path.join(path, str(i)), 'wb') for i in range(self.partitions)] # If the number of keys is small, then the overhead of sort is small @@ -756,7 +765,7 @@ def _spill(self): h = self._partition(k) self.serializer.dump_stream([(k, self.data[k])], streams[h]) else: - for k, v in self.data.iteritems(): + for k, v in self.data.items(): h = self._partition(k) self.serializer.dump_stream([(k, v)], streams[h]) @@ -771,14 +780,14 @@ def _spill(self): else: for i in range(self.partitions): p = os.path.join(path, str(i)) - with open(p, "w") as f: + with open(p, "wb") as f: # dump items in batch if self._sorted: # sort by key only (stable) - sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0)) + sorted_items = sorted(self.pdata[i].items(), key=operator.itemgetter(0)) self.serializer.dump_stream(sorted_items, f) else: - self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.serializer.dump_stream(self.pdata[i].items(), f) self.pdata[i].clear() DiskBytesSpilled += os.path.getsize(p) @@ -792,7 +801,7 @@ def _merged_items(self, index): # if the memory can not hold all the partition, # then use sort based merge. Because of compression, # the data on disks will be much smaller than needed memory - if (size >> 20) >= self.memory_limit / 10: + if size >= self.memory_limit << 17: # * 1M / 8 return self._merge_sorted_items(index) self.data = {} @@ -800,15 +809,18 @@ def _merged_items(self, index): path = self._get_spill_dir(j) p = os.path.join(path, str(index)) # do not check memory during merging - self.mergeCombiners(self.serializer.load_stream(open(p)), 0) - return self.data.iteritems() + with open(p, "rb") as f: + self.mergeCombiners(self.serializer.load_stream(f), 0) + return self.data.items() def _merge_sorted_items(self, index): """ load a partition from disk, then sort and group by key """ def load_partition(j): path = self._get_spill_dir(j) p = os.path.join(path, str(index)) - return self.serializer.load_stream(open(p, 'r', 65536)) + with open(p, 'rb', 65536) as f: + for v in self.serializer.load_stream(f): + yield v disk_items = [load_partition(j) for j in range(self.spills)] diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 65abb24eed823..6d54b9e49ed10 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -37,9 +37,22 @@ - L{types} List of data types available. """ +from __future__ import absolute_import + +# fix the module name conflict for Python 3+ +import sys +from . import _types as types +modname = __name__ + '.types' +types.__name__ = modname +# update the __module__ for all objects, make them picklable +for v in types.__dict__.values(): + if hasattr(v, "__module__") and v.__module__.endswith('._types'): + v.__module__ = modname +sys.modules[modname] = types +del modname, sys -from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.types import Row +from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions __all__ = [ diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/_types.py similarity index 97% rename from python/pyspark/sql/types.py rename to python/pyspark/sql/_types.py index ef76d84c00481..492c0cbdcf693 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/_types.py @@ -15,6 +15,7 @@ # limitations under the License. # +import sys import decimal import datetime import keyword @@ -25,6 +26,9 @@ from array import array from operator import itemgetter +if sys.version >= "3": + long = int + unicode = str __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", @@ -410,7 +414,7 @@ def fromJson(cls, json): split = pyUDT.rfind(".") pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] - m = __import__(pyModule, globals(), locals(), [pyClass], -1) + m = __import__(pyModule, globals(), locals(), [pyClass]) UDT = getattr(m, pyClass) return UDT() @@ -419,10 +423,9 @@ def __eq__(self, other): _all_primitive_types = dict((v.typeName(), v) - for v in globals().itervalues() - if type(v) is PrimitiveTypeSingleton and - v.__base__ == PrimitiveType) - + for v in list(globals().values()) + if (type(v) is type or type(v) is PrimitiveTypeSingleton) + and v.__base__ == PrimitiveType) _all_complex_types = dict((v.typeName(), v) for v in [ArrayType, MapType, StructType]) @@ -486,10 +489,10 @@ def _parse_datatype_json_string(json_string): def _parse_datatype_json_value(json_value): - if type(json_value) is unicode: + if not isinstance(json_value, dict): if json_value in _all_primitive_types.keys(): return _all_primitive_types[json_value]() - elif json_value == u'decimal': + elif json_value == 'decimal': return DecimalType() elif _FIXED_DECIMAL.match(json_value): m = _FIXED_DECIMAL.match(json_value) @@ -511,10 +514,8 @@ def _parse_datatype_json_value(json_value): type(None): NullType, bool: BooleanType, int: LongType, - long: LongType, float: DoubleType, str: StringType, - unicode: StringType, bytearray: BinaryType, decimal.Decimal: DecimalType, datetime.date: DateType, @@ -522,6 +523,12 @@ def _parse_datatype_json_value(json_value): datetime.time: TimestampType, } +if sys.version < "3": + _type_mappings.update({ + unicode: StringType, + long: LongType, + }) + def _infer_type(obj): """Infer the DataType from obj @@ -541,7 +548,7 @@ def _infer_type(obj): return dataType() if isinstance(obj, dict): - for key, value in obj.iteritems(): + for key, value in obj.items(): if key is not None and value is not None: return MapType(_infer_type(key), _infer_type(value), True) else: @@ -565,10 +572,10 @@ def _infer_schema(row): items = sorted(row.items()) elif isinstance(row, (tuple, list)): - if hasattr(row, "_fields"): # namedtuple - items = zip(row._fields, tuple(row)) - elif hasattr(row, "__fields__"): # Row + if hasattr(row, "__fields__"): # Row items = zip(row.__fields__, tuple(row)) + elif hasattr(row, "_fields"): # namedtuple + items = zip(row._fields, tuple(row)) else: names = ['_%d' % i for i in range(1, len(row) + 1)] items = zip(names, row) @@ -647,7 +654,7 @@ def converter(obj): if isinstance(obj, dict): return tuple(c(obj.get(n)) for n, c in zip(names, converters)) elif isinstance(obj, tuple): - if hasattr(obj, "_fields") or hasattr(obj, "__fields__"): + if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): return tuple(c(v) for c, v in zip(converters, obj)) elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs d = dict(obj) @@ -733,12 +740,12 @@ def _create_converter(dataType): if isinstance(dataType, ArrayType): conv = _create_converter(dataType.elementType) - return lambda row: map(conv, row) + return lambda row: [conv(v) for v in row] elif isinstance(dataType, MapType): kconv = _create_converter(dataType.keyType) vconv = _create_converter(dataType.valueType) - return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems()) + return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items()) elif isinstance(dataType, NullType): return lambda x: None @@ -881,7 +888,7 @@ def _infer_schema_type(obj, dataType): >>> _infer_schema_type(row, schema) StructType...a,ArrayType...b,MapType(StringType,...c,LongType... """ - if dataType is NullType(): + if isinstance(dataType, NullType): return _infer_type(obj) if not obj: @@ -892,7 +899,7 @@ def _infer_schema_type(obj, dataType): return ArrayType(eType, True) elif isinstance(dataType, MapType): - k, v = obj.iteritems().next() + k, v = next(iter(obj.items())) return MapType(_infer_schema_type(k, dataType.keyType), _infer_schema_type(v, dataType.valueType)) @@ -935,7 +942,7 @@ def _verify_type(obj, dataType): >>> _verify_type(None, StructType([])) >>> _verify_type("", StringType()) >>> _verify_type(0, LongType()) - >>> _verify_type(range(3), ArrayType(ShortType())) + >>> _verify_type(list(range(3)), ArrayType(ShortType())) >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... @@ -976,7 +983,7 @@ def _verify_type(obj, dataType): _verify_type(i, dataType.elementType) elif isinstance(dataType, MapType): - for k, v in obj.iteritems(): + for k, v in obj.items(): _verify_type(k, dataType.keyType) _verify_type(v, dataType.valueType) @@ -1213,6 +1220,8 @@ def __getattr__(self, item): return self[idx] except IndexError: raise AttributeError(item) + except ValueError: + raise AttributeError(item) def __reduce__(self): if hasattr(self, "__fields__"): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index e8529a8f8e3a4..c90afc326ca0e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -15,14 +15,19 @@ # limitations under the License. # +import sys import warnings import json -from itertools import imap + +if sys.version >= '3': + basestring = unicode = str +else: + from itertools import imap as map from py4j.protocol import Py4JError from py4j.java_collections import MapConverter -from pyspark.rdd import RDD, _prepare_for_python_RDD +from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter @@ -62,31 +67,27 @@ class SQLContext(object): A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as tables, execute SQL over tables, cache tables, and read parquet files. - When created, :class:`SQLContext` adds a method called ``toDF`` to :class:`RDD`, - which could be used to convert an RDD into a DataFrame, it's a shorthand for - :func:`SQLContext.createDataFrame`. - :param sparkContext: The :class:`SparkContext` backing this SQLContext. :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new SQLContext in the JVM, instead we make all calls to this object. """ + @ignore_unicode_prefix def __init__(self, sparkContext, sqlContext=None): """Creates a new SQLContext. >>> from datetime import datetime >>> sqlContext = SQLContext(sc) - >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, + >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) >>> df = allTypes.toDF() >>> df.registerTempTable("allTypes") >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, - ... x.row.a, x.list)).collect() - [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] + [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() + [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ self._sc = sparkContext self._jsc = self._sc._jsc @@ -122,6 +123,7 @@ def udf(self): """Returns a :class:`UDFRegistration` for UDF registration.""" return UDFRegistration(self) + @ignore_unicode_prefix def registerFunction(self, name, f, returnType=StringType()): """Registers a lambda function as a UDF so it can be used in SQL statements. @@ -147,7 +149,7 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] """ - func = lambda _, it: imap(lambda x: f(*x), it) + func = lambda _, it: map(lambda x: f(*x), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self) @@ -185,6 +187,7 @@ def _inferSchema(self, rdd, samplingRatio=None): schema = rdd.map(_infer_schema).reduce(_merge_type) return schema + @ignore_unicode_prefix def inferSchema(self, rdd, samplingRatio=None): """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. """ @@ -195,6 +198,7 @@ def inferSchema(self, rdd, samplingRatio=None): return self.createDataFrame(rdd, None, samplingRatio) + @ignore_unicode_prefix def applySchema(self, rdd, schema): """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. """ @@ -208,6 +212,7 @@ def applySchema(self, rdd, schema): return self.createDataFrame(rdd, schema) + @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None): """ Creates a :class:`DataFrame` from an :class:`RDD` of :class:`tuple`/:class:`list`, @@ -380,6 +385,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): df = self._ssql_ctx.jsonFile(path, scala_datatype) return DataFrame(df, self) + @ignore_unicode_prefix def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. @@ -477,6 +483,7 @@ def createExternalTable(self, tableName, path=None, source=None, joptions) return DataFrame(df, self) + @ignore_unicode_prefix def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. @@ -497,6 +504,7 @@ def table(self, tableName): """ return DataFrame(self._ssql_ctx.table(tableName), self) + @ignore_unicode_prefix def tables(self, dbName=None): """Returns a :class:`DataFrame` containing names of tables in the given database. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f2c3b74a185cf..d76504f986270 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -16,14 +16,19 @@ # import sys -import itertools import warnings import random +if sys.version >= '3': + basestring = unicode = str + long = int +else: + from itertools import imap as map + from py4j.java_collections import ListConverter, MapConverter from pyspark.context import SparkContext -from pyspark.rdd import RDD, _load_from_socket +from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -65,19 +70,20 @@ def __init__(self, jdf, sql_ctx): self._sc = sql_ctx and sql_ctx._sc self.is_cached = False self._schema = None # initialized lazily + self._lazy_rdd = None @property def rdd(self): """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. """ - if not hasattr(self, '_lazy_rdd'): + if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) schema = self.schema def applySchema(it): cls = _create_cls(schema) - return itertools.imap(cls, it) + return map(cls, it) self._lazy_rdd = rdd.mapPartitions(applySchema) @@ -89,13 +95,14 @@ def na(self): """ return DataFrameNaFunctions(self) - def toJSON(self, use_unicode=False): + @ignore_unicode_prefix + def toJSON(self, use_unicode=True): """Converts a :class:`DataFrame` into a :class:`RDD` of string. Each row is turned into a JSON document as one element in the returned RDD. >>> df.toJSON().first() - '{"age":2,"name":"Alice"}' + u'{"age":2,"name":"Alice"}' """ rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) @@ -228,7 +235,7 @@ def printSchema(self): |-- name: string (nullable = true) """ - print (self._jdf.schema().treeString()) + print(self._jdf.schema().treeString()) def explain(self, extended=False): """Prints the (logical and physical) plans to the console for debugging purpose. @@ -250,9 +257,9 @@ def explain(self, extended=False): == RDD == """ if extended: - print self._jdf.queryExecution().toString() + print(self._jdf.queryExecution().toString()) else: - print self._jdf.queryExecution().executedPlan().toString() + print(self._jdf.queryExecution().executedPlan().toString()) def isLocal(self): """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally @@ -270,7 +277,7 @@ def show(self, n=20): 2 Alice 5 Bob """ - print self._jdf.showString(n).encode('utf8', 'ignore') + print(self._jdf.showString(n)) def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) @@ -279,10 +286,11 @@ def count(self): """Returns the number of rows in this :class:`DataFrame`. >>> df.count() - 2L + 2 """ - return self._jdf.count() + return int(self._jdf.count()) + @ignore_unicode_prefix def collect(self): """Returns all the records as a list of :class:`Row`. @@ -295,6 +303,7 @@ def collect(self): cls = _create_cls(self.schema) return [cls(r) for r in rs] + @ignore_unicode_prefix def limit(self, num): """Limits the result count to the number specified. @@ -306,6 +315,7 @@ def limit(self, num): jdf = self._jdf.limit(num) return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix def take(self, num): """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. @@ -314,6 +324,7 @@ def take(self, num): """ return self.limit(num).collect() + @ignore_unicode_prefix def map(self, f): """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. @@ -324,6 +335,7 @@ def map(self, f): """ return self.rdd.map(f) + @ignore_unicode_prefix def flatMap(self, f): """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, and then flattening the results. @@ -353,7 +365,7 @@ def foreach(self, f): This is a shorthand for ``df.rdd.foreach()``. >>> def f(person): - ... print person.name + ... print(person.name) >>> df.foreach(f) """ return self.rdd.foreach(f) @@ -365,7 +377,7 @@ def foreachPartition(self, f): >>> def f(people): ... for person in people: - ... print person.name + ... print(person.name) >>> df.foreachPartition(f) """ return self.rdd.foreachPartition(f) @@ -412,7 +424,7 @@ def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. >>> df.distinct().count() - 2L + 2 """ return DataFrame(self._jdf.distinct(), self.sql_ctx) @@ -420,10 +432,10 @@ def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. >>> df.sample(False, 0.5, 97).count() - 1L + 1 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) + seed = seed if seed is not None else random.randint(0, sys.maxsize) rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) @@ -437,6 +449,7 @@ def dtypes(self): return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @property + @ignore_unicode_prefix def columns(self): """Returns all column names as a list. @@ -445,6 +458,7 @@ def columns(self): """ return [f.name for f in self.schema.fields] + @ignore_unicode_prefix def join(self, other, joinExprs=None, joinType=None): """Joins with another :class:`DataFrame`, using the given join expression. @@ -470,6 +484,7 @@ def join(self, other, joinExprs=None, joinType=None): jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType) return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix def sort(self, *cols): """Returns a new :class:`DataFrame` sorted by the specified column(s). @@ -513,6 +528,7 @@ def describe(self, *cols): jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)) return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix def head(self, n=None): """ Returns the first ``n`` rows as a list of :class:`Row`, @@ -528,6 +544,7 @@ def head(self, n=None): return rs[0] if rs else None return self.take(n) + @ignore_unicode_prefix def first(self): """Returns the first row as a :class:`Row`. @@ -536,6 +553,7 @@ def first(self): """ return self.head() + @ignore_unicode_prefix def __getitem__(self, item): """Returns the column as a :class:`Column`. @@ -567,6 +585,7 @@ def __getattr__(self, name): jc = self._jdf.apply(name) return Column(jc) + @ignore_unicode_prefix def select(self, *cols): """Projects a set of expressions and returns a new :class:`DataFrame`. @@ -598,6 +617,7 @@ def selectExpr(self, *expr): jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix def filter(self, condition): """Filters rows using the given condition. @@ -626,6 +646,7 @@ def filter(self, condition): where = filter + @ignore_unicode_prefix def groupBy(self, *cols): """Groups the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedData` @@ -775,6 +796,7 @@ def fillna(self, value, subset=None): cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx) + @ignore_unicode_prefix def withColumn(self, colName, col): """Returns a new :class:`DataFrame` by adding a column. @@ -786,6 +808,7 @@ def withColumn(self, colName, col): """ return self.select('*', col.alias(colName)) + @ignore_unicode_prefix def withColumnRenamed(self, existing, new): """REturns a new :class:`DataFrame` by renaming an existing column. @@ -852,6 +875,7 @@ def __init__(self, jdf, sql_ctx): self._jdf = jdf self.sql_ctx = sql_ctx + @ignore_unicode_prefix def agg(self, *exprs): """Compute aggregates and returns the result as a :class:`DataFrame`. @@ -1041,11 +1065,13 @@ def __init__(self, jc): __sub__ = _bin_op("minus") __mul__ = _bin_op("multiply") __div__ = _bin_op("divide") + __truediv__ = _bin_op("divide") __mod__ = _bin_op("mod") __radd__ = _bin_op("plus") __rsub__ = _reverse_op("minus") __rmul__ = _bin_op("multiply") __rdiv__ = _reverse_op("divide") + __rtruediv__ = _reverse_op("divide") __rmod__ = _reverse_op("mod") # logistic operators @@ -1075,6 +1101,7 @@ def __init__(self, jc): startswith = _bin_op("startsWith") endswith = _bin_op("endsWith") + @ignore_unicode_prefix def substr(self, startPos, length): """ Return a :class:`Column` which is a substring of the column @@ -1097,6 +1124,7 @@ def substr(self, startPos, length): __getslice__ = substr + @ignore_unicode_prefix def inSet(self, *cols): """ A boolean expression that is evaluated to true if the value of this expression is contained by the evaluated values of the arguments. @@ -1131,6 +1159,7 @@ def alias(self, alias): """ return Column(getattr(self._jc, "as")(alias)) + @ignore_unicode_prefix def cast(self, dataType): """ Convert the column into type `dataType` diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index daeb6916b58bc..1d6536952810f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -18,8 +18,10 @@ """ A collections of builtin functions """ +import sys -from itertools import imap +if sys.version < "3": + from itertools import imap as map from py4j.java_collections import ListConverter @@ -116,7 +118,7 @@ def __init__(self, func, returnType): def _create_judf(self): f = self.func # put it in closure `func` - func = lambda _, it: imap(lambda x: f(*x), it) + func = lambda _, it: map(lambda x: f(*x), it) ser = AutoBatchedSerializer(PickleSerializer()) command = (func, None, ser, ser) sc = SparkContext._active_spark_context diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b3a6a2c6a9229..7c09a0cfe30ab 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -157,13 +157,13 @@ def test_udf2(self): self.assertEqual(4, res[0]) def test_udf_with_array_type(self): - d = [Row(l=range(3), d={"key": range(5)})] + d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) self.sqlCtx.createDataFrame(rdd).registerTempTable("test") self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() - self.assertEqual(range(3), l1) + self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) def test_broadcast_in_udf(self): @@ -266,7 +266,7 @@ def test_infer_nested_schema(self): def test_apply_schema(self): from datetime import date, datetime - rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, (2,), [1, 2, 3], None)]) schema = StructType([ @@ -309,7 +309,7 @@ def test_apply_schema(self): def test_struct_in_map(self): d = [Row(m={Row(i=1): Row(s="")})] df = self.sc.parallelize(d).toDF() - k, v = df.head().m.items()[0] + k, v = list(df.head().m.items())[0] self.assertEqual(1, k.i) self.assertEqual("", v.s) @@ -554,6 +554,9 @@ def setUpClass(cls): except py4j.protocol.Py4JError: cls.sqlCtx = None return + except TypeError: + cls.sqlCtx = None + return os.unlink(cls.tempdir.name) _scala_HiveContext =\ cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index 1e597d64e03fe..944fa414b0c0e 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -31,7 +31,7 @@ class StatCounter(object): def __init__(self, values=[]): - self.n = 0L # Running count of our values + self.n = 0 # Running count of our values self.mu = 0.0 # Running mean of our values self.m2 = 0.0 # Running variance numerator (sum of (x - mean)^2) self.maxValue = float("-inf") @@ -87,7 +87,7 @@ def copy(self): return copy.deepcopy(self) def count(self): - return self.n + return int(self.n) def mean(self): return self.mu diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 2c73083c9f9a8..4590c58839266 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +from __future__ import print_function + import os import sys @@ -157,7 +160,7 @@ def getOrCreate(cls, checkpointPath, setupFunc): try: jssc = gw.jvm.JavaStreamingContext(checkpointPath) except Exception: - print >>sys.stderr, "failed to load StreamingContext from checkpoint" + print("failed to load StreamingContext from checkpoint", file=sys.stderr) raise jsc = jssc.sparkContext() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 3fa42444239f7..ff097985fae3e 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -15,11 +15,15 @@ # limitations under the License. # -from itertools import chain, ifilter, imap +import sys import operator import time +from itertools import chain from datetime import datetime +if sys.version < "3": + from itertools import imap as map, ifilter as filter + from py4j.protocol import Py4JJavaError from pyspark import RDD @@ -76,7 +80,7 @@ def filter(self, f): Return a new DStream containing only the elements that satisfy predicate. """ def func(iterator): - return ifilter(f, iterator) + return filter(f, iterator) return self.mapPartitions(func, True) def flatMap(self, f, preservesPartitioning=False): @@ -85,7 +89,7 @@ def flatMap(self, f, preservesPartitioning=False): this DStream, and then flattening the results """ def func(s, iterator): - return chain.from_iterable(imap(f, iterator)) + return chain.from_iterable(map(f, iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def map(self, f, preservesPartitioning=False): @@ -93,7 +97,7 @@ def map(self, f, preservesPartitioning=False): Return a new DStream by applying a function to each element of DStream. """ def func(iterator): - return imap(f, iterator) + return map(f, iterator) return self.mapPartitions(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -150,7 +154,7 @@ def foreachRDD(self, func): """ Apply a function to each RDD in this DStream. """ - if func.func_code.co_argcount == 1: + if func.__code__.co_argcount == 1: old_func = func func = lambda t, rdd: old_func(rdd) jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) @@ -165,14 +169,14 @@ def pprint(self, num=10): """ def takeAndPrint(time, rdd): taken = rdd.take(num + 1) - print "-------------------------------------------" - print "Time: %s" % time - print "-------------------------------------------" + print("-------------------------------------------") + print("Time: %s" % time) + print("-------------------------------------------") for record in taken[:num]: - print record + print(record) if len(taken) > num: - print "..." - print + print("...") + print() self.foreachRDD(takeAndPrint) @@ -181,7 +185,7 @@ def mapValues(self, f): Return a new DStream by applying a map function to the value of each key-value pairs in this DStream without changing the key. """ - map_values_fn = lambda (k, v): (k, f(v)) + map_values_fn = lambda kv: (kv[0], f(kv[1])) return self.map(map_values_fn, preservesPartitioning=True) def flatMapValues(self, f): @@ -189,7 +193,7 @@ def flatMapValues(self, f): Return a new DStream by applying a flatmap function to the value of each key-value pairs in this DStream without changing the key. """ - flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + flat_map_fn = lambda kv: ((kv[0], x) for x in f(kv[1])) return self.flatMap(flat_map_fn, preservesPartitioning=True) def glom(self): @@ -286,10 +290,10 @@ def transform(self, func): `func` can have one argument of `rdd`, or have two arguments of (`time`, `rdd`) """ - if func.func_code.co_argcount == 1: + if func.__code__.co_argcount == 1: oldfunc = func func = lambda t, rdd: oldfunc(rdd) - assert func.func_code.co_argcount == 2, "func should take one or two arguments" + assert func.__code__.co_argcount == 2, "func should take one or two arguments" return TransformedDStream(self, func) def transformWith(self, func, other, keepSerializer=False): @@ -300,10 +304,10 @@ def transformWith(self, func, other, keepSerializer=False): `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three arguments of (`time`, `rdd_a`, `rdd_b`) """ - if func.func_code.co_argcount == 2: + if func.__code__.co_argcount == 2: oldfunc = func func = lambda t, a, b: oldfunc(a, b) - assert func.func_code.co_argcount == 3, "func should take two or three arguments" + assert func.__code__.co_argcount == 3, "func should take two or three arguments" jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) @@ -460,7 +464,7 @@ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuratio keyed = self.map(lambda x: (1, x)) reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, 1) - return reduced.map(lambda (k, v): v) + return reduced.map(lambda kv: kv[1]) def countByWindow(self, windowDuration, slideDuration): """ @@ -489,7 +493,7 @@ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=Non keyed = self.map(lambda x: (x, 1)) counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, windowDuration, slideDuration, numPartitions) - return counted.filter(lambda (k, v): v > 0).count() + return counted.filter(lambda kv: kv[1] > 0).count() def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): """ @@ -548,7 +552,8 @@ def reduceFunc(t, a, b): def invReduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) joined = a.leftOuterJoin(b, numPartitions) - return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) + return joined.mapValues(lambda kv: invFunc(kv[0], kv[1]) + if kv[1] is not None else kv[0]) jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) if invReduceFunc: @@ -579,9 +584,9 @@ def reduceFunc(t, a, b): g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) else: g = a.cogroup(b.partitionBy(numPartitions), numPartitions) - g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None)) - state = g.mapValues(lambda (vs, s): updateFunc(vs, s)) - return state.filter(lambda (k, v): v is not None) + g = g.mapValues(lambda ab: (list(ab[1]), list(ab[0])[0] if len(ab[0]) else None)) + state = g.mapValues(lambda vs_s: updateFunc(vs_s[0], vs_s[1])) + return state.filter(lambda k_v: k_v[1] is not None) jreduceFunc = TransformFunction(self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index f083ed149effb..7a7b6e1d9a527 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -67,10 +67,10 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") helper = helperClass.newInstance() jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel) - except Py4JJavaError, e: + except Py4JJavaError as e: # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): - print """ + print(""" ________________________________________________________________________________________________ Spark Streaming's Kafka libraries not found in class path. Try one of the following. @@ -88,8 +88,8 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, ________________________________________________________________________________________________ -""" % (ssc.sparkContext.version, ssc.sparkContext.version) +""" % (ssc.sparkContext.version, ssc.sparkContext.version)) raise e ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) stream = DStream(jstream, ssc, ser) - return stream.map(lambda (k, v): (keyDecoder(k), valueDecoder(v))) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 9b4635e49020b..06d22154373bc 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -22,6 +22,7 @@ import unittest import tempfile import struct +from functools import reduce from py4j.java_collections import MapConverter @@ -51,7 +52,7 @@ def wait_for(self, result, n): while len(result) < n and time.time() - start_time < self.timeout: time.sleep(0.01) if len(result) < n: - print "timeout after", self.timeout + print("timeout after", self.timeout) def _take(self, dstream, n): """ @@ -131,7 +132,7 @@ def test_map(self): def func(dstream): return dstream.map(str) - expected = map(lambda x: map(str, x), input) + expected = [list(map(str, x)) for x in input] self._test_func(input, func, expected) def test_flatMap(self): @@ -140,8 +141,8 @@ def test_flatMap(self): def func(dstream): return dstream.flatMap(lambda x: (x, x * 2)) - expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), - input) + expected = [list(chain.from_iterable((map(lambda y: [y, y * 2], x)))) + for x in input] self._test_func(input, func, expected) def test_filter(self): @@ -150,7 +151,7 @@ def test_filter(self): def func(dstream): return dstream.filter(lambda x: x % 2 == 0) - expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input) + expected = [[y for y in x if y % 2 == 0] for x in input] self._test_func(input, func, expected) def test_count(self): @@ -159,7 +160,7 @@ def test_count(self): def func(dstream): return dstream.count() - expected = map(lambda x: [len(x)], input) + expected = [[len(x)] for x in input] self._test_func(input, func, expected) def test_reduce(self): @@ -168,7 +169,7 @@ def test_reduce(self): def func(dstream): return dstream.reduce(operator.add) - expected = map(lambda x: [reduce(operator.add, x)], input) + expected = [[reduce(operator.add, x)] for x in input] self._test_func(input, func, expected) def test_reduceByKey(self): @@ -185,27 +186,27 @@ def func(dstream): def test_mapValues(self): """Basic operation test for DStream.mapValues.""" input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], - [("", 4), (1, 1), (2, 2), (3, 3)], + [(0, 4), (1, 1), (2, 2), (3, 3)], [(1, 1), (2, 1), (3, 1), (4, 1)]] def func(dstream): return dstream.mapValues(lambda x: x + 10) expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)], - [("", 14), (1, 11), (2, 12), (3, 13)], + [(0, 14), (1, 11), (2, 12), (3, 13)], [(1, 11), (2, 11), (3, 11), (4, 11)]] self._test_func(input, func, expected, sort=True) def test_flatMapValues(self): """Basic operation test for DStream.flatMapValues.""" input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], - [("", 4), (1, 1), (2, 1), (3, 1)], + [(0, 4), (1, 1), (2, 1), (3, 1)], [(1, 1), (2, 1), (3, 1), (4, 1)]] def func(dstream): return dstream.flatMapValues(lambda x: (x, x + 10)) expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), ("c", 1), ("c", 11), ("d", 1), ("d", 11)], - [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], + [(0, 4), (0, 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]] self._test_func(input, func, expected) @@ -233,7 +234,7 @@ def f(iterator): def test_countByValue(self): """Basic operation test for DStream.countByValue.""" - input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]] + input = [list(range(1, 5)) * 2, list(range(5, 7)) + list(range(5, 9)), ["a", "a", "b", ""]] def func(dstream): return dstream.countByValue() @@ -285,7 +286,7 @@ def test_union(self): def func(d1, d2): return d1.union(d2) - expected = [range(6), range(6), range(6)] + expected = [list(range(6)), list(range(6)), list(range(6))] self._test_func(input1, func, expected, input2=input2) def test_cogroup(self): @@ -424,7 +425,7 @@ class StreamingContextTests(PySparkStreamingTestCase): duration = 0.1 def _add_input_stream(self): - inputs = map(lambda x: range(1, x), range(101)) + inputs = [range(1, x) for x in range(101)] stream = self.ssc.queueStream(inputs) self._collect(stream, 1, block=False) @@ -441,7 +442,7 @@ def test_stop_multiple_times(self): self.ssc.stop() def test_queue_stream(self): - input = [range(i + 1) for i in range(3)] + input = [list(range(i + 1)) for i in range(3)] dstream = self.ssc.queueStream(input) result = self._collect(dstream, 3) self.assertEqual(input, result) @@ -457,13 +458,13 @@ def test_text_file_stream(self): with open(os.path.join(d, name), "w") as f: f.writelines(["%d\n" % i for i in range(10)]) self.wait_for(result, 2) - self.assertEqual([range(10), range(10)], result) + self.assertEqual([list(range(10)), list(range(10))], result) def test_binary_records_stream(self): d = tempfile.mkdtemp() self.ssc = StreamingContext(self.sc, self.duration) dstream = self.ssc.binaryRecordsStream(d, 10).map( - lambda v: struct.unpack("10b", str(v))) + lambda v: struct.unpack("10b", bytes(v))) result = self._collect(dstream, 2, block=False) self.ssc.start() for name in ('a', 'b'): @@ -471,10 +472,10 @@ def test_binary_records_stream(self): with open(os.path.join(d, name), "wb") as f: f.write(bytearray(range(10))) self.wait_for(result, 2) - self.assertEqual([range(10), range(10)], map(lambda v: list(v[0]), result)) + self.assertEqual([list(range(10)), list(range(10))], [list(v[0]) for v in result]) def test_union(self): - input = [range(i + 1) for i in range(3)] + input = [list(range(i + 1)) for i in range(3)] dstream = self.ssc.queueStream(input) dstream2 = self.ssc.queueStream(input) dstream3 = self.ssc.union(dstream, dstream2) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 86ee5aa04f252..34291f30a5652 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -91,9 +91,9 @@ def dumps(self, id): except Exception: traceback.print_exc() - def loads(self, bytes): + def loads(self, data): try: - f, deserializers = self.serializer.loads(str(bytes)) + f, deserializers = self.serializer.loads(bytes(data)) return TransformFunction(self.ctx, f, *deserializers) except Exception: traceback.print_exc() @@ -116,7 +116,7 @@ def rddToFileName(prefix, suffix, timestamp): """ if isinstance(timestamp, datetime): seconds = time.mktime(timestamp.timetuple()) - timestamp = long(seconds * 1000) + timestamp.microsecond / 1000 + timestamp = int(seconds * 1000) + timestamp.microsecond // 1000 if suffix is None: return prefix + "-" + str(timestamp) else: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index ee67e80d539f8..75f39d9e75f38 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -19,8 +19,8 @@ Unit tests for PySpark; additional tests are implemented as doctests in individual modules. """ + from array import array -from fileinput import input from glob import glob import os import re @@ -45,6 +45,9 @@ sys.exit(1) else: import unittest + if sys.version_info[0] >= 3: + xrange = range + basestring = str from pyspark.conf import SparkConf @@ -52,7 +55,9 @@ from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer + CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \ + PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \ + FlattenedValuesSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark import shuffle from pyspark.profiler import BasicProfiler @@ -81,7 +86,7 @@ class MergerTests(unittest.TestCase): def setUp(self): self.N = 1 << 12 self.l = [i for i in xrange(self.N)] - self.data = zip(self.l, self.l) + self.data = list(zip(self.l, self.l)) self.agg = Aggregator(lambda x: [x], lambda x, y: x.append(y) or x, lambda x, y: x.extend(y) or x) @@ -89,45 +94,45 @@ def setUp(self): def test_in_memory(self): m = InMemoryMerger(self.agg) m.mergeValues(self.data) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) m = InMemoryMerger(self.agg) - m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + m.mergeCombiners(map(lambda x_y: (x_y[0], [x_y[1]]), self.data)) + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) m = ExternalMerger(self.agg, 1000) - m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data)) self.assertEqual(m.spills, 0) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) def test_medium_dataset(self): - m = ExternalMerger(self.agg, 30) + m = ExternalMerger(self.agg, 20) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N))) m = ExternalMerger(self.agg, 10) - m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3)) + m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(sum(v) for k, v in m.iteritems()), + self.assertEqual(sum(sum(v) for k, v in m.items()), sum(xrange(self.N)) * 3) def test_huge_dataset(self): - m = ExternalMerger(self.agg, 10, partitions=3) - m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) + m = ExternalMerger(self.agg, 5, partitions=3) + m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m.iteritems()), + self.assertEqual(sum(len(v) for k, v in m.items()), self.N * 10) m._cleanup() @@ -144,55 +149,55 @@ def gen_gs(N, step=1): self.assertEqual(1, len(list(gen_gs(1)))) self.assertEqual(2, len(list(gen_gs(2)))) self.assertEqual(100, len(list(gen_gs(100)))) - self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)]) - self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100))) + self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)]) + self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100))) for k, vs in gen_gs(50002, 10000): self.assertEqual(k, len(vs)) - self.assertEqual(range(k), list(vs)) + self.assertEqual(list(range(k)), list(vs)) ser = PickleSerializer() l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) for k, vs in l: self.assertEqual(k, len(vs)) - self.assertEqual(range(k), list(vs)) + self.assertEqual(list(range(k)), list(vs)) class SorterTests(unittest.TestCase): def test_in_memory_sort(self): - l = range(1024) + l = list(range(1024)) random.shuffle(l) sorter = ExternalSorter(1024) - self.assertEquals(sorted(l), list(sorter.sorted(l))) - self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) - self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) - self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertEqual(sorted(l), list(sorter.sorted(l))) + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) def test_external_sort(self): - l = range(1024) + l = list(range(1024)) random.shuffle(l) sorter = ExternalSorter(1) - self.assertEquals(sorted(l), list(sorter.sorted(l))) + self.assertEqual(sorted(l), list(sorter.sorted(l))) self.assertGreater(shuffle.DiskBytesSpilled, 0) last = shuffle.DiskBytesSpilled - self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) + self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) self.assertGreater(shuffle.DiskBytesSpilled, last) last = shuffle.DiskBytesSpilled - self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) + self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) self.assertGreater(shuffle.DiskBytesSpilled, last) last = shuffle.DiskBytesSpilled - self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), - list(sorter.sorted(l, key=lambda x: -x, reverse=True))) + self.assertEqual(sorted(l, key=lambda x: -x, reverse=True), + list(sorter.sorted(l, key=lambda x: -x, reverse=True))) self.assertGreater(shuffle.DiskBytesSpilled, last) def test_external_sort_in_rdd(self): conf = SparkConf().set("spark.python.worker.memory", "1m") sc = SparkContext(conf=conf) - l = range(10240) + l = list(range(10240)) random.shuffle(l) - rdd = sc.parallelize(l, 10) - self.assertEquals(sorted(l), rdd.sortBy(lambda x: x).collect()) + rdd = sc.parallelize(l, 2) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) sc.stop() @@ -200,11 +205,11 @@ class SerializationTestCase(unittest.TestCase): def test_namedtuple(self): from collections import namedtuple - from cPickle import dumps, loads + from pickle import dumps, loads P = namedtuple("P", "x y") p1 = P(1, 3) p2 = loads(dumps(p1, 2)) - self.assertEquals(p1, p2) + self.assertEqual(p1, p2) def test_itemgetter(self): from operator import itemgetter @@ -246,7 +251,7 @@ def test_pickling_file_handles(self): ser = CloudPickleSerializer() out1 = sys.stderr out2 = ser.loads(ser.dumps(out1)) - self.assertEquals(out1, out2) + self.assertEqual(out1, out2) def test_func_globals(self): @@ -263,19 +268,36 @@ def __reduce__(self): def foo(): sys.exit(0) - self.assertTrue("exit" in foo.func_code.co_names) + self.assertTrue("exit" in foo.__code__.co_names) ser.dumps(foo) def test_compressed_serializer(self): ser = CompressedSerializer(PickleSerializer()) - from StringIO import StringIO + try: + from StringIO import StringIO + except ImportError: + from io import BytesIO as StringIO io = StringIO() ser.dump_stream(["abc", u"123", range(5)], io) io.seek(0) self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) ser.dump_stream(range(1000), io) io.seek(0) - self.assertEqual(["abc", u"123", range(5)] + range(1000), list(ser.load_stream(io))) + self.assertEqual(["abc", u"123", range(5)] + list(range(1000)), list(ser.load_stream(io))) + io.close() + + def test_hash_serializer(self): + hash(NoOpSerializer()) + hash(UTF8Deserializer()) + hash(PickleSerializer()) + hash(MarshalSerializer()) + hash(AutoSerializer()) + hash(BatchedSerializer(PickleSerializer())) + hash(AutoBatchedSerializer(MarshalSerializer())) + hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) + hash(CompressedSerializer(PickleSerializer())) + hash(FlattenedValuesSerializer(PickleSerializer())) class PySparkTestCase(unittest.TestCase): @@ -340,7 +362,7 @@ def test_checkpoint_and_restore(self): self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), flatMappedRDD._jrdd_deserializer) - self.assertEquals([1, 2, 3, 4], recovered.collect()) + self.assertEqual([1, 2, 3, 4], recovered.collect()) class AddFileTests(PySparkTestCase): @@ -356,8 +378,7 @@ def test_add_py_file(self): def func(x): from userlibrary import UserClass return UserClass().hello() - self.assertRaises(Exception, - self.sc.parallelize(range(2)).map(func).first) + self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) log4j.LogManager.getRootLogger().setLevel(old_level) # Add the file, so the job should now succeed: @@ -372,7 +393,7 @@ def test_add_file_locally(self): download_path = SparkFiles.get("hello.txt") self.assertNotEqual(path, download_path) with open(download_path) as test_file: - self.assertEquals("Hello World!\n", test_file.readline()) + self.assertEqual("Hello World!\n", test_file.readline()) def test_add_py_file_locally(self): # To ensure that we're actually testing addPyFile's effects, check that @@ -381,7 +402,7 @@ def func(): from userlibrary import UserClass self.assertRaises(ImportError, func) path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") - self.sc.addFile(path) + self.sc.addPyFile(path) from userlibrary import UserClass self.assertEqual("Hello World!", UserClass().hello()) @@ -391,7 +412,7 @@ def test_add_egg_file_locally(self): def func(): from userlib import UserClass self.assertRaises(ImportError, func) - path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1-py2.7.egg") + path = os.path.join(SPARK_HOME, "python/test_support/userlib-0.1.zip") self.sc.addPyFile(path) from userlib import UserClass self.assertEqual("Hello World from inside a package!", UserClass().hello()) @@ -427,8 +448,9 @@ def test_save_as_textfile_with_unicode(self): tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsTextFile(tempFile.name) - raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) - self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode("utf-8")) def test_save_as_textfile_with_utf8(self): x = u"\u00A1Hola, mundo!" @@ -436,19 +458,20 @@ def test_save_as_textfile_with_utf8(self): tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsTextFile(tempFile.name) - raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) - self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + raw_contents = b''.join(open(p, 'rb').read() + for p in glob(tempFile.name + "/part-0000*")) + self.assertEqual(x, raw_contents.strip().decode('utf8')) def test_transforming_cartesian_result(self): # Regression test for SPARK-1034 rdd1 = self.sc.parallelize([1, 2]) rdd2 = self.sc.parallelize([3, 4]) cart = rdd1.cartesian(rdd2) - result = cart.map(lambda (x, y): x + y).collect() + result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect() def test_transforming_pickle_file(self): # Regression test for SPARK-2601 - data = self.sc.parallelize(["Hello", "World!"]) + data = self.sc.parallelize([u"Hello", u"World!"]) tempFile = tempfile.NamedTemporaryFile(delete=True) tempFile.close() data.saveAsPickleFile(tempFile.name) @@ -461,13 +484,13 @@ def test_cartesian_on_textfile(self): a = self.sc.textFile(path) result = a.cartesian(a).collect() (x, y) = result[0] - self.assertEqual("Hello World!", x.strip()) - self.assertEqual("Hello World!", y.strip()) + self.assertEqual(u"Hello World!", x.strip()) + self.assertEqual(u"Hello World!", y.strip()) def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write("Hello World!") + tempFile.write(b"Hello World!") tempFile.close() data = self.sc.textFile(tempFile.name) filtered_data = data.filter(lambda x: True) @@ -510,21 +533,21 @@ def test_namedtuple_in_rdd(self): jon = Person(1, "Jon", "Doe") jane = Person(2, "Jane", "Doe") theDoes = self.sc.parallelize([jon, jane]) - self.assertEquals([jon, jane], theDoes.collect()) + self.assertEqual([jon, jane], theDoes.collect()) def test_large_broadcast(self): N = 100000 data = [[float(i) for i in range(300)] for i in range(N)] bdata = self.sc.broadcast(data) # 270MB m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() - self.assertEquals(N, m) + self.assertEqual(N, m) def test_multiple_broadcasts(self): N = 1 << 21 b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM - r = range(1 << 15) + r = list(range(1 << 15)) random.shuffle(r) - s = str(r) + s = str(r).encode() checksum = hashlib.md5(s).hexdigest() b2 = self.sc.broadcast(s) r = list(set(self.sc.parallelize(range(10), 10).map( @@ -535,7 +558,7 @@ def test_multiple_broadcasts(self): self.assertEqual(checksum, csum) random.shuffle(r) - s = str(r) + s = str(r).encode() checksum = hashlib.md5(s).hexdigest() b2 = self.sc.broadcast(s) r = list(set(self.sc.parallelize(range(10), 10).map( @@ -549,7 +572,7 @@ def test_large_closure(self): N = 1000000 data = [float(i) for i in xrange(N)] rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) - self.assertEquals(N, rdd.first()) + self.assertEqual(N, rdd.first()) # regression test for SPARK-6886 self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count()) @@ -590,15 +613,15 @@ def test_zip_with_different_number_of_items(self): # same total number of items, but different distributions a = self.sc.parallelize([2, 3], 2).flatMap(range) b = self.sc.parallelize([3, 2], 2).flatMap(range) - self.assertEquals(a.count(), b.count()) + self.assertEqual(a.count(), b.count()) self.assertRaises(Exception, lambda: a.zip(b).count()) def test_count_approx_distinct(self): rdd = self.sc.parallelize(range(1000)) - self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050) - self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050) - self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050) - self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050) + self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050) + self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050) rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) self.assertTrue(18 < rdd.countApproxDistinct() < 22) @@ -612,59 +635,59 @@ def test_count_approx_distinct(self): def test_histogram(self): # empty rdd = self.sc.parallelize([]) - self.assertEquals([0], rdd.histogram([0, 10])[1]) - self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) self.assertRaises(ValueError, lambda: rdd.histogram(1)) # out of range rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEquals([0], rdd.histogram([0, 10])[1]) - self.assertEquals([0, 0], rdd.histogram((0, 4, 10))[1]) + self.assertEqual([0], rdd.histogram([0, 10])[1]) + self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1]) # in range with one bucket rdd = self.sc.parallelize(range(1, 5)) - self.assertEquals([4], rdd.histogram([0, 10])[1]) - self.assertEquals([3, 1], rdd.histogram([0, 4, 10])[1]) + self.assertEqual([4], rdd.histogram([0, 10])[1]) + self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1]) # in range with one bucket exact match - self.assertEquals([4], rdd.histogram([1, 4])[1]) + self.assertEqual([4], rdd.histogram([1, 4])[1]) # out of range with two buckets rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEquals([0, 0], rdd.histogram([0, 5, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1]) # out of range with two uneven buckets rdd = self.sc.parallelize([10.01, -0.01]) - self.assertEquals([0, 0], rdd.histogram([0, 4, 10])[1]) + self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1]) # in range with two buckets rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) # in range with two bucket and None rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')]) - self.assertEquals([3, 2], rdd.histogram([0, 5, 10])[1]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1]) # in range with two uneven buckets rdd = self.sc.parallelize([1, 2, 3, 5, 6]) - self.assertEquals([3, 2], rdd.histogram([0, 5, 11])[1]) + self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1]) # mixed range with two uneven buckets rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01]) - self.assertEquals([4, 3], rdd.histogram([0, 5, 11])[1]) + self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1]) # mixed range with four uneven buckets rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1]) - self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) # mixed range with uneven buckets and NaN rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1, None, float('nan')]) - self.assertEquals([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) + self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1]) # out of range with infinite buckets rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")]) - self.assertEquals([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) + self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1]) # invalid buckets self.assertRaises(ValueError, lambda: rdd.histogram([])) @@ -674,25 +697,25 @@ def test_histogram(self): # without buckets rdd = self.sc.parallelize(range(1, 5)) - self.assertEquals(([1, 4], [4]), rdd.histogram(1)) + self.assertEqual(([1, 4], [4]), rdd.histogram(1)) # without buckets single element rdd = self.sc.parallelize([1]) - self.assertEquals(([1, 1], [1]), rdd.histogram(1)) + self.assertEqual(([1, 1], [1]), rdd.histogram(1)) # without bucket no range rdd = self.sc.parallelize([1] * 4) - self.assertEquals(([1, 1], [4]), rdd.histogram(1)) + self.assertEqual(([1, 1], [4]), rdd.histogram(1)) # without buckets basic two rdd = self.sc.parallelize(range(1, 5)) - self.assertEquals(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) + self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2)) # without buckets with more requested than elements rdd = self.sc.parallelize([1, 2]) buckets = [1 + 0.2 * i for i in range(6)] hist = [1, 0, 0, 0, 1] - self.assertEquals((buckets, hist), rdd.histogram(5)) + self.assertEqual((buckets, hist), rdd.histogram(5)) # invalid RDDs rdd = self.sc.parallelize([1, float('inf')]) @@ -702,15 +725,8 @@ def test_histogram(self): # string rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2) - self.assertEquals([2, 2], rdd.histogram(["a", "b", "c"])[1]) - self.assertEquals((["ab", "ef"], [5]), rdd.histogram(1)) - self.assertRaises(TypeError, lambda: rdd.histogram(2)) - - # mixed RDD - rdd = self.sc.parallelize([1, 4, "ab", "ac", "b"], 2) - self.assertEquals([1, 1], rdd.histogram([0, 4, 10])[1]) - self.assertEquals([2, 1], rdd.histogram(["a", "b", "c"])[1]) - self.assertEquals(([1, "b"], [5]), rdd.histogram(1)) + self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1]) + self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1)) self.assertRaises(TypeError, lambda: rdd.histogram(2)) def test_repartitionAndSortWithinPartitions(self): @@ -718,31 +734,31 @@ def test_repartitionAndSortWithinPartitions(self): repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2) partitions = repartitioned.glom().collect() - self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)]) - self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)]) + self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)]) + self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)]) def test_distinct(self): rdd = self.sc.parallelize((1, 2, 3)*10, 10) - self.assertEquals(rdd.getNumPartitions(), 10) - self.assertEquals(rdd.distinct().count(), 3) + self.assertEqual(rdd.getNumPartitions(), 10) + self.assertEqual(rdd.distinct().count(), 3) result = rdd.distinct(5) - self.assertEquals(result.getNumPartitions(), 5) - self.assertEquals(result.count(), 3) + self.assertEqual(result.getNumPartitions(), 5) + self.assertEqual(result.count(), 3) def test_external_group_by_key(self): - self.sc._conf.set("spark.python.worker.memory", "5m") + self.sc._conf.set("spark.python.worker.memory", "1m") N = 200001 kv = self.sc.parallelize(range(N)).map(lambda x: (x % 3, x)) gkv = kv.groupByKey().cache() self.assertEqual(3, gkv.count()) - filtered = gkv.filter(lambda (k, vs): k == 1) + filtered = gkv.filter(lambda kv: kv[0] == 1) self.assertEqual(1, filtered.count()) - self.assertEqual([(1, N/3)], filtered.mapValues(len).collect()) - self.assertEqual([(N/3, N/3)], + self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect()) + self.assertEqual([(N // 3, N // 3)], filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) result = filtered.collect()[0][1] - self.assertEqual(N/3, len(result)) - self.assertTrue(isinstance(result.data, shuffle.ExternalList)) + self.assertEqual(N // 3, len(result)) + self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList)) def test_sort_on_empty_rdd(self): self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) @@ -767,7 +783,7 @@ def test_null_in_rdd(self): rdd = RDD(jrdd, self.sc, UTF8Deserializer()) self.assertEqual([u"a", None, u"b"], rdd.collect()) rdd = RDD(jrdd, self.sc, NoOpSerializer()) - self.assertEqual(["a", None, "b"], rdd.collect()) + self.assertEqual([b"a", None, b"b"], rdd.collect()) def test_multiple_python_java_RDD_conversions(self): # Regression test for SPARK-5361 @@ -813,14 +829,14 @@ def test_narrow_dependency_in_join(self): self.sc.setJobGroup("test3", "test", True) d = sorted(parted.cogroup(parted).collect()) self.assertEqual(10, len(d)) - self.assertEqual([[0], [0]], map(list, d[0][1])) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) jobId = tracker.getJobIdsForGroup("test3")[0] self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds)) self.sc.setJobGroup("test4", "test", True) d = sorted(parted.cogroup(rdd).collect()) self.assertEqual(10, len(d)) - self.assertEqual([[0], [0]], map(list, d[0][1])) + self.assertEqual([[0], [0]], list(map(list, d[0][1]))) jobId = tracker.getJobIdsForGroup("test4")[0] self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds)) @@ -906,6 +922,7 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name) + @unittest.skipIf(sys.version >= "3", "serialize array of byte") def test_sequencefiles(self): basepath = self.tempdir.name ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", @@ -954,15 +971,16 @@ def test_sequencefiles(self): en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] self.assertEqual(nulls, en) - maps = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect()) + maps = self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect() em = [(1, {}), (1, {3.0: u'bb'}), (2, {1.0: u'aa'}), (2, {1.0: u'cc'}), (3, {2.0: u'dd'})] - self.assertEqual(maps, em) + for v in maps: + self.assertTrue(v in em) # arrays get pickled to tuples by default tuples = sorted(self.sc.sequenceFile( @@ -1089,8 +1107,8 @@ def test_converters(self): def test_binary_files(self): path = os.path.join(self.tempdir.name, "binaryfiles") os.mkdir(path) - data = "short binary data" - with open(os.path.join(path, "part-0000"), 'w') as f: + data = b"short binary data" + with open(os.path.join(path, "part-0000"), 'wb') as f: f.write(data) [(p, d)] = self.sc.binaryFiles(path).collect() self.assertTrue(p.endswith("part-0000")) @@ -1103,7 +1121,7 @@ def test_binary_records(self): for i in range(100): f.write('%04d' % i) result = self.sc.binaryRecords(path, 4).map(int).collect() - self.assertEqual(range(100), result) + self.assertEqual(list(range(100)), result) class OutputFormatTests(ReusedPySparkTestCase): @@ -1115,6 +1133,7 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.tempdir.name, ignore_errors=True) + @unittest.skipIf(sys.version >= "3", "serialize array of byte") def test_sequencefiles(self): basepath = self.tempdir.name ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] @@ -1155,8 +1174,9 @@ def test_sequencefiles(self): (2, {1.0: u'cc'}), (3, {2.0: u'dd'})] self.sc.parallelize(em).saveAsSequenceFile(basepath + "/sfmap/") - maps = sorted(self.sc.sequenceFile(basepath + "/sfmap/").collect()) - self.assertEqual(maps, em) + maps = self.sc.sequenceFile(basepath + "/sfmap/").collect() + for v in maps: + self.assertTrue(v, em) def test_oldhadoop(self): basepath = self.tempdir.name @@ -1168,12 +1188,13 @@ def test_oldhadoop(self): "org.apache.hadoop.mapred.SequenceFileOutputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable") - result = sorted(self.sc.hadoopFile( + result = self.sc.hadoopFile( basepath + "/oldhadoop/", "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.MapWritable").collect()) - self.assertEqual(result, dict_data) + "org.apache.hadoop.io.MapWritable").collect() + for v in result: + self.assertTrue(v, dict_data) conf = { "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", @@ -1183,12 +1204,13 @@ def test_oldhadoop(self): } self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) input_conf = {"mapred.input.dir": basepath + "/olddataset/"} - old_dataset = sorted(self.sc.hadoopRDD( + result = self.sc.hadoopRDD( "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.MapWritable", - conf=input_conf).collect()) - self.assertEqual(old_dataset, dict_data) + conf=input_conf).collect() + for v in result: + self.assertTrue(v, dict_data) def test_newhadoop(self): basepath = self.tempdir.name @@ -1223,6 +1245,7 @@ def test_newhadoop(self): conf=input_conf).collect()) self.assertEqual(new_dataset, data) + @unittest.skipIf(sys.version >= "3", "serialize of array") def test_newhadoop_with_array(self): basepath = self.tempdir.name # use custom ArrayWritable types and converters to handle arrays @@ -1303,7 +1326,7 @@ def test_reserialization(self): basepath = self.tempdir.name x = range(1, 5) y = range(1001, 1005) - data = zip(x, y) + data = list(zip(x, y)) rdd = self.sc.parallelize(x).zip(self.sc.parallelize(y)) rdd.saveAsSequenceFile(basepath + "/reserialize/sequence") result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) @@ -1354,7 +1377,7 @@ def connect(self, port): sock = socket(AF_INET, SOCK_STREAM) sock.connect(('127.0.0.1', port)) # send a split index of -1 to shutdown the worker - sock.send("\xFF\xFF\xFF\xFF") + sock.send(b"\xFF\xFF\xFF\xFF") sock.close() return True @@ -1395,7 +1418,6 @@ def test_termination_sigterm(self): class WorkerTests(PySparkTestCase): - def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) temp.close() @@ -1410,7 +1432,7 @@ def sleep(x): # start job in background thread def run(): - self.sc.parallelize(range(1)).foreach(sleep) + self.sc.parallelize(range(1), 1).foreach(sleep) import threading t = threading.Thread(target=run) t.daemon = True @@ -1419,7 +1441,8 @@ def run(): daemon_pid, worker_pid = 0, 0 while True: if os.path.exists(path): - data = open(path).read().split(' ') + with open(path) as f: + data = f.read().split(' ') daemon_pid, worker_pid = map(int, data) break time.sleep(0.1) @@ -1455,7 +1478,7 @@ def raise_exception(_): def test_after_jvm_exception(self): tempFile = tempfile.NamedTemporaryFile(delete=False) - tempFile.write("Hello World!") + tempFile.write(b"Hello World!") tempFile.close() data = self.sc.textFile(tempFile.name, 1) filtered_data = data.filter(lambda x: True) @@ -1577,12 +1600,12 @@ def test_single_script(self): |from pyspark import SparkContext | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect() + |print(sc.parallelize([1, 2, 3]).map(lambda x: x * 2).collect()) """) proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) def test_script_with_local_functions(self): """Submit and test a single script file calling a global function""" @@ -1593,12 +1616,12 @@ def test_script_with_local_functions(self): | return x * 3 | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(foo).collect() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) proc = subprocess.Popen([self.sparkSubmit, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[3, 6, 9]", out) + self.assertIn("[3, 6, 9]", out.decode('utf-8')) def test_module_dependency(self): """Submit and test a script with a dependency on another module""" @@ -1607,7 +1630,7 @@ def test_module_dependency(self): |from mylib import myfunc | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) zip = self.createFileInZip("mylib.py", """ |def myfunc(x): @@ -1617,7 +1640,7 @@ def test_module_dependency(self): stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) def test_module_dependency_on_cluster(self): """Submit and test a script with a dependency on another module on a cluster""" @@ -1626,7 +1649,7 @@ def test_module_dependency_on_cluster(self): |from mylib import myfunc | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) zip = self.createFileInZip("mylib.py", """ |def myfunc(x): @@ -1637,7 +1660,7 @@ def test_module_dependency_on_cluster(self): stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) def test_package_dependency(self): """Submit and test a script with a dependency on a Spark Package""" @@ -1646,14 +1669,14 @@ def test_package_dependency(self): |from mylib import myfunc | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", "file:" + self.programDir, script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) def test_package_dependency_on_cluster(self): """Submit and test a script with a dependency on a Spark Package on a cluster""" @@ -1662,7 +1685,7 @@ def test_package_dependency_on_cluster(self): |from mylib import myfunc | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(myfunc).collect() + |print(sc.parallelize([1, 2, 3]).map(myfunc).collect()) """) self.create_spark_package("a:mylib:0.1") proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", @@ -1670,7 +1693,7 @@ def test_package_dependency_on_cluster(self): "local-cluster[1,1,512]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 3, 4]", out) + self.assertIn("[2, 3, 4]", out.decode('utf-8')) def test_single_script_on_cluster(self): """Submit and test a single script on a cluster""" @@ -1681,7 +1704,7 @@ def test_single_script_on_cluster(self): | return x * 2 | |sc = SparkContext() - |print sc.parallelize([1, 2, 3]).map(foo).collect() + |print(sc.parallelize([1, 2, 3]).map(foo).collect()) """) # this will fail if you have different spark.executor.memory # in conf/spark-defaults.conf @@ -1690,7 +1713,7 @@ def test_single_script_on_cluster(self): stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) - self.assertIn("[2, 4, 6]", out) + self.assertIn("[2, 4, 6]", out.decode('utf-8')) class ContextTests(unittest.TestCase): @@ -1765,7 +1788,7 @@ class SciPyTests(PySparkTestCase): def test_serialize(self): from scipy.special import gammaln x = range(1, 5) - expected = map(gammaln, x) + expected = list(map(gammaln, x)) observed = self.sc.parallelize(x).map(gammaln).collect() self.assertEqual(expected, observed) @@ -1786,11 +1809,11 @@ def test_statcounter_array(self): if __name__ == "__main__": if not _have_scipy: - print "NOTE: Skipping SciPy tests as it does not seem to be installed" + print("NOTE: Skipping SciPy tests as it does not seem to be installed") if not _have_numpy: - print "NOTE: Skipping NumPy tests as it does not seem to be installed" + print("NOTE: Skipping NumPy tests as it does not seem to be installed") unittest.main() if not _have_scipy: - print "NOTE: SciPy tests were skipped as it does not seem to be installed" + print("NOTE: SciPy tests were skipped as it does not seem to be installed") if not _have_numpy: - print "NOTE: NumPy tests were skipped as it does not seem to be installed" + print("NOTE: NumPy tests were skipped as it does not seem to be installed") diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 452d6fabdcc17..fbdaf3a5814cd 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -18,6 +18,7 @@ """ Worker that receives input from Piped RDD. """ +from __future__ import print_function import os import sys import time @@ -37,9 +38,9 @@ def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) - write_long(1000 * boot, outfile) - write_long(1000 * init, outfile) - write_long(1000 * finish, outfile) + write_long(int(1000 * boot), outfile) + write_long(int(1000 * init), outfile) + write_long(int(1000 * finish), outfile) def add_path(path): @@ -72,6 +73,9 @@ def main(infile, outfile): for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) add_path(os.path.join(spark_files_dir, filename)) + if sys.version > '3': + import importlib + importlib.invalidate_caches() # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) @@ -106,14 +110,14 @@ def process(): except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(traceback.format_exc(), outfile) + write_with_length(traceback.format_exc().encode("utf-8"), outfile) except IOError: # JVM close the socket pass except Exception: # Write the error to stderr if it happened while serializing - print >> sys.stderr, "PySpark worker failed with exception:" - print >> sys.stderr, traceback.format_exc() + print("PySpark worker failed with exception:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) diff --git a/python/run-tests b/python/run-tests index f3a07d8aba562..ed3e819ef30c1 100755 --- a/python/run-tests +++ b/python/run-tests @@ -66,7 +66,7 @@ function run_core_tests() { function run_sql_tests() { echo "Run sql tests ..." - run_test "pyspark/sql/types.py" + run_test "pyspark/sql/_types.py" run_test "pyspark/sql/context.py" run_test "pyspark/sql/dataframe.py" run_test "pyspark/sql/functions.py" @@ -136,6 +136,19 @@ run_mllib_tests run_ml_tests run_streaming_tests +# Try to test with Python 3 +if [ $(which python3.4) ]; then + export PYSPARK_PYTHON="python3.4" + echo "Testing with Python3.4 version:" + $PYSPARK_PYTHON --version + + run_core_tests + run_sql_tests + run_mllib_tests + run_ml_tests + run_streaming_tests +fi + # Try to test with PyPy if [ $(which pypy) ]; then export PYSPARK_PYTHON="pypy" diff --git a/python/test_support/userlib-0.1-py2.7.egg b/python/test_support/userlib-0.1-py2.7.egg deleted file mode 100644 index 1674c9cb2227e160aec55c74e9ba1ddba850f580..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1945 zcmWIWW@Zs#U|`^2Sdq%-{OPC4zkVQZ3lQ@IacOaCQBG!(er{rBo?bzvZr?^OCPM+2 z-!(^Pyju5>Ifws9kif>ilh5?x&a}Ph`t|+UfpA{q4-(mHcXF|6ID3k0S`rmJ^Pu27 zX2pP|PDj;+_Dt5>5pw-nXZI!d?|Lntzt1q@eSGhRp0S;^U`hJyZ6+W7I>sGuc2o~I z;$-<*VantEcQTf&4@>aB-66eQMnKwrV%L&?hZb7x-!-%B+8>4hxZ7LO_?&A_9oL=$ zbn|W?2Kfl)_W1bByv&mLc%b`}lmDGLz4eV;d4LhKhymk9W&!bThE5TMq?8kh2`5rh zI+9Wnd=is5t|hQ}rTk-O;&oG4-6HJKBDP%2^0|tb`0vbyFVikRj(5)S6>%#*g z4>4Xznml`c(5%T>?3=boqzIt-qG4rZelQ~gLjn^6g8-5*pfQlVbi#eF!v-Slm#_J* zt#~NyvVK)UwfP*@J=O@3Im_WVolA58N~>g61`pSxe0_vE)< z^ZdsjUjEk5J^74tL*7Q7qYqkq3~ia@R@xk@{p5N=T`@m+{)dK?i}POB9N~!PDabP2 znxJE9d~uu8)sXNlS7bEW#Y9#FH7t}rbS~rCraZfArw&~@*7N!K)x8r0<{n>Yd*t

waS|`|O?l3T$=A=7qdQyq0*MfDE zS-ZRgRd|K9^1{m+-Ws>d>^fX$@A>54M%x1?ijPX2D3<;<+wF7fk=$4Lh4pO9-@=?k}n_0dd;3x|M^$py6en$SghCm=SC!>>%c6)9BG#!3k-N^AQnc7 z0HkE(t$pfz=n3C5XM#_h(LJHB7099Mr(1N+Q%CO^6IygS{lCb@1vG^Nh{b@|)!kj! z)6dOcza+mPJ}0#-HAkN)2L?f8-Bp`(* z^5{v|nXIY_yAzriEjv{+f4c7E5CpoBkx7IZ6hzp|O`rrAY-t2hu#yOxhmkd7E4Uz9 zfrW769wg03=`R`G1oT1!VL}ry>7ZGUq8nR^N9bk-CO(*MB>T~=M^EGk0|J2tz!MQl zl1DcKJ*gwi=tnjKmhkau2c>%$*wT0qSv$5|fNm6eCO{bVAK56REP-wUdYVO;Fo^{z YJYXp}z?+o~q=XX)eSyWk87qhf04R`+K>z>% diff --git a/python/test_support/userlib-0.1.zip b/python/test_support/userlib-0.1.zip new file mode 100644 index 0000000000000000000000000000000000000000..496e1349aa9674cfa4f232ea8b5e17fcdb49f4fd GIT binary patch literal 668 zcmWIWW@h1H0D&_>m2O}Ll;B{HVJIz5Ey~GE(hm*cWMDr3PAF9Jy-;Xr1vdjD%U4DQ z1{M*Z+5or-kw6n(?`LOH28x2PFp>%J@tJv%$WKZ+A#ed zj*Lul%(&u00_Z;w5MX%g2%-@&!wQKRw17l54KrRKrU7G+VN0VEl4&Th1T+{+3_%RW k6#>WwZ$dH{9vwhKLD7NXT2?loY6fN?`~#%BKu%)-0QB3qpa1{> literal 0 HcmV?d00001 From 5fe43433529346788e8c343d338a5b7dc169cf58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Thu, 16 Apr 2015 17:32:42 -0700 Subject: [PATCH 023/191] SPARK-6927 [SQL] Sorting Error when codegen on MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix this error by adding BinaryType comparor in GenerateOrdering. JIRA https://issues.apache.org/jira/browse/SPARK-6927 Author: 云峤 Closes #5524 from kaka1992/fix-codegen-sort and squashes the following commits: d7e2afe [云峤] fix codegen sorting error --- .../codegen/GenerateOrdering.scala | 14 ++++++++++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 20 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 0db29eb404bd1..fc2a2b60703e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, NumericType} +import org.apache.spark.sql.types.{BinaryType, StringType, NumericType} /** * Generates bytecode for an [[Ordering]] of [[Row Rows]] for a given set of @@ -43,6 +43,18 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit val evalB = expressionEvaluator(order.child) val compare = order.child.dataType match { + case BinaryType => + q""" + val x = ${if (order.direction == Ascending) evalA.primitiveTerm else evalB.primitiveTerm} + val y = ${if (order.direction != Ascending) evalB.primitiveTerm else evalA.primitiveTerm} + var i = 0 + while (i < x.length && i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + i = i+1 + } + return x.length - y.length + """ case _: NumericType => q""" val comp = ${evalA.primitiveTerm} - ${evalB.primitiveTerm} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d739e550f3e56..9e02e69fda3f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -398,6 +398,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { setConf(SQLConf.EXTERNAL_SORT, before.toString) } + test("SPARK-6927 sorting with codegen on") { + val externalbefore = conf.externalSortEnabled + val codegenbefore = conf.codegenEnabled + setConf(SQLConf.EXTERNAL_SORT, "false") + setConf(SQLConf.CODEGEN_ENABLED, "true") + sortTest() + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } + + test("SPARK-6927 external sorting with codegen on") { + val externalbefore = conf.externalSortEnabled + val codegenbefore = conf.codegenEnabled + setConf(SQLConf.CODEGEN_ENABLED, "true") + setConf(SQLConf.EXTERNAL_SORT, "true") + sortTest() + setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + } + test("limit") { checkAnswer( sql("SELECT * FROM testData LIMIT 10"), From 6183b5e2caedd074073d0f6cb6609a634e2f5194 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 16 Apr 2015 17:33:57 -0700 Subject: [PATCH 024/191] [SPARK-6911] [SQL] improve accessor for nested types Support access columns by index in Python: ``` >>> df[df[0] > 3].collect() [Row(age=5, name=u'Bob')] ``` Access items in ArrayType or MapType ``` >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() >>> df.select(df.l[0], df.d["key"]).show() ``` Access field in StructType ``` >>> df.select(df.r.getField("b")).show() >>> df.select(df.r.a).show() ``` Author: Davies Liu Closes #5513 from davies/access and squashes the following commits: e04d5a0 [Davies Liu] Update run-tests-jenkins 7ada9eb [Davies Liu] update timeout d125ac4 [Davies Liu] check column name, improve scala tests 6b62540 [Davies Liu] fix test db15b42 [Davies Liu] Merge branch 'master' of github.com:apache/spark into access 6c32e79 [Davies Liu] add scala tests 11f1df3 [Davies Liu] improve accessor for nested types --- python/pyspark/sql/dataframe.py | 49 +++++++++++++++++-- python/pyspark/sql/tests.py | 18 +++++++ .../scala/org/apache/spark/sql/Column.scala | 7 +-- .../org/apache/spark/sql/DataFrameSuite.scala | 6 +++ .../scala/org/apache/spark/sql/TestData.scala | 9 ++-- 5 files changed, 76 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d76504f986270..b9a3e6cfe7f49 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -563,16 +563,23 @@ def __getitem__(self, item): [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] >>> df[ df.age > 3 ].collect() [Row(age=5, name=u'Bob')] + >>> df[df[0] > 3].collect() + [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): + if item not in self.columns: + raise IndexError("no such column: %s" % item) jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): return self.filter(item) - elif isinstance(item, list): + elif isinstance(item, (list, tuple)): return self.select(*item) + elif isinstance(item, int): + jc = self._jdf.apply(self.columns[item]) + return Column(jc) else: - raise IndexError("unexpected index: %s" % item) + raise TypeError("unexpected type: %s" % type(item)) def __getattr__(self, name): """Returns the :class:`Column` denoted by ``name``. @@ -580,8 +587,8 @@ def __getattr__(self, name): >>> df.select(df.age).collect() [Row(age=2), Row(age=5)] """ - if name.startswith("__"): - raise AttributeError(name) + if name not in self.columns: + raise AttributeError("No such column: %s" % name) jc = self._jdf.apply(name) return Column(jc) @@ -1093,7 +1100,39 @@ def __init__(self, jc): # container operators __contains__ = _bin_op("contains") __getitem__ = _bin_op("getItem") - getField = _bin_op("getField", "An expression that gets a field by name in a StructField.") + + def getItem(self, key): + """An expression that gets an item at position `ordinal` out of a list, + or gets an item by key out of a dict. + + >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() + l[0] d[key] + 1 value + >>> df.select(df.l[0], df.d["key"]).show() + l[0] d[key] + 1 value + """ + return self[key] + + def getField(self, name): + """An expression that gets a field by name in a StructField. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df.select(df.r.getField("b")).show() + r.b + b + >>> df.select(df.r.a).show() + r.a + 1 + """ + return Column(self._jc.getField(name)) + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + return self.getField(item) # string methods rlike = _bin_op("rlike") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7c09a0cfe30ab..6691e8c8dc44b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -426,6 +426,24 @@ def test_help_command(self): pydoc.render_doc(df.foo) pydoc.render_doc(df.take(1)) + def test_access_column(self): + df = self.df + self.assertTrue(isinstance(df.key, Column)) + self.assertTrue(isinstance(df['key'], Column)) + self.assertTrue(isinstance(df[0], Column)) + self.assertRaises(IndexError, lambda: df[2]) + self.assertRaises(IndexError, lambda: df["bad_key"]) + self.assertRaises(TypeError, lambda: df[{}]) + + def test_access_nested_types(self): + df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() + self.assertEqual(1, df.select(df.l[0]).first()[0]) + self.assertEqual(1, df.select(df.l.getItem(0)).first()[0]) + self.assertEqual(1, df.select(df.r.a).first()[0]) + self.assertEqual("b", df.select(df.r.getField("b")).first()[0]) + self.assertEqual("v", df.select(df.d["k"]).first()[0]) + self.assertEqual("v", df.select(df.d.getItem("k")).first()[0]) + def test_infer_long_type(self): longrow = [Row(f1='a', f2=100000000000000)] df = self.sc.parallelize(longrow).toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 3cd7adf8cab5e..edb229c059e6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -515,14 +515,15 @@ class Column(protected[sql] val expr: Expression) extends Logging { def rlike(literal: String): Column = RLike(expr, lit(literal).expr) /** - * An expression that gets an item at position `ordinal` out of an array. + * An expression that gets an item at position `ordinal` out of an array, + * or gets a value by key `key` in a [[MapType]]. * * @group expr_ops */ - def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal)) + def getItem(key: Any): Column = GetItem(expr, Literal(key)) /** - * An expression that gets a field by name in a [[StructField]]. + * An expression that gets a field by name in a [[StructType]]. * * @group expr_ops */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b26e22f6229fe..34b2cb054a3e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -86,6 +86,12 @@ class DataFrameSuite extends QueryTest { TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } + test("access complex data") { + assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1) + assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1) + assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1) + } + test("table scan") { checkAnswer( testData, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 637f59b2e68ca..225b51bd73d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test._ import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.test._ case class TestData(key: Int, value: String) @@ -199,11 +198,11 @@ object TestData { Salary(1, 1000.0) :: Nil).toDF() salary.registerTempTable("salary") - case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean) + case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean) val complexData = TestSQLContext.sparkContext.parallelize( - ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) - :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) + ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1), true) + :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false) :: Nil).toDF() complexData.registerTempTable("complexData") } From d96608674f6c2ff3abb13c65d80c1a3872206710 Mon Sep 17 00:00:00 2001 From: scwf Date: Thu, 16 Apr 2015 17:35:51 -0700 Subject: [PATCH 025/191] [SQL][Minor] Fix foreachUp of treenode `foreachUp` should runs the given function recursively on [[children]] then on this node(just like transformUp). The current implementation does not follow this. This will leads to checkanalysis do not check from bottom of logical tree. Author: scwf Author: Fei Wang Closes #5518 from scwf/patch-1 and squashes the following commits: 18e28b2 [scwf] added a test case 1ccbfa8 [Fei Wang] fix foreachUp --- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 2 +- .../spark/sql/catalyst/trees/TreeNodeSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index a2df51e598a2b..97502ed3afe72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -85,7 +85,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { * @param f the function to be applied to each node in the tree. */ def foreachUp(f: BaseType => Unit): Unit = { - children.foreach(_.foreach(f)) + children.foreach(_.foreachUp(f)) f(this) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 4eb8708335dcf..6b393327cc97a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -117,5 +117,17 @@ class TreeNodeSuite extends FunSuite { assert(transformed.origin.startPosition.isDefined) } + test("foreach up") { + val actual = new ArrayBuffer[String]() + val expected = Seq("1", "2", "3", "4", "-", "*", "+") + val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) + expression foreachUp { + case b: BinaryExpression => actual.append(b.symbol); + case l: Literal => actual.append(l.toString); + } + + assert(expected === actual) + } + } From 1e43851d6455f65b850ea0327d0e92f65395d23f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 16 Apr 2015 17:50:20 -0700 Subject: [PATCH 026/191] [SPARK-6899][SQL] Fix type mismatch when using codegen with Average on DecimalType JIRA https://issues.apache.org/jira/browse/SPARK-6899 Author: Liang-Chi Hsieh Closes #5517 from viirya/fix_codegen_average and squashes the following commits: 8ae5f65 [Liang-Chi Hsieh] Add the case of DecimalType.Unlimited to Average. --- .../spark/sql/catalyst/expressions/aggregates.scala | 2 +- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 14a855054b94d..f3830c6d3bcf2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -326,7 +326,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN override def asPartial: SplitEvaluation = { child.dataType match { - case DecimalType.Fixed(_, _) => + case DecimalType.Fixed(_, _) | DecimalType.Unlimited => // Turn the child to unlimited decimals for calculation, before going back to fixed val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() val partialCount = Alias(Count(child), "PartialCount")() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 34b2cb054a3e7..44a7d1e7bbb6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -537,4 +537,13 @@ class DataFrameSuite extends QueryTest { val df = TestSQLContext.createDataFrame(rowRDD, schema) df.rdd.collect() } + + test("SPARK-6899") { + val originalValue = TestSQLContext.conf.codegenEnabled + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } } From e5949c287ed19e78b6eecc61c3e88a07ad452eb9 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 16 Apr 2015 17:59:49 -0700 Subject: [PATCH 027/191] [SPARK-6966][SQL] Use correct ClassLoader for JDBC Driver Otherwise we cannot add jars with drivers after the fact. Author: Michael Armbrust Closes #5543 from marmbrus/jdbcClassloader and squashes the following commits: d9930f3 [Michael Armbrust] fix imports 73d0614 [Michael Armbrust] [SPARK-6966][SQL] Use correct ClassLoader for JDBC Driver --- .../main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 99b755c9f25d0..5f480083d5a49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils /** * Data corresponding to one partition of a JDBCRDD. @@ -99,7 +100,7 @@ private[sql] class DefaultSource extends RelationProvider { val upperBound = parameters.getOrElse("upperBound", null) val numPartitions = parameters.getOrElse("numPartitions", null) - if (driver != null) Class.forName(driver) + if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver) if (partitionColumn != null && (lowerBound == null || upperBound == null || numPartitions == null)) { From 8220d5265f1bbea9dfdaeec4f2d06d7fe24c0bc3 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 16 Apr 2015 21:49:26 -0500 Subject: [PATCH 028/191] [SPARK-6972][SQL] Add Coalesce to DataFrame Author: Michael Armbrust Closes #5545 from marmbrus/addCoalesce and squashes the following commits: 9fdf3f6 [Michael Armbrust] [SPARK-6972][SQL] Add Coalesce to DataFrame --- .../scala/org/apache/spark/sql/DataFrame.scala | 14 ++++++++++++++ .../main/scala/org/apache/spark/sql/RDDApi.scala | 2 ++ .../org/apache/spark/sql/DataFrameSuite.scala | 8 ++++++++ 3 files changed, 24 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3235f85d5bbd2..17c21f6e3a0e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -908,6 +908,20 @@ class DataFrame private[sql]( schema, needsConversion = false) } + /** + * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. + * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g. + * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of + * the 100 new partitions will claim 10 of the current partitions. + * @group rdd + */ + override def coalesce(numPartitions: Int): DataFrame = { + sqlContext.createDataFrame( + queryExecution.toRdd.coalesce(numPartitions), + schema, + needsConversion = false) + } + /** * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. * @group dfops diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala index ba4373f0124b4..63dbab19947c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RDDApi.scala @@ -61,5 +61,7 @@ private[sql] trait RDDApi[T] { def repartition(numPartitions: Int): DataFrame + def coalesce(numPartitions: Int): DataFrame + def distinct: DataFrame } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 44a7d1e7bbb6a..3250ab476aeb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -178,6 +178,14 @@ class DataFrameSuite extends QueryTest { testData.select('key).collect().toSeq) } + test("coalesce") { + assert(testData.select('key).coalesce(1).rdd.partitions.size === 1) + + checkAnswer( + testData.select('key).coalesce(1).select('key), + testData.select('key).collect().toSeq) + } + test("groupBy") { checkAnswer( testData2.groupBy("a").agg($"a", sum($"b")), From f7a25644ed5b3b49fe7f33743bec3d95cdf7913e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 17 Apr 2015 11:02:31 +0100 Subject: [PATCH 029/191] SPARK-6846 [WEBUI] Stage kill URL easy to accidentally trigger and possibility for security issue kill endpoints now only accept a POST (kill stage, master kill app, master kill driver); kill link now POSTs Author: Sean Owen Closes #5528 from srowen/SPARK-6846 and squashes the following commits: 137ac9f [Sean Owen] Oops, fix scalastyle line length probelm 7c5f961 [Sean Owen] Add Imran's test of kill link 59f447d [Sean Owen] kill endpoints now only accept a POST (kill stage, master kill app, master kill driver); kill link now POSTs --- .../org/apache/spark/ui/static/webui.css | 6 +-- .../spark/deploy/master/ui/MasterPage.scala | 28 +++++++------ .../spark/deploy/master/ui/MasterWebUI.scala | 8 ++-- .../org/apache/spark/ui/JettyUtils.scala | 17 +++++++- .../scala/org/apache/spark/ui/SparkUI.scala | 4 +- .../org/apache/spark/ui/jobs/StageTable.scala | 27 ++++++------- .../org/apache/spark/ui/UISeleniumSuite.scala | 40 +++++++++++++------ 7 files changed, 78 insertions(+), 52 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 6c37cc8b98236..4910744d1d790 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -85,17 +85,13 @@ table.sortable td { filter: progid:dximagetransform.microsoft.gradient(startColorstr='#FFA4EDFF', endColorstr='#FF94DDFF', GradientType=0); } -span.kill-link { +a.kill-link { margin-right: 2px; margin-left: 20px; color: gray; float: right; } -span.kill-link a { - color: gray; -} - span.expand-details { font-size: 10pt; cursor: pointer; diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 399f07399a0aa..1f2c3fdbfb2bc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -190,12 +190,14 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def appRow(app: ApplicationInfo): Seq[Node] = { val killLink = if (parent.killEnabled && (app.state == ApplicationState.RUNNING || app.state == ApplicationState.WAITING)) { - val killLinkUri = s"app/kill?id=${app.id}&terminate=true" - val confirm = "return window.confirm(" + - s"'Are you sure you want to kill application ${app.id} ?');" - - (kill) - + val confirm = + s"if (window.confirm('Are you sure you want to kill application ${app.id} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" +
+ + + (kill) + }

diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 1b670418ab1ff..bb11e0642ddc6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -43,10 +43,10 @@ class MasterWebUI(val master: Master, requestedPort: Int) attachPage(new HistoryNotFoundPage(this)) attachPage(masterPage) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler( - createRedirectHandler("/app/kill", "/", masterPage.handleAppKillRequest)) - attachHandler( - createRedirectHandler("/driver/kill", "/", masterPage.handleDriverKillRequest)) + attachHandler(createRedirectHandler( + "/app/kill", "/", masterPage.handleAppKillRequest, httpMethod = "POST")) + attachHandler(createRedirectHandler( + "/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethod = "POST")) } /** Attach a reconstructed UI to this Master UI. Only valid after bind(). */ diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 95f254a9ef22a..a091ca650c60c 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -114,10 +114,23 @@ private[spark] object JettyUtils extends Logging { srcPath: String, destPath: String, beforeRedirect: HttpServletRequest => Unit = x => (), - basePath: String = ""): ServletContextHandler = { + basePath: String = "", + httpMethod: String = "GET"): ServletContextHandler = { val prefixedDestPath = attachPrefix(basePath, destPath) val servlet = new HttpServlet { - override def doGet(request: HttpServletRequest, response: HttpServletResponse) { + override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { + httpMethod match { + case "GET" => doRequest(request, response) + case _ => response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } + } + override def doPost(request: HttpServletRequest, response: HttpServletResponse): Unit = { + httpMethod match { + case "POST" => doRequest(request, response) + case _ => response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } + } + private def doRequest(request: HttpServletRequest, response: HttpServletResponse): Unit = { beforeRedirect(request) // Make sure we don't end up with "//" in the middle val newUrl = new URL(new URL(request.getRequestURL.toString), prefixedDestPath).toString diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index adfa6bbada256..580ab8b1325f8 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -55,8 +55,8 @@ private[spark] class SparkUI private ( attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) - attachHandler( - createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest)) + attachHandler(createRedirectHandler( + "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, httpMethod = "POST")) } initialize() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 5865850fa09b5..cb72890a0fd20 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -73,20 +73,21 @@ private[ui] class StageTableBase( } private def makeDescription(s: StageInfo): Seq[Node] = { - // scalastyle:off + val basePathUri = UIUtils.prependBaseUri(basePath) + val killLink = if (killEnabled) { - val killLinkUri = "%s/stages/stage/kill?id=%s&terminate=true" - .format(UIUtils.prependBaseUri(basePath), s.stageId) - val confirm = "return window.confirm('Are you sure you want to kill stage %s ?');" - .format(s.stageId) - - (kill) - + val killLinkUri = s"$basePathUri/stages/stage/kill/" + val confirm = + s"if (window.confirm('Are you sure you want to kill stage ${s.stageId} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" + + + + (kill) + } - // scalastyle:on - val nameLinkUri ="%s/stages/stage?id=%s&attempt=%s" - .format(UIUtils.prependBaseUri(basePath), s.stageId, s.attemptId) + val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) @@ -98,11 +99,9 @@ private[ui] class StageTableBase( diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 1cb594633f331..eb9db550fd74c 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui +import java.net.{HttpURLConnection, URL} import javax.servlet.http.HttpServletRequest import scala.collection.JavaConversions._ @@ -56,12 +57,13 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before * Create a test SparkContext with the SparkUI enabled. * It is safe to `get` the SparkUI directly from the SparkContext returned here. */ - private def newSparkContext(): SparkContext = { + private def newSparkContext(killEnabled: Boolean = true): SparkContext = { val conf = new SparkConf() .setMaster("local") .setAppName("test") .set("spark.ui.enabled", "true") .set("spark.ui.port", "0") + .set("spark.ui.killEnabled", killEnabled.toString) val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc @@ -128,21 +130,12 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before } test("spark.ui.killEnabled should properly control kill button display") { - def getSparkContext(killEnabled: Boolean): SparkContext = { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - .set("spark.ui.enabled", "true") - .set("spark.ui.killEnabled", killEnabled.toString) - new SparkContext(conf) - } - def hasKillLink: Boolean = find(className("kill-link")).isDefined def runSlowJob(sc: SparkContext) { sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() } - withSpark(getSparkContext(killEnabled = true)) { sc => + withSpark(newSparkContext(killEnabled = true)) { sc => runSlowJob(sc) eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") @@ -150,7 +143,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before } } - withSpark(getSparkContext(killEnabled = false)) { sc => + withSpark(newSparkContext(killEnabled = false)) { sc => runSlowJob(sc) eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") @@ -233,7 +226,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before // because someone could change the error message and cause this test to pass by accident. // Instead, it's safer to check that each row contains a link to a stage details page. findAll(cssSelector("tbody tr")).foreach { row => - val link = row.underlying.findElement(By.xpath(".//a")) + val link = row.underlying.findElement(By.xpath("./td/div/a")) link.getAttribute("href") should include ("stage") } } @@ -356,4 +349,25 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before } } } + + test("kill stage is POST only") { + def getResponseCode(url: URL, method: String): Int = { + val connection = url.openConnection().asInstanceOf[HttpURLConnection] + connection.setRequestMethod(method) + connection.connect() + val code = connection.getResponseCode() + connection.disconnect() + code + } + + withSpark(newSparkContext(killEnabled = true)) { sc => + sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + val url = new URL( + sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0&terminate=true") + getResponseCode(url, "GET") should be (405) + getResponseCode(url, "POST") should be (200) + } + } + } } From 4527761bcd6501c362baf2780905a0018b9a74ba Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 17 Apr 2015 11:06:01 +0100 Subject: [PATCH 030/191] [SPARK-6046] [core] Reorganize deprecated config support in SparkConf. This change tries to follow the chosen way for handling deprecated configs in SparkConf: all values (old and new) are kept in the conf object, and newer names take precedence over older ones when retrieving the value. Warnings are logged when config options are set, which generally happens on the driver node (where the logs are most visible). Author: Marcelo Vanzin Closes #5514 from vanzin/SPARK-6046 and squashes the following commits: 9371529 [Marcelo Vanzin] Avoid math. 6cf3f11 [Marcelo Vanzin] Review feedback. 2445d48 [Marcelo Vanzin] Fix (and cleanup) update interval initialization. b6824be [Marcelo Vanzin] Clean up the other deprecated config use also. ab20351 [Marcelo Vanzin] Update FsHistoryProvider to only retrieve new config key. 2c93209 [Marcelo Vanzin] [SPARK-6046] [core] Reorganize deprecated config support in SparkConf. --- .../scala/org/apache/spark/SparkConf.scala | 174 ++++++++++-------- .../deploy/history/FsHistoryProvider.scala | 9 +- .../org/apache/spark/executor/Executor.scala | 5 +- .../org/apache/spark/SparkConfSuite.scala | 22 +++ docs/monitoring.md | 6 +- .../org/apache/spark/deploy/yarn/Client.scala | 3 +- 6 files changed, 124 insertions(+), 95 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 390e631647bd6..b0186e9a007b8 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -68,6 +68,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { if (value == null) { throw new NullPointerException("null value for " + key) } + logDeprecationWarning(key) settings.put(key, value) this } @@ -134,13 +135,15 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Set multiple parameters together */ def setAll(settings: Traversable[(String, String)]): SparkConf = { - this.settings.putAll(settings.toMap.asJava) + settings.foreach { case (k, v) => set(k, v) } this } /** Set a parameter if it isn't already configured */ def setIfMissing(key: String, value: String): SparkConf = { - settings.putIfAbsent(key, value) + if (settings.putIfAbsent(key, value) == null) { + logDeprecationWarning(key) + } this } @@ -174,8 +177,8 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { getOption(key).getOrElse(defaultValue) } - /** - * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no + /** + * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then seconds are assumed. * @throws NoSuchElementException */ @@ -183,36 +186,35 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.timeStringAsSeconds(get(key)) } - /** - * Get a time parameter as seconds, falling back to a default if not set. If no + /** + * Get a time parameter as seconds, falling back to a default if not set. If no * suffix is provided then seconds are assumed. - * */ def getTimeAsSeconds(key: String, defaultValue: String): Long = { Utils.timeStringAsSeconds(get(key, defaultValue)) } - /** - * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no - * suffix is provided then milliseconds are assumed. + /** + * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no + * suffix is provided then milliseconds are assumed. * @throws NoSuchElementException */ def getTimeAsMs(key: String): Long = { Utils.timeStringAsMs(get(key)) } - /** - * Get a time parameter as milliseconds, falling back to a default if not set. If no - * suffix is provided then milliseconds are assumed. + /** + * Get a time parameter as milliseconds, falling back to a default if not set. If no + * suffix is provided then milliseconds are assumed. */ def getTimeAsMs(key: String, defaultValue: String): Long = { Utils.timeStringAsMs(get(key, defaultValue)) } - + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { - Option(settings.get(key)) + Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) } /** Get all parameters as a list of pairs */ @@ -379,13 +381,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } } - - // Warn against the use of deprecated configs - deprecatedConfigs.values.foreach { dc => - if (contains(dc.oldName)) { - dc.warn() - } - } } /** @@ -400,19 +395,44 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { private[spark] object SparkConf extends Logging { + /** + * Maps deprecated config keys to information about the deprecation. + * + * The extra information is logged as a warning when the config is present in the user's + * configuration. + */ private val deprecatedConfigs: Map[String, DeprecatedConfig] = { val configs = Seq( - DeprecatedConfig("spark.files.userClassPathFirst", "spark.executor.userClassPathFirst", - "1.3"), - DeprecatedConfig("spark.yarn.user.classpath.first", null, "1.3", - "Use spark.{driver,executor}.userClassPathFirst instead."), - DeprecatedConfig("spark.history.fs.updateInterval", - "spark.history.fs.update.interval.seconds", - "1.3", "Use spark.history.fs.update.interval.seconds instead"), - DeprecatedConfig("spark.history.updateInterval", - "spark.history.fs.update.interval.seconds", - "1.3", "Use spark.history.fs.update.interval.seconds instead")) - configs.map { x => (x.oldName, x) }.toMap + DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", + "Please use spark.{driver,executor}.userClassPathFirst instead.")) + Map(configs.map { cfg => (cfg.key -> cfg) }:_*) + } + + /** + * Maps a current config key to alternate keys that were used in previous version of Spark. + * + * The alternates are used in the order defined in this map. If deprecated configs are + * present in the user's configuration, a warning is logged. + */ + private val configsWithAlternatives = Map[String, Seq[AlternateConfig]]( + "spark.executor.userClassPathFirst" -> Seq( + AlternateConfig("spark.files.userClassPathFirst", "1.3")), + "spark.history.fs.update.interval" -> Seq( + AlternateConfig("spark.history.fs.update.interval.seconds", "1.4"), + AlternateConfig("spark.history.fs.updateInterval", "1.3"), + AlternateConfig("spark.history.updateInterval", "1.3")) + ) + + /** + * A view of `configsWithAlternatives` that makes it more efficient to look up deprecated + * config keys. + * + * Maps the deprecated config name to a 2-tuple (new config name, alternate config info). + */ + private val allAlternatives: Map[String, (String, AlternateConfig)] = { + configsWithAlternatives.keys.flatMap { key => + configsWithAlternatives(key).map { cfg => (cfg.key -> (key -> cfg)) } + }.toMap } /** @@ -443,61 +463,57 @@ private[spark] object SparkConf extends Logging { } /** - * Translate the configuration key if it is deprecated and has a replacement, otherwise just - * returns the provided key. - * - * @param userKey Configuration key from the user / caller. - * @param warn Whether to print a warning if the key is deprecated. Warnings will be printed - * only once for each key. + * Looks for available deprecated keys for the given config option, and return the first + * value available. */ - private def translateConfKey(userKey: String, warn: Boolean = false): String = { - deprecatedConfigs.get(userKey) - .map { deprecatedKey => - if (warn) { - deprecatedKey.warn() - } - deprecatedKey.newName.getOrElse(userKey) - }.getOrElse(userKey) + def getDeprecatedConfig(key: String, conf: SparkConf): Option[String] = { + configsWithAlternatives.get(key).flatMap { alts => + alts.collectFirst { case alt if conf.contains(alt.key) => + val value = conf.get(alt.key) + alt.translation.map(_(value)).getOrElse(value) + } + } } /** - * Holds information about keys that have been deprecated or renamed. + * Logs a warning message if the given config key is deprecated. + */ + def logDeprecationWarning(key: String): Unit = { + deprecatedConfigs.get(key).foreach { cfg => + logWarning( + s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " + + s"may be removed in the future. ${cfg.deprecationMessage}") + } + + allAlternatives.get(key).foreach { case (newKey, cfg) => + logWarning( + s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " + + s"and may be removed in the future. Please use the new key '$newKey' instead.") + } + } + + /** + * Holds information about keys that have been deprecated and do not have a replacement. * - * @param oldName Old configuration key. - * @param newName New configuration key, or `null` if key has no replacement, in which case the - * deprecated key will be used (but the warning message will still be printed). + * @param key The deprecated key. * @param version Version of Spark where key was deprecated. - * @param deprecationMessage Message to include in the deprecation warning; mandatory when - * `newName` is not provided. + * @param deprecationMessage Message to include in the deprecation warning. */ private case class DeprecatedConfig( - oldName: String, - _newName: String, + key: String, version: String, - deprecationMessage: String = null) { - - private val warned = new AtomicBoolean(false) - val newName = Option(_newName) + deprecationMessage: String) - if (newName == null && (deprecationMessage == null || deprecationMessage.isEmpty())) { - throw new IllegalArgumentException("Need new config name or deprecation message.") - } - - def warn(): Unit = { - if (warned.compareAndSet(false, true)) { - if (newName != null) { - val message = Option(deprecationMessage).getOrElse( - s"Please use the alternative '$newName' instead.") - logWarning( - s"The configuration option '$oldName' has been replaced as of Spark $version and " + - s"may be removed in the future. $message") - } else { - logWarning( - s"The configuration option '$oldName' has been deprecated as of Spark $version and " + - s"may be removed in the future. $deprecationMessage") - } - } - } + /** + * Information about an alternate configuration key that has been deprecated. + * + * @param key The deprecated config key. + * @param version The Spark version in which the key was deprecated. + * @param translation A translation function for converting old config values into new ones. + */ + private case class AlternateConfig( + key: String, + version: String, + translation: Option[String => String] = None) - } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 9d40d8c8fd7a8..985545742df67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -49,11 +49,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private val NOT_STARTED = "" // Interval between each check for event log updates - private val UPDATE_INTERVAL_MS = conf.getOption("spark.history.fs.update.interval.seconds") - .orElse(conf.getOption("spark.history.fs.updateInterval")) - .orElse(conf.getOption("spark.history.updateInterval")) - .map(_.toInt) - .getOrElse(10) * 1000 + private val UPDATE_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.update.interval", "10s") // Interval between each cleaner checks for event logs to delete private val CLEAN_INTERVAL_MS = conf.getLong("spark.history.fs.cleaner.interval.seconds", @@ -130,8 +126,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis // Disable the background thread during tests. if (!conf.contains("spark.testing")) { // A task that periodically checks for event log updates on disk. - pool.scheduleAtFixedRate(getRunner(checkForLogs), 0, UPDATE_INTERVAL_MS, - TimeUnit.MILLISECONDS) + pool.scheduleAtFixedRate(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { // A task that periodically cleans event logs on disk. diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 1b5fdeba28ee2..327d155b38c22 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -89,10 +89,7 @@ private[spark] class Executor( ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) // Whether to load classes in user jars before those in Spark jars - private val userClassPathFirst: Boolean = { - conf.getBoolean("spark.executor.userClassPathFirst", - conf.getBoolean("spark.files.userClassPathFirst", false)) - } + private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index e08210ae60d17..7d87ba5fd2610 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -197,6 +197,28 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro serializer.newInstance().serialize(new StringBuffer()) } + test("deprecated configs") { + val conf = new SparkConf() + val newName = "spark.history.fs.update.interval" + + assert(!conf.contains(newName)) + + conf.set("spark.history.updateInterval", "1") + assert(conf.get(newName) === "1") + + conf.set("spark.history.fs.updateInterval", "2") + assert(conf.get(newName) === "2") + + conf.set("spark.history.fs.update.interval.seconds", "3") + assert(conf.get(newName) === "3") + + conf.set(newName, "4") + assert(conf.get(newName) === "4") + + val count = conf.getAll.filter { case (k, v) => k.startsWith("spark.history.") }.size + assert(count === 4) + } + } class Class1 {} diff --git a/docs/monitoring.md b/docs/monitoring.md index 6816671ffbf46..2a130224591ca 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -86,10 +86,10 @@ follows: - - + + diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 1091ff54b0463..52e4dee46c535 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1052,8 +1052,7 @@ object Client extends Logging { if (isDriver) { conf.getBoolean("spark.driver.userClassPathFirst", false) } else { - conf.getBoolean("spark.executor.userClassPathFirst", - conf.getBoolean("spark.files.userClassPathFirst", false)) + conf.getBoolean("spark.executor.userClassPathFirst", false) } } From f6a9a57a72767f48fcc02e5fda4d6eafa67aebde Mon Sep 17 00:00:00 2001 From: Punya Biswal Date: Fri, 17 Apr 2015 11:08:37 +0100 Subject: [PATCH 031/191] [SPARK-6952] Handle long args when detecting PID reuse sbin/spark-daemon.sh used ps -p "$TARGET_PID" -o args= to figure out whether the process running with the expected PID is actually a Spark daemon. When running with a large classpath, the output of ps gets truncated and the check fails spuriously. This weakens the check to see if it's a java command (which is something we do in other parts of the script) rather than looking for the specific main class name. This means that SPARK-4832 might happen under a slightly broader range of circumstances (a java program happened to reuse the same PID), but it seems worthwhile compared to failing consistently with a large classpath. Author: Punya Biswal Closes #5535 from punya/feature/SPARK-6952 and squashes the following commits: 7ea12d1 [Punya Biswal] Handle long args when detecting PID reuse --- sbin/spark-daemon.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index d8e0facb81169..de762acc8fa0e 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -129,7 +129,7 @@ run_command() { if [ -f "$pid" ]; then TARGET_ID="$(cat "$pid")" - if [[ $(ps -p "$TARGET_ID" -o args=) =~ $command ]]; then + if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then echo "$command running as process $TARGET_ID. Stop it first." exit 1 fi @@ -163,7 +163,7 @@ run_command() { echo "$newpid" > "$pid" sleep 2 # Check if the process has died; in that case we'll tail the log so the user can see - if [[ ! $(ps -p "$newpid" -o args=) =~ $command ]]; then + if [[ ! $(ps -p "$newpid" -o comm=) =~ "java" ]]; then echo "failed to launch $command:" tail -2 "$log" | sed 's/^/ /' echo "full log in $log" From dc48ba9f9f7449dd2f12cbad288b65c8119d9284 Mon Sep 17 00:00:00 2001 From: linweizhong Date: Fri, 17 Apr 2015 12:04:02 +0100 Subject: [PATCH 032/191] [SPARK-6604][PySpark]Specify ip of python server scoket In driver now will start a server socket and use a wildcard ip, use 127.0.0.0 is more reasonable, as we only use it by local Python process. /cc davies Author: linweizhong Closes #5256 from Sephiroth-Lin/SPARK-6604 and squashes the following commits: 7b3c633 [linweizhong] rephrase --- core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b1ffba4c546bf..7409dc2d866f6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -604,7 +604,7 @@ private[spark] object PythonRDD extends Logging { * The thread will terminate after all the data are sent or any exceptions happen. */ private def serveIterator[T](items: Iterator[T], threadName: String): Int = { - val serverSocket = new ServerSocket(0, 1) + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 3 seconds serverSocket.setSoTimeout(3000) From c84d91692aa25c01882bcc3f9fd5de3cfa786195 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 17 Apr 2015 11:29:27 -0500 Subject: [PATCH 033/191] [SPARK-6957] [SPARK-6958] [SQL] improve API compatibility to pandas ``` select(['cola', 'colb']) groupby(['colA', 'colB']) groupby([df.colA, df.colB]) df.sort('A', ascending=True) df.sort(['A', 'B'], ascending=True) df.sort(['A', 'B'], ascending=[1, 0]) ``` cc rxin Author: Davies Liu Closes #5544 from davies/compatibility and squashes the following commits: 4944058 [Davies Liu] add docstrings adb2816 [Davies Liu] Merge branch 'master' of github.com:apache/spark into compatibility bcbbcab [Davies Liu] support ascending as list 8dabdf0 [Davies Liu] improve API compatibility to pandas --- python/pyspark/sql/dataframe.py | 96 ++++++++++++++++++++++----------- python/pyspark/sql/functions.py | 11 ++-- python/pyspark/sql/tests.py | 2 +- 3 files changed, 70 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b9a3e6cfe7f49..326d22e72f104 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -485,13 +485,17 @@ def join(self, other, joinExprs=None, joinType=None): return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix - def sort(self, *cols): + def sort(self, *cols, **kwargs): """Returns a new :class:`DataFrame` sorted by the specified column(s). - :param cols: list of :class:`Column` to sort by. + :param cols: list of :class:`Column` or column names to sort by. + :param ascending: sort by ascending order or not, could be bool, int + or list of bool, int (default: True). >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.sort("age", ascending=False).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> df.orderBy(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> from pyspark.sql.functions import * @@ -499,16 +503,42 @@ def sort(self, *cols): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.orderBy(desc("age"), "name").collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ if not cols: raise ValueError("should sort by at least one column") - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + jcols = [_to_java_column(c) for c in cols] + ascending = kwargs.get('ascending', True) + if isinstance(ascending, (bool, int)): + if not ascending: + jcols = [jc.desc() for jc in jcols] + elif isinstance(ascending, list): + jcols = [jc if asc else jc.desc() + for asc, jc in zip(ascending, jcols)] + else: + raise TypeError("ascending can only be bool or list, but got %s" % type(ascending)) + + jdf = self._jdf.sort(self._jseq(jcols)) return DataFrame(jdf, self.sql_ctx) orderBy = sort + def _jseq(self, cols, converter=None): + """Return a JVM Seq of Columns from a list of Column or names""" + return _to_seq(self.sql_ctx._sc, cols, converter) + + def _jcols(self, *cols): + """Return a JVM Seq of Columns from a list of Column or column names + + If `cols` has only one list in it, cols[0] will be used as the list. + """ + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + return self._jseq(cols, _to_java_column) + def describe(self, *cols): """Computes statistics for numeric columns. @@ -523,9 +553,7 @@ def describe(self, *cols): min 2 max 5 """ - cols = ListConverter().convert(cols, - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)) + jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix @@ -607,9 +635,7 @@ def select(self, *cols): >>> df.select(df.name, (df.age + 10).alias('age')).collect() [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] """ - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + jdf = self._jdf.select(self._jcols(*cols)) return DataFrame(jdf, self.sql_ctx) def selectExpr(self, *expr): @@ -620,8 +646,9 @@ def selectExpr(self, *expr): >>> df.selectExpr("age * 2", "abs(age)").collect() [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] """ - jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) - jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) + if len(expr) == 1 and isinstance(expr[0], list): + expr = expr[0] + jdf = self._jdf.selectExpr(self._jseq(expr)) return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix @@ -659,6 +686,8 @@ def groupBy(self, *cols): so we can run aggregation on them. See :class:`GroupedData` for all the available aggregate functions. + :func:`groupby` is an alias for :func:`groupBy`. + :param cols: list of columns to group by. Each element should be a column name (string) or an expression (:class:`Column`). @@ -668,12 +697,14 @@ def groupBy(self, *cols): [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] >>> df.groupBy(df.name).avg().collect() [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + >>> df.groupBy(['name', df.age]).count().collect() + [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] """ - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + jdf = self._jdf.groupBy(self._jcols(*cols)) return GroupedData(jdf, self.sql_ctx) + groupby = groupBy + def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups (shorthand for ``df.groupBy.agg()``). @@ -744,9 +775,7 @@ def dropna(self, how='any', thresh=None, subset=None): if thresh is None: thresh = len(subset) if how == 'any' else 1 - cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) - cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) - return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx) + return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx) def fillna(self, value, subset=None): """Replace null values, alias for ``na.fill()``. @@ -799,9 +828,7 @@ def fillna(self, value, subset=None): elif not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") - cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) - cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) - return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx) + return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) @ignore_unicode_prefix def withColumn(self, colName, col): @@ -862,10 +889,8 @@ def _api(self): def df_varargs_api(f): def _api(self, *args): - jargs = ListConverter().convert(args, - self.sql_ctx._sc._gateway._gateway_client) name = f.__name__ - jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs)) + jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -912,9 +937,8 @@ def agg(self, *exprs): else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jcols = ListConverter().convert([c._jc for c in exprs[1:]], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + jdf = self._jdf.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) return DataFrame(jdf, self.sql_ctx) @dfapi @@ -1006,6 +1030,19 @@ def _to_java_column(col): return jcol +def _to_seq(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM Seq of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + jcols = ListConverter().convert(cols, sc._gateway._gateway_client) + return sc._jvm.PythonUtils.toSeq(jcols) + + def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): @@ -1177,8 +1214,7 @@ def inSet(self, *cols): cols = cols[0] cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] sc = SparkContext._active_spark_context - jcols = ListConverter().convert(cols, sc._gateway._gateway_client) - jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols)) + jc = getattr(self._jc, "in")(_to_seq(sc, cols)) return Column(jc) # order diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1d6536952810f..bb47923f24b82 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -23,13 +23,11 @@ if sys.version < "3": from itertools import imap as map -from py4j.java_collections import ListConverter - from pyspark import SparkContext from pyspark.rdd import _prepare_for_python_RDD from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType -from pyspark.sql.dataframe import Column, _to_java_column +from pyspark.sql.dataframe import Column, _to_java_column, _to_seq __all__ = ['countDistinct', 'approxCountDistinct', 'udf'] @@ -87,8 +85,7 @@ def countDistinct(col, *cols): [Row(c=2)] """ sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client) - jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols)) + jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column)) return Column(jc) @@ -138,9 +135,7 @@ def __del__(self): def __call__(self, *cols): sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - sc._gateway._gateway_client) - jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) + jc = self._judf.apply(_to_seq(sc, cols, _to_java_column)) return Column(jc) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6691e8c8dc44b..aa3aa1d164d9f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -282,7 +282,7 @@ def test_apply_schema(self): StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), StructField("list1", ArrayType(ByteType(), False), False), StructField("null1", DoubleType(), True)]) - df = self.sqlCtx.applySchema(rdd, schema) + df = self.sqlCtx.createDataFrame(rdd, schema) results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), From 50ab8a6543ad5c31e89c16df374d0cb13222fd1e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 17 Apr 2015 14:21:51 -0500 Subject: [PATCH 034/191] [SPARK-2669] [yarn] Distribute client configuration to AM. Currently, when Spark launches the Yarn AM, the process will use the local Hadoop configuration on the node where the AM launches, if one is present. A more correct approach is to use the same configuration used to launch the Spark job, since the user may have made modifications (such as adding app-specific configs). The approach taken here is to use the distributed cache to make all files in the Hadoop configuration directory available to the AM. This is a little overkill since only the AM needs them (the executors use the broadcast Hadoop configuration from the driver), but is the easier approach. Even though only a few files in that directory may end up being used, all of them are uploaded. This allows supporting use cases such as when auxiliary configuration files are used for SSL configuration, or when uploading a Hive configuration directory. Not all of these may be reflected in a o.a.h.conf.Configuration object, but may be needed when a driver in cluster mode instantiates, for example, a HiveConf object instead. Author: Marcelo Vanzin Closes #4142 from vanzin/SPARK-2669 and squashes the following commits: f5434b9 [Marcelo Vanzin] Merge branch 'master' into SPARK-2669 013f0fb [Marcelo Vanzin] Review feedback. f693152 [Marcelo Vanzin] Le sigh. ed45b7d [Marcelo Vanzin] Zip all config files and upload them as an archive. 5927b6b [Marcelo Vanzin] Merge branch 'master' into SPARK-2669 cbb9fb3 [Marcelo Vanzin] Remove stale test. e3e58d0 [Marcelo Vanzin] Merge branch 'master' into SPARK-2669 e3d0613 [Marcelo Vanzin] Review feedback. 34bdbd8 [Marcelo Vanzin] Fix test. 022a688 [Marcelo Vanzin] Merge branch 'master' into SPARK-2669 a77ddd5 [Marcelo Vanzin] Merge branch 'master' into SPARK-2669 79221c7 [Marcelo Vanzin] [SPARK-2669] [yarn] Distribute client configuration to AM. --- docs/running-on-yarn.md | 6 +- .../org/apache/spark/deploy/yarn/Client.scala | 125 +++++++++++++++--- .../spark/deploy/yarn/ExecutorRunnable.scala | 2 +- .../spark/deploy/yarn/ClientSuite.scala | 29 ++-- .../spark/deploy/yarn/YarnClusterSuite.scala | 6 +- 5 files changed, 132 insertions(+), 36 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 853c9f26b0ec9..0968fc5ad632b 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -211,7 +211,11 @@ Most of the configs are the same for Spark on YARN as for other deployment modes # Launching Spark on YARN Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. -These configs are used to write to the dfs and connect to the YARN ResourceManager. +These configs are used to write to the dfs and connect to the YARN ResourceManager. The +configuration contained in this directory will be distributed to the YARN cluster so that all +containers used by the application use the same configuration. If the configuration references +Java system properties or environment variables not managed by YARN, they should also be set in the +Spark application's configuration (driver, executors, and the AM when running in client mode). There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 52e4dee46c535..019afbd1a1743 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,15 +17,18 @@ package org.apache.spark.deploy.yarn +import java.io.{File, FileOutputStream} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer +import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} import com.google.common.base.Objects +import com.google.common.io.Files import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.conf.Configuration @@ -77,12 +80,6 @@ private[spark] class Client( def stop(): Unit = yarnClient.stop() - /* ------------------------------------------------------------------------------------- * - | The following methods have much in common in the stable and alpha versions of Client, | - | but cannot be implemented in the parent trait due to subtle API differences across | - | hadoop versions. | - * ------------------------------------------------------------------------------------- */ - /** * Submit an application running our ApplicationMaster to the ResourceManager. * @@ -223,6 +220,10 @@ private[spark] class Client( val fs = FileSystem.get(hadoopConf) val dst = new Path(fs.getHomeDirectory(), appStagingDir) val nns = getNameNodesToAccess(sparkConf) + dst + // Used to keep track of URIs added to the distributed cache. If the same URI is added + // multiple times, YARN will fail to launch containers for the app with an internal + // error. + val distributedUris = new HashSet[String] obtainTokensForNamenodes(nns, hadoopConf, credentials) obtainTokenForHiveMetastore(hadoopConf, credentials) @@ -241,6 +242,17 @@ private[spark] class Client( "for alternatives.") } + def addDistributedUri(uri: URI): Boolean = { + val uriStr = uri.toString() + if (distributedUris.contains(uriStr)) { + logWarning(s"Resource $uri added multiple times to distributed cache.") + false + } else { + distributedUris += uriStr + true + } + } + /** * Copy the given main resource to the distributed cache if the scheme is not "local". * Otherwise, set the corresponding key in our SparkConf to handle it downstream. @@ -258,11 +270,13 @@ private[spark] class Client( if (!localPath.isEmpty()) { val localURI = new URI(localPath) if (localURI.getScheme != LOCAL_SCHEME) { - val src = getQualifiedLocalPath(localURI, hadoopConf) - val destPath = copyFileToRemote(dst, src, replication) - val destFs = FileSystem.get(destPath.toUri(), hadoopConf) - distCacheMgr.addResource(destFs, hadoopConf, destPath, - localResources, LocalResourceType.FILE, destName, statCache) + if (addDistributedUri(localURI)) { + val src = getQualifiedLocalPath(localURI, hadoopConf) + val destPath = copyFileToRemote(dst, src, replication) + val destFs = FileSystem.get(destPath.toUri(), hadoopConf) + distCacheMgr.addResource(destFs, hadoopConf, destPath, + localResources, LocalResourceType.FILE, destName, statCache) + } } else if (confKey != null) { // If the resource is intended for local use only, handle this downstream // by setting the appropriate property @@ -271,6 +285,13 @@ private[spark] class Client( } } + createConfArchive().foreach { file => + require(addDistributedUri(file.toURI())) + val destPath = copyFileToRemote(dst, new Path(file.toURI()), replication) + distCacheMgr.addResource(fs, hadoopConf, destPath, localResources, LocalResourceType.ARCHIVE, + LOCALIZED_HADOOP_CONF_DIR, statCache, appMasterOnly = true) + } + /** * Do the same for any additional resources passed in through ClientArguments. * Each resource category is represented by a 3-tuple of: @@ -288,13 +309,15 @@ private[spark] class Client( flist.split(',').foreach { file => val localURI = new URI(file.trim()) if (localURI.getScheme != LOCAL_SCHEME) { - val localPath = new Path(localURI) - val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) - val destPath = copyFileToRemote(dst, localPath, replication) - distCacheMgr.addResource( - fs, hadoopConf, destPath, localResources, resType, linkname, statCache) - if (addToClasspath) { - cachedSecondaryJarLinks += linkname + if (addDistributedUri(localURI)) { + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyFileToRemote(dst, localPath, replication) + distCacheMgr.addResource( + fs, hadoopConf, destPath, localResources, resType, linkname, statCache) + if (addToClasspath) { + cachedSecondaryJarLinks += linkname + } } } else if (addToClasspath) { // Resource is intended for local use only and should be added to the class path @@ -310,6 +333,57 @@ private[spark] class Client( localResources } + /** + * Create an archive with the Hadoop config files for distribution. + * + * These are only used by the AM, since executors will use the configuration object broadcast by + * the driver. The files are zipped and added to the job as an archive, so that YARN will explode + * it when distributing to the AM. This directory is then added to the classpath of the AM + * process, just to make sure that everybody is using the same default config. + * + * This follows the order of precedence set by the startup scripts, in which HADOOP_CONF_DIR + * shows up in the classpath before YARN_CONF_DIR. + * + * Currently this makes a shallow copy of the conf directory. If there are cases where a + * Hadoop config directory contains subdirectories, this code will have to be fixed. + */ + private def createConfArchive(): Option[File] = { + val hadoopConfFiles = new HashMap[String, File]() + Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => + sys.env.get(envKey).foreach { path => + val dir = new File(path) + if (dir.isDirectory()) { + dir.listFiles().foreach { file => + if (!hadoopConfFiles.contains(file.getName())) { + hadoopConfFiles(file.getName()) = file + } + } + } + } + } + + if (!hadoopConfFiles.isEmpty) { + val hadoopConfArchive = File.createTempFile(LOCALIZED_HADOOP_CONF_DIR, ".zip", + new File(Utils.getLocalDir(sparkConf))) + + val hadoopConfStream = new ZipOutputStream(new FileOutputStream(hadoopConfArchive)) + try { + hadoopConfStream.setLevel(0) + hadoopConfFiles.foreach { case (name, file) => + hadoopConfStream.putNextEntry(new ZipEntry(name)) + Files.copy(file, hadoopConfStream) + hadoopConfStream.closeEntry() + } + } finally { + hadoopConfStream.close() + } + + Some(hadoopConfArchive) + } else { + None + } + } + /** * Set up the environment for launching our ApplicationMaster container. */ @@ -317,7 +391,7 @@ private[spark] class Client( logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.driver.extraClassPath") - populateClasspath(args, yarnConf, sparkConf, env, extraCp) + populateClasspath(args, yarnConf, sparkConf, env, true, extraCp) env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() @@ -718,6 +792,9 @@ object Client extends Logging { // Distribution-defined classpath to add to processes val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH" + // Subdirectory where the user's hadoop config files will be placed. + val LOCALIZED_HADOOP_CONF_DIR = "__hadoop_conf__" + /** * Find the user-defined Spark jar if configured, or return the jar containing this * class if not. @@ -831,11 +908,19 @@ object Client extends Logging { conf: Configuration, sparkConf: SparkConf, env: HashMap[String, String], + isAM: Boolean, extraClassPath: Option[String] = None): Unit = { extraClassPath.foreach(addClasspathEntry(_, env)) addClasspathEntry( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env ) + + if (isAM) { + addClasspathEntry( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + + LOCALIZED_HADOOP_CONF_DIR, env) + } + if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { val userClassPath = if (args != null) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index b06069c07f451..9d04d241dae9e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -277,7 +277,7 @@ class ExecutorRunnable( private def prepareEnvironment(container: Container): HashMap[String, String] = { val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.executor.extraClassPath") - Client.populateClasspath(null, yarnConf, sparkConf, env, extraCp) + Client.populateClasspath(null, yarnConf, sparkConf, env, false, extraCp) sparkConf.getExecutorEnv.foreach { case (key, value) => // This assumes each executor environment variable set here is a path diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index c1b94ac9c5bdd..a51c2005cb472 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -20,6 +20,11 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ HashMap => MutableHashMap } +import scala.reflect.ClassTag +import scala.util.Try + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig @@ -30,11 +35,6 @@ import org.mockito.Matchers._ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ HashMap => MutableHashMap } -import scala.reflect.ClassTag -import scala.util.Try - import org.apache.spark.{SparkException, SparkConf} import org.apache.spark.util.Utils @@ -93,7 +93,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { val env = new MutableHashMap[String, String]() val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) - Client.populateClasspath(args, conf, sparkConf, env) + Client.populateClasspath(args, conf, sparkConf, env, true) val cp = env("CLASSPATH").split(":|;|") s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => @@ -104,13 +104,16 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { cp should not contain (uri.getPath()) } }) - if (classOf[Environment].getMethods().exists(_.getName == "$$")) { - cp should contain("{{PWD}}") - } else if (Utils.isWindows) { - cp should contain("%PWD%") - } else { - cp should contain(Environment.PWD.$()) - } + val pwdVar = + if (classOf[Environment].getMethods().exists(_.getName == "$$")) { + "{{PWD}}" + } else if (Utils.isWindows) { + "%PWD%" + } else { + Environment.PWD.$() + } + cp should contain(pwdVar) + cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_HADOOP_CONF_DIR}") cp should not contain (Client.SPARK_JAR) cp should not contain (Client.APP_JAR) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index a18c94d4ab4a8..3877da4120e7c 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -77,6 +77,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit private var yarnCluster: MiniYARNCluster = _ private var tempDir: File = _ private var fakeSparkJar: File = _ + private var hadoopConfDir: File = _ private var logConfDir: File = _ override def beforeAll() { @@ -120,6 +121,9 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_HADOOP_CONF_DIR) + assert(hadoopConfDir.mkdir()) + File.createTempFile("token", ".txt", hadoopConfDir) } override def afterAll() { @@ -258,7 +262,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit appArgs Utils.executeAndGetOutput(argv, - extraEnvironment = Map("YARN_CONF_DIR" -> tempDir.getAbsolutePath())) + extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) } /** From a83571acc938582865efb41645aa1e414f339e46 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 17 Apr 2015 13:15:36 -0700 Subject: [PATCH 035/191] [SPARK-6113] [ml] Stabilize DecisionTree API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is a PR for cleaning up and finalizing the DecisionTree API. PRs for ensembles will follow once this is merged. ### Goal Here is the description copied from the JIRA (for both trees and ensembles): > **Issue**: The APIs for DecisionTree and ensembles (RandomForests and GradientBoostedTrees) have been experimental for a long time. The API has become very convoluted because trees and ensembles have many, many variants, some of which we have added incrementally without a long-term design. > **Proposal**: This JIRA is for discussing changes required to finalize the APIs. After we discuss, I will make a PR to update the APIs and make them non-Experimental. This will require making many breaking changes; see the design doc for details. > **[Design doc](https://docs.google.com/document/d/1rJ_DZinyDG3PkYkAKSsQlY0QgCeefn4hUv7GsPkzBP4)** : This outlines current issues and the proposed API. Overall code layout: * The old API in mllib.tree.* will remain the same. * The new API will reside in ml.classification.* and ml.regression.* ### Summary of changes Old API * Exactly the same, except I made 1 method in Loss private (but that is not a breaking change since that method was introduced after the Spark 1.3 release). New APIs * Under Pipeline API * The new API preserves functionality, except: * New API does NOT store prob (probability of label in classification). I want to have it store the full vector of probabilities but feel that should be in a later PR. * Use abstractions for parameters, estimators, and models to avoid code duplication * Limit parameters to relevant algorithms * For enum-like types, only expose Strings * We can make these pluggable later on by adding new parameters. That is a far-future item. Test suites * I organized DecisionTreeSuite, but I made absolutely no changes to the tests themselves. * The test suites for the new API only test (a) similarity with the results of the old API and (b) elements of the new API. * After code is moved to this new API, we should move the tests from the old suites which test the internals. ### Details #### Changed names Parameters * useNodeIdCache -> cacheNodeIds #### Other changes * Split: Changed categories to set instead of list #### Non-decision tree changes * AttributeGroup * Added parentheses to toMetadata, toStructField methods (These were removed in a previous PR, but I ran into 1 issue with the Scala compiler not being able to disambiguate between a toMetadata method with no parentheses and a toMetadata method which takes 1 argument.) * Attributes * Renamed: toMetadata -> toMetadataImpl * Added toMetadata methods which return ML metadata (keyed with “ML_ATTR”) * NominalAttribute: Added getNumValues method which examines both numValues and values. * Params.inheritValues: Checks whether the parent param really belongs to the child (to allow Estimator-Model pairs with different sets of parameters) ### Questions for reviewers * Is "DecisionTreeClassificationModel" too long a name? * Is this OK in the docs? ``` class DecisionTreeRegressor extends TreeRegressor[DecisionTreeRegressionModel] with DecisionTreeParams[DecisionTreeRegressor] with TreeRegressorParams[DecisionTreeRegressor] ``` ### Future We should open up the abstractions at some point. E.g., it would be useful to be able to set tree-related parameters in 1 place and then pass those to multiple tree-based algorithms. Follow-up JIRAs will be (in this order): * Tree ensembles * Deprecate old tree code * Move DecisionTree implementation code to new API. * Move tests from the old suites which test the internals. * Update programming guide * Python API * Change RandomForest* to always use bootstrapping, even when numTrees = 1 * Provide the probability of the predicted label for classification. After we move code to the new API and update it to maintain probabilities for all labels, then we can add the probabilities to the new API. CC: mengxr manishamde codedeft chouqin MechCoder Author: Joseph K. Bradley Closes #5530 from jkbradley/dt-api-dt and squashes the following commits: 6aae255 [Joseph K. Bradley] Changed tree abstractions not to take type parameters, and for setters to return this.type instead ec17947 [Joseph K. Bradley] Updates based on code review. Main changes were: moving public types from ml.impl.tree to ml.tree, modifying CategoricalSplit to take an Array of categories but store a Set internally, making more types sealed or final 5626c81 [Joseph K. Bradley] style fixes f8fbd24 [Joseph K. Bradley] imported reorg of DecisionTreeSuite from old PR. small cleanups 7ef63ed [Joseph K. Bradley] Added DecisionTreeRegressor, test suites, and example (for real this time) e11673f [Joseph K. Bradley] Added DecisionTreeRegressor, test suites, and example 119f407 [Joseph K. Bradley] added DecisionTreeClassifier example 0bdc486 [Joseph K. Bradley] fixed issues after param PR was merged f9fbb60 [Joseph K. Bradley] Done with DecisionTreeClassifier, but no save/load yet. Need to add example as well 2532c9a [Joseph K. Bradley] partial move to spark.ml API, not done yet c72c1a0 [Joseph K. Bradley] Copied changes for common items, plus DecisionTreeClassifier from original PR --- .../examples/ml/DecisionTreeExample.scala | 322 +++++++++++++++ .../spark/ml/attribute/AttributeGroup.scala | 10 +- .../spark/ml/attribute/attributes.scala | 43 +- .../DecisionTreeClassifier.scala | 155 ++++++++ .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/impl/tree/treeParams.scala | 300 ++++++++++++++ .../scala/org/apache/spark/ml/package.scala | 12 + .../org/apache/spark/ml/param/params.scala | 3 +- .../ml/regression/DecisionTreeRegressor.scala | 145 +++++++ .../scala/org/apache/spark/ml/tree/Node.scala | 205 ++++++++++ .../org/apache/spark/ml/tree/Split.scala | 151 +++++++ .../org/apache/spark/ml/tree/treeModels.scala | 60 +++ .../apache/spark/ml/util/MetadataUtils.scala | 82 ++++ .../spark/mllib/tree/DecisionTree.scala | 5 +- .../mllib/tree/GradientBoostedTrees.scala | 12 +- .../spark/mllib/tree/RandomForest.scala | 2 +- .../tree/configuration/BoostingStrategy.scala | 10 +- .../spark/mllib/tree/loss/AbsoluteError.scala | 5 +- .../spark/mllib/tree/loss/LogLoss.scala | 5 +- .../apache/spark/mllib/tree/loss/Loss.scala | 4 +- .../spark/mllib/tree/loss/SquaredError.scala | 5 +- .../mllib/tree/model/DecisionTreeModel.scala | 4 +- .../apache/spark/mllib/tree/model/Node.scala | 2 +- .../mllib/tree/model/treeEnsembleModels.scala | 32 +- .../JavaDecisionTreeClassifierSuite.java | 98 +++++ .../JavaDecisionTreeRegressorSuite.java | 97 +++++ .../ml/attribute/AttributeGroupSuite.scala | 4 +- .../spark/ml/attribute/AttributeSuite.scala | 42 +- .../DecisionTreeClassifierSuite.scala | 274 +++++++++++++ .../spark/ml/feature/VectorIndexerSuite.scala | 2 +- .../org/apache/spark/ml/impl/TreeTests.scala | 132 +++++++ .../DecisionTreeRegressorSuite.scala | 91 +++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 373 +++++++++--------- 33 files changed, 2426 insertions(+), 263 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala new file mode 100644 index 0000000000000..d4cc8dede07ef --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.ml.tree.DecisionTreeModel +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer} +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.{SQLContext, DataFrame} + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.DecisionTreeExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DecisionTreeExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + numTrees: Int = 1, + featureSubsetStrategy: String = "auto", + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DecisionTreeExample") { + head("DecisionTreeExample: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + }}") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + /** Load a dataset from the given path, using the given format */ + private[ml] def loadData( + sc: SparkContext, + path: String, + format: String, + expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = { + format match { + case "dense" => MLUtils.loadLabeledPoints(sc, path) + case "libsvm" => expectedNumFeatures match { + case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures) + case None => MLUtils.loadLibSVMFile(sc, path) + } + case _ => throw new IllegalArgumentException(s"Bad data format: $format") + } + } + + /** + * Load training and test data from files. + * @param input Path to input dataset. + * @param dataFormat "libsvm" or "dense" + * @param testInput Path to test dataset. + * @param algo Classification or Regression + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @return (training dataset, test dataset) + */ + private[ml] def loadDatasets( + sc: SparkContext, + input: String, + dataFormat: String, + testInput: String, + algo: String, + fracTest: Double): (DataFrame, DataFrame) = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Load training data + val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat) + + // Load or create test set + val splits: Array[RDD[LabeledPoint]] = if (testInput != "") { + // Load testInput. + val numFeatures = origExamples.take(1)(0).features.size + val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures)) + Array(origExamples, origTestExamples) + } else { + // Split input into training, test. + origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345) + } + + // For classification, convert labels to Strings since we will index them later with + // StringIndexer. + def labelsToStrings(data: DataFrame): DataFrame = { + algo.toLowerCase match { + case "classification" => + data.withColumn("labelString", data("label").cast(StringType)) + case "regression" => + data + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + } + val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache()) + + (dataframes(0), dataframes(1)) + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"DecisionTreeExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = + loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest) + + val numTraining = training.count() + val numTest = test.count() + val numFeatures = training.select("features").first().getAs[Vector](0).size + println("Loaded data:") + println(s" numTraining = $numTraining, numTest = $numTest") + println(s" numFeatures = $numFeatures") + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer().setInputCol("features") + .setOutputCol("indexedFeatures").setMaxCategories(10) + stages += featuresIndexer + // (3) Learn DecisionTree + val dt = algo match { + case "classification" => + new DecisionTreeClassifier().setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + case "regression" => + new DecisionTreeRegressor().setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained Decision Tree from the fitted PipelineModel + val treeModel: DecisionTreeModel = algo match { + case "classification" => + pipelineModel.getModel[DecisionTreeClassificationModel]( + dt.asInstanceOf[DecisionTreeClassifier]) + case "regression" => + pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor]) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + if (treeModel.numNodes < 20) { + println(treeModel.toDebugString) // Print full model. + } else { + println(treeModel) // Print model summary. + } + + // Predict on training + val trainingFullPredictions = pipelineModel.transform(training).cache() + val trainingPredictions = trainingFullPredictions.select("prediction") + .map(_.getDouble(0)) + val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0)) + // Predict on test data + val testFullPredictions = pipelineModel.transform(test).cache() + val testPredictions = testFullPredictions.select("prediction") + .map(_.getDouble(0)) + val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0)) + + // For classification, print number of classes for reference. + if (algo == "classification") { + val numClasses = + MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match { + case Some(n) => n + case None => throw new RuntimeException( + "DecisionTreeExample had unknown failure when indexing labels for classification.") + } + println(s"numClasses = $numClasses.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + val trainingAccuracy = + new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision + println(s"Train accuracy = $trainingAccuracy") + val testAccuracy = + new MulticlassMetrics(testPredictions.zip(testLabels)).precision + println(s"Test accuracy = $testAccuracy") + case "regression" => + val trainingRMSE = + new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError + println(s"Training root mean squared error (RMSE) = $trainingRMSE") + val testRMSE = + new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError + println(s"Test root mean squared error (RMSE) = $testRMSE") + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index aa27a668f1695..d7dee8fed2a55 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -117,12 +117,12 @@ class AttributeGroup private ( case numeric: NumericAttribute => // Skip default numeric attributes. if (numeric.withoutIndex != NumericAttribute.defaultAttr) { - numericMetadata += numeric.toMetadata(withType = false) + numericMetadata += numeric.toMetadataImpl(withType = false) } case nominal: NominalAttribute => - nominalMetadata += nominal.toMetadata(withType = false) + nominalMetadata += nominal.toMetadataImpl(withType = false) case binary: BinaryAttribute => - binaryMetadata += binary.toMetadata(withType = false) + binaryMetadata += binary.toMetadataImpl(withType = false) } val attrBldr = new MetadataBuilder if (numericMetadata.nonEmpty) { @@ -151,7 +151,7 @@ class AttributeGroup private ( } /** Converts to ML metadata */ - def toMetadata: Metadata = toMetadata(Metadata.empty) + def toMetadata(): Metadata = toMetadata(Metadata.empty) /** Converts to a StructField with some existing metadata. */ def toStructField(existingMetadata: Metadata): StructField = { @@ -159,7 +159,7 @@ class AttributeGroup private ( } /** Converts to a StructField. */ - def toStructField: StructField = toStructField(Metadata.empty) + def toStructField(): StructField = toStructField(Metadata.empty) override def equals(other: Any): Boolean = { other match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index 00b7566aab434..5717d6ec2eaec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -68,21 +68,32 @@ sealed abstract class Attribute extends Serializable { * Converts this attribute to [[Metadata]]. * @param withType whether to include the type info */ - private[attribute] def toMetadata(withType: Boolean): Metadata + private[attribute] def toMetadataImpl(withType: Boolean): Metadata /** * Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to * save space, because numeric type is the default attribute type. For nominal and binary * attributes, the type info is included. */ - private[attribute] def toMetadata(): Metadata = { + private[attribute] def toMetadataImpl(): Metadata = { if (attrType == AttributeType.Numeric) { - toMetadata(withType = false) + toMetadataImpl(withType = false) } else { - toMetadata(withType = true) + toMetadataImpl(withType = true) } } + /** Converts to ML metadata with some existing metadata. */ + def toMetadata(existingMetadata: Metadata): Metadata = { + new MetadataBuilder() + .withMetadata(existingMetadata) + .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl()) + .build() + } + + /** Converts to ML metadata */ + def toMetadata(): Metadata = toMetadata(Metadata.empty) + /** * Converts to a [[StructField]] with some existing metadata. * @param existingMetadata existing metadata to carry over @@ -90,7 +101,7 @@ sealed abstract class Attribute extends Serializable { def toStructField(existingMetadata: Metadata): StructField = { val newMetadata = new MetadataBuilder() .withMetadata(existingMetadata) - .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadata()) + .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadataImpl()) .build() StructField(name.get, DoubleType, nullable = false, newMetadata) } @@ -98,7 +109,7 @@ sealed abstract class Attribute extends Serializable { /** Converts to a [[StructField]]. */ def toStructField(): StructField = toStructField(Metadata.empty) - override def toString: String = toMetadata(withType = true).toString + override def toString: String = toMetadataImpl(withType = true).toString } /** Trait for ML attribute factories. */ @@ -210,7 +221,7 @@ class NumericAttribute private[ml] ( override def isNominal: Boolean = false /** Convert this attribute to metadata. */ - private[attribute] override def toMetadata(withType: Boolean): Metadata = { + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { import org.apache.spark.ml.attribute.AttributeKeys._ val bldr = new MetadataBuilder() if (withType) bldr.putString(TYPE, attrType.name) @@ -353,6 +364,20 @@ class NominalAttribute private[ml] ( /** Copy without the `numValues`. */ def withoutNumValues: NominalAttribute = copy(numValues = None) + /** + * Get the number of values, either from `numValues` or from `values`. + * Return None if unknown. + */ + def getNumValues: Option[Int] = { + if (numValues.nonEmpty) { + numValues + } else if (values.nonEmpty) { + Some(values.get.length) + } else { + None + } + } + /** Creates a copy of this attribute with optional changes. */ private def copy( name: Option[String] = name, @@ -363,7 +388,7 @@ class NominalAttribute private[ml] ( new NominalAttribute(name, index, isOrdinal, numValues, values) } - private[attribute] override def toMetadata(withType: Boolean): Metadata = { + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { import org.apache.spark.ml.attribute.AttributeKeys._ val bldr = new MetadataBuilder() if (withType) bldr.putString(TYPE, attrType.name) @@ -465,7 +490,7 @@ class BinaryAttribute private[ml] ( new BinaryAttribute(name, index, values) } - private[attribute] override def toMetadata(withType: Boolean): Metadata = { + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { import org.apache.spark.ml.attribute.AttributeKeys._ val bldr = new MetadataBuilder if (withType) bldr.putString(TYPE, attrType.name) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala new file mode 100644 index 0000000000000..3855e396b5534 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, Node} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm + * for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@AlphaComponent +final class DecisionTreeClassifier + extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] + with DecisionTreeParams + with TreeClassifierParams { + + // Override parameter setters from parent trait for Java API compatibility. + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = + super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = + super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): DecisionTreeClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" + + s" with invalid label column, without the number of classes specified.") + // TODO: Automatically index labels. + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = getOldStrategy(categoricalFeatures, numClasses) + val oldModel = OldDecisionTree.train(oldDataset, strategy) + DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + override private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int): OldStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses) + strategy.algo = OldAlgo.Classification + strategy.setImpurity(getOldImpurity) + strategy + } +} + +object DecisionTreeClassifier { + /** Accessor for supported impurities */ + final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@AlphaComponent +final class DecisionTreeClassificationModel private[ml] ( + override val parent: DecisionTreeClassifier, + override val fittingParamMap: ParamMap, + override val rootNode: Node) + extends PredictionModel[Vector, DecisionTreeClassificationModel] + with DecisionTreeModel with Serializable { + + require(rootNode != null, + "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + + override protected def predict(features: Vector): Double = { + rootNode.predict(features) + } + + override protected def copy(): DecisionTreeClassificationModel = { + val m = new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldDecisionTreeModel = { + new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) + } +} + +private[ml] object DecisionTreeClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldDecisionTreeModel, + parent: DecisionTreeClassifier, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, + s"Cannot convert non-classification DecisionTreeModel (old API) to" + + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 4d960df357fe9..23956c512c8a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -118,7 +118,7 @@ class StringIndexerModel private[ml] ( } val outputColName = map(outputCol) val metadata = NominalAttribute.defaultAttr - .withName(outputColName).withValues(labels).toStructField().metadata + .withName(outputColName).withValues(labels).toMetadata() dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala new file mode 100644 index 0000000000000..6f4509f03d033 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl.tree + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.impl.estimator.PredictorParams +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy, + Impurity => OldImpurity, Variance => OldVariance} + + +/** + * :: DeveloperApi :: + * Parameters for Decision Tree-based algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait DecisionTreeParams extends PredictorParams { + + /** + * Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (default = 5) + * @group param + */ + final val maxDepth: IntParam = + new IntParam(this, "maxDepth", "Maximum depth of the tree." + + " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") + + /** + * Maximum number of bins used for discretizing continuous features and for choosing how to split + * on features at each node. More bins give higher granularity. + * Must be >= 2 and >= number of categories in any categorical feature. + * (default = 32) + * @group param + */ + final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" + + " discretizing continuous features. Must be >=2 and >= number of categories for any" + + " categorical feature.") + + /** + * Minimum number of instances each child must have after split. + * If a split causes the left or right child to have fewer than minInstancesPerNode, + * the split will be discarded as invalid. + * Should be >= 1. + * (default = 1) + * @group param + */ + final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" + + " number of instances each child must have after split. If a split causes the left or right" + + " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." + + " Should be >= 1.") + + /** + * Minimum information gain for a split to be considered at a tree node. + * (default = 0.0) + * @group param + */ + final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain", + "Minimum information gain for a split to be considered at a tree node.") + + /** + * Maximum memory in MB allocated to histogram aggregation. + * (default = 256 MB) + * @group expertParam + */ + final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB", + "Maximum memory in MB allocated to histogram aggregation.") + + /** + * If false, the algorithm will pass trees to executors to match instances with nodes. + * If true, the algorithm will cache node IDs for each instance. + * Caching can speed up training of deeper trees. + * (default = false) + * @group expertParam + */ + final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" + + " algorithm will pass trees to executors to match instances with nodes. If true, the" + + " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + + " trees.") + + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be >= 1. + * (default = 10) + * @group expertParam + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + + " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + + " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + + " checkpoint directory is set in the SparkContext. Must be >= 1.") + + setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, + maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + + /** @group setParam */ + def setMaxDepth(value: Int): this.type = { + require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") + set(maxDepth, value) + this.asInstanceOf[this.type] + } + + /** @group getParam */ + def getMaxDepth: Int = getOrDefault(maxDepth) + + /** @group setParam */ + def setMaxBins(value: Int): this.type = { + require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value") + set(maxBins, value) + this + } + + /** @group getParam */ + def getMaxBins: Int = getOrDefault(maxBins) + + /** @group setParam */ + def setMinInstancesPerNode(value: Int): this.type = { + require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value") + set(minInstancesPerNode, value) + this + } + + /** @group getParam */ + def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) + + /** @group setParam */ + def setMinInfoGain(value: Double): this.type = { + set(minInfoGain, value) + this + } + + /** @group getParam */ + def getMinInfoGain: Double = getOrDefault(minInfoGain) + + /** @group expertSetParam */ + def setMaxMemoryInMB(value: Int): this.type = { + require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value") + set(maxMemoryInMB, value) + this + } + + /** @group expertGetParam */ + def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) + + /** @group expertSetParam */ + def setCacheNodeIds(value: Boolean): this.type = { + set(cacheNodeIds, value) + this + } + + /** @group expertGetParam */ + def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) + + /** @group expertSetParam */ + def setCheckpointInterval(value: Int): this.type = { + require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value") + set(checkpointInterval, value) + this + } + + /** @group expertGetParam */ + def getCheckpointInterval: Int = getOrDefault(checkpointInterval) + + /** + * Create a Strategy instance to use with the old API. + * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0, + * the default for single trees). + */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int): OldStrategy = { + val strategy = OldStrategy.defaultStategy(OldAlgo.Classification) + strategy.checkpointInterval = getCheckpointInterval + strategy.maxBins = getMaxBins + strategy.maxDepth = getMaxDepth + strategy.maxMemoryInMB = getMaxMemoryInMB + strategy.minInfoGain = getMinInfoGain + strategy.minInstancesPerNode = getMinInstancesPerNode + strategy.useNodeIdCache = getCacheNodeIds + strategy.numClasses = numClasses + strategy.categoricalFeaturesInfo = categoricalFeatures + strategy.subsamplingRate = 1.0 // default for individual trees + strategy + } +} + +/** + * (private trait) Parameters for Decision Tree-based classification algorithms. + */ +private[ml] trait TreeClassifierParams extends Params { + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "entropy" and "gini". + * (default = gini) + * @group param + */ + val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}") + + setDefault(impurity -> "gini") + + /** @group setParam */ + def setImpurity(value: String): this.type = { + val impurityStr = value.toLowerCase + require(TreeClassifierParams.supportedImpurities.contains(impurityStr), + s"Tree-based classifier was given unrecognized impurity: $value." + + s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}") + set(impurity, impurityStr) + this + } + + /** @group getParam */ + def getImpurity: String = getOrDefault(impurity) + + /** Convert new impurity to old impurity. */ + private[ml] def getOldImpurity: OldImpurity = { + getImpurity match { + case "entropy" => OldEntropy + case "gini" => OldGini + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"TreeClassifierParams was given unrecognized impurity: $impurity.") + } + } +} + +private[ml] object TreeClassifierParams { + // These options should be lowercase. + val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) +} + +/** + * (private trait) Parameters for Decision Tree-based regression algorithms. + */ +private[ml] trait TreeRegressorParams extends Params { + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "variance". + * (default = variance) + * @group param + */ + val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}") + + setDefault(impurity -> "variance") + + /** @group setParam */ + def setImpurity(value: String): this.type = { + val impurityStr = value.toLowerCase + require(TreeRegressorParams.supportedImpurities.contains(impurityStr), + s"Tree-based regressor was given unrecognized impurity: $value." + + s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}") + set(impurity, impurityStr) + this + } + + /** @group getParam */ + def getImpurity: String = getOrDefault(impurity) + + /** Convert new impurity to old impurity. */ + protected def getOldImpurity: OldImpurity = { + getImpurity match { + case "variance" => OldVariance + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"TreeRegressorParams was given unrecognized impurity: $impurity") + } + } +} + +private[ml] object TreeRegressorParams { + // These options should be lowercase. + val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index b45bd1499b72e..ac75e9de1a8f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -32,6 +32,18 @@ package org.apache.spark * @groupname getParam Parameter getters * @groupprio getParam 6 * + * @groupname expertParam (expert-only) Parameters + * @groupdesc expertParam A list of advanced, expert-only (hyper-)parameter keys this algorithm can + * take. Users can set and get the parameter values through setters and getters, + * respectively. + * @groupprio expertParam 7 + * + * @groupname expertSetParam (expert-only) Parameter setters + * @groupprio expertSetParam 8 + * + * @groupname expertGetParam (expert-only) Parameter getters + * @groupprio expertGetParam 9 + * * @groupname Ungrouped Members * @groupprio Ungrouped 0 */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 849c60433c777..ddc5907e7facd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -296,8 +296,9 @@ private[spark] object Params { paramMap: ParamMap, parent: E, child: M): Unit = { + val childParams = child.params.map(_.name).toSet parent.params.foreach { param => - if (paramMap.contains(param)) { + if (paramMap.contains(param) && childParams.contains(param.name)) { child.set(child.getParam(param.name), paramMap(param)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala new file mode 100644 index 0000000000000..49a8b77acf960 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, Node} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm + * for regression. + * It supports both continuous and categorical features. + */ +@AlphaComponent +final class DecisionTreeRegressor + extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] + with DecisionTreeParams + with TreeRegressorParams { + + // Override parameter setters from parent trait for Java API compatibility. + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = + super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): DecisionTreeRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = getOldStrategy(categoricalFeatures) + val oldModel = OldDecisionTree.train(oldDataset, strategy) + DecisionTreeRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0) + strategy.algo = OldAlgo.Regression + strategy.setImpurity(getOldImpurity) + strategy + } +} + +object DecisionTreeRegressor { + /** Accessor for supported impurities */ + final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression. + * It supports both continuous and categorical features. + * @param rootNode Root of the decision tree + */ +@AlphaComponent +final class DecisionTreeRegressionModel private[ml] ( + override val parent: DecisionTreeRegressor, + override val fittingParamMap: ParamMap, + override val rootNode: Node) + extends PredictionModel[Vector, DecisionTreeRegressionModel] + with DecisionTreeModel with Serializable { + + require(rootNode != null, + "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + + override protected def predict(features: Vector): Double = { + rootNode.predict(features) + } + + override protected def copy(): DecisionTreeRegressionModel = { + val m = new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes" + } + + /** Convert to a model in the old API */ + private[ml] def toOld: OldDecisionTreeModel = { + new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) + } +} + +private[ml] object DecisionTreeRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldDecisionTreeModel, + parent: DecisionTreeRegressor, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, + s"Cannot convert non-regression DecisionTreeModel (old API) to" + + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala new file mode 100644 index 0000000000000..d6e2203d9f937 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, + Node => OldNode, Predict => OldPredict} + + +/** + * Decision tree node interface. + */ +sealed abstract class Node extends Serializable { + + // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree + // code into the new API and deprecate the old API. + + /** Prediction this node makes (or would make, if it is an internal node) */ + def prediction: Double + + /** Impurity measure at this node (for training data) */ + def impurity: Double + + /** Recursive prediction helper method */ + private[ml] def predict(features: Vector): Double = prediction + + /** + * Get the number of nodes in tree below this node, including leaf nodes. + * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. + */ + private[tree] def numDescendants: Int + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + private[tree] def subtreeToString(indentFactor: Int = 0): String + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes. + */ + private[tree] def subtreeDepth: Int + + /** + * Create a copy of this node in the old Node format, recursively creating child nodes as needed. + * @param id Node ID using old format IDs + */ + private[ml] def toOld(id: Int): OldNode +} + +private[ml] object Node { + + /** + * Create a new Node from the old Node format, recursively creating child nodes as needed. + */ + def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { + if (oldNode.isLeaf) { + // TODO: Once the implementation has been moved to this API, then include sufficient + // statistics here. + new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity) + } else { + val gain = if (oldNode.stats.nonEmpty) { + oldNode.stats.get.gain + } else { + 0.0 + } + new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, + gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), + split = Split.fromOld(oldNode.split.get, categoricalFeatures)) + } + } +} + +/** + * Decision tree leaf node. + * @param prediction Prediction this node makes + * @param impurity Impurity measure at this node (for training data) + */ +final class LeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double) extends Node { + + override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" + + override private[ml] def predict(features: Vector): Double = prediction + + override private[tree] def numDescendants: Int = 0 + + override private[tree] def subtreeToString(indentFactor: Int = 0): String = { + val prefix: String = " " * indentFactor + prefix + s"Predict: $prediction\n" + } + + override private[tree] def subtreeDepth: Int = 0 + + override private[ml] def toOld(id: Int): OldNode = { + // NOTE: We do NOT store 'prob' in the new API currently. + new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true, + None, None, None, None) + } +} + +/** + * Internal Decision Tree node. + * @param prediction Prediction this node would make if it were a leaf node + * @param impurity Impurity measure at this node (for training data) + * @param gain Information gain value. + * Values < 0 indicate missing values; this quirk will be removed with future updates. + * @param leftChild Left-hand child node + * @param rightChild Right-hand child node + * @param split Information about the test used to split to the left or right child. + */ +final class InternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + val gain: Double, + val leftChild: Node, + val rightChild: Node, + val split: Split) extends Node { + + override def toString: String = { + s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" + } + + override private[ml] def predict(features: Vector): Double = { + if (split.shouldGoLeft(features)) { + leftChild.predict(features) + } else { + rightChild.predict(features) + } + } + + override private[tree] def numDescendants: Int = { + 2 + leftChild.numDescendants + rightChild.numDescendants + } + + override private[tree] def subtreeToString(indentFactor: Int = 0): String = { + val prefix: String = " " * indentFactor + prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" + + leftChild.subtreeToString(indentFactor + 1) + + prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" + + rightChild.subtreeToString(indentFactor + 1) + } + + override private[tree] def subtreeDepth: Int = { + 1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth) + } + + override private[ml] def toOld(id: Int): OldNode = { + assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" + + " since the old API does not support deep trees.") + // NOTE: We do NOT store 'prob' in the new API currently. + new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false, + Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), + Some(rightChild.toOld(OldNode.rightChildIndex(id))), + Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, + new OldPredict(leftChild.prediction, prob = 0.0), + new OldPredict(rightChild.prediction, prob = 0.0)))) + } +} + +private object InternalNode { + + /** + * Helper method for [[Node.subtreeToString()]]. + * @param split Split to print + * @param left Indicates whether this is the part of the split going to the left, + * or that going to the right. + */ + private def splitToString(split: Split, left: Boolean): String = { + val featureStr = s"feature ${split.featureIndex}" + split match { + case contSplit: ContinuousSplit => + if (left) { + s"$featureStr <= ${contSplit.threshold}" + } else { + s"$featureStr > ${contSplit.threshold}" + } + case catSplit: CategoricalSplit => + val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}") + if (left) { + s"$featureStr in $categoriesStr" + } else { + s"$featureStr not in $categoriesStr" + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala new file mode 100644 index 0000000000000..cb940f62990ed --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} +import org.apache.spark.mllib.tree.model.{Split => OldSplit} + + +/** + * Interface for a "Split," which specifies a test made at a decision tree node + * to choose the left or right path. + */ +sealed trait Split extends Serializable { + + /** Index of feature which this split tests */ + def featureIndex: Int + + /** Return true (split to left) or false (split to right) */ + private[ml] def shouldGoLeft(features: Vector): Boolean + + /** Convert to old Split format */ + private[tree] def toOld: OldSplit +} + +private[ml] object Split { + + def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = { + oldSplit.featureType match { + case OldFeatureType.Categorical => + new CategoricalSplit(featureIndex = oldSplit.feature, + leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature)) + case OldFeatureType.Continuous => + new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold) + } + } +} + +/** + * Split which tests a categorical feature. + * @param featureIndex Index of the feature to test + * @param leftCategories If the feature value is in this set of categories, then the split goes + * left. Otherwise, it goes right. + * @param numCategories Number of categories for this feature. + */ +final class CategoricalSplit( + override val featureIndex: Int, + leftCategories: Array[Double], + private val numCategories: Int) + extends Split { + + require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + + s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}") + + /** + * If true, then "categories" is the set of categories for splitting to the left, and vice versa. + */ + private val isLeft: Boolean = leftCategories.length <= numCategories / 2 + + /** Set of categories determining the splitting rule, along with [[isLeft]]. */ + private val categories: Set[Double] = { + if (isLeft) { + leftCategories.toSet + } else { + setComplement(leftCategories.toSet) + } + } + + override private[ml] def shouldGoLeft(features: Vector): Boolean = { + if (isLeft) { + categories.contains(features(featureIndex)) + } else { + !categories.contains(features(featureIndex)) + } + } + + override def equals(o: Any): Boolean = { + o match { + case other: CategoricalSplit => featureIndex == other.featureIndex && + isLeft == other.isLeft && categories == other.categories + case _ => false + } + } + + override private[tree] def toOld: OldSplit = { + val oldCats = if (isLeft) { + categories + } else { + setComplement(categories) + } + OldSplit(featureIndex, threshold = 0.0, OldFeatureType.Categorical, oldCats.toList) + } + + /** Get sorted categories which split to the left */ + def getLeftCategories: Array[Double] = { + val cats = if (isLeft) categories else setComplement(categories) + cats.toArray.sorted + } + + /** Get sorted categories which split to the right */ + def getRightCategories: Array[Double] = { + val cats = if (isLeft) setComplement(categories) else categories + cats.toArray.sorted + } + + /** [0, numCategories) \ cats */ + private def setComplement(cats: Set[Double]): Set[Double] = { + Range(0, numCategories).map(_.toDouble).filter(cat => !cats.contains(cat)).toSet + } +} + +/** + * Split which tests a continuous feature. + * @param featureIndex Index of the feature to test + * @param threshold If the feature value is <= this threshold, then the split goes left. + * Otherwise, it goes right. + */ +final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split { + + override private[ml] def shouldGoLeft(features: Vector): Boolean = { + features(featureIndex) <= threshold + } + + override def equals(o: Any): Boolean = { + o match { + case other: ContinuousSplit => + featureIndex == other.featureIndex && threshold == other.threshold + case _ => + false + } + } + + override private[tree] def toOld: OldSplit = { + OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala new file mode 100644 index 0000000000000..8e3bc3849dcf0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.annotation.AlphaComponent + + +/** + * :: AlphaComponent :: + * + * Abstraction for Decision Tree models. + * + * TODO: Add support for predicting probabilities and raw predictions + */ +@AlphaComponent +trait DecisionTreeModel { + + /** Root of the decision tree */ + def rootNode: Node + + /** Number of nodes in tree, including leaf nodes. */ + def numNodes: Int = { + 1 + rootNode.numDescendants + } + + /** + * Depth of the tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + lazy val depth: Int = { + rootNode.subtreeDepth + } + + /** Summary of the model */ + override def toString: String = { + // Implementing classes should generally override this method to be more descriptive. + s"DecisionTreeModel of depth $depth with $numNodes nodes" + } + + /** Full description of model */ + def toDebugString: String = { + val header = toString + "\n" + header + rootNode.subtreeToString(2) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala new file mode 100644 index 0000000000000..c84c8b4eb744f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.immutable.HashMap + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, + NumericAttribute} +import org.apache.spark.sql.types.StructField + + +/** + * :: Experimental :: + * + * Helper utilities for tree-based algorithms + */ +@Experimental +object MetadataUtils { + + /** + * Examine a schema to identify the number of classes in a label column. + * Returns None if the number of labels is not specified, or if the label column is continuous. + */ + def getNumClasses(labelSchema: StructField): Option[Int] = { + Attribute.fromStructField(labelSchema) match { + case numAttr: NumericAttribute => None + case binAttr: BinaryAttribute => Some(2) + case nomAttr: NominalAttribute => nomAttr.getNumValues + } + } + + /** + * Examine a schema to identify categorical (Binary and Nominal) features. + * + * @param featuresSchema Schema of the features column. + * If a feature does not have metadata, it is assumed to be continuous. + * If a feature is Nominal, then it must have the number of values + * specified. + * @return Map: feature index --> number of categories. + * The map's set of keys will be the set of categorical feature indices. + */ + def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = { + val metadata = AttributeGroup.fromStructField(featuresSchema) + if (metadata.attributes.isEmpty) { + HashMap.empty[Int, Int] + } else { + metadata.attributes.get.zipWithIndex.flatMap { case (attr, idx) => + if (attr == null) { + Iterator() + } else { + attr match { + case numAttr: NumericAttribute => Iterator() + case binAttr: BinaryAttribute => Iterator(idx -> 2) + case nomAttr: NominalAttribute => + nomAttr.getNumValues match { + case Some(numValues: Int) => Iterator(idx -> numValues) + case None => throw new IllegalArgumentException(s"Feature $idx is marked as" + + " Nominal (categorical), but it does not have the number of values specified.") + } + } + } + }.toMap + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b9d0c56dd1ea3..dfe3a0b6913ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1147,7 +1147,10 @@ object DecisionTree extends Serializable with Logging { } } - assert(splits.length > 0) + // TODO: Do not fail; just ignore the useless feature. + assert(splits.length > 0, + s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + + " Please remove this feature and then try again.") // set number of splits accordingly metadata.setNumSplits(featureIndex, splits.length) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index c02c79f094b66..0e31c7ed58df8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -81,11 +81,11 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) /** * Method to validate a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @param validationInput Validation dataset: - RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - Should be different from and follow the same distribution as input. - e.g., these two datasets could be created from an original dataset - by using [[org.apache.spark.rdd.RDD.randomSplit()]] + * @param validationInput Validation dataset. + * This dataset should be different from the training dataset, + * but it should follow the same distribution. + * E.g., these two datasets could be created from an original dataset + * by using [[org.apache.spark.rdd.RDD.randomSplit()]] * @return a gradient boosted trees model that can be used for prediction */ def runWithValidation( @@ -194,8 +194,6 @@ object GradientBoostedTrees extends Logging { val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight - val startingModel = new GradientBoostedTreesModel( - Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1)) var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index db01f2e229e5a..055e60c7d9c95 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -249,7 +249,7 @@ private class RandomForest ( nodeIdCache.get.deleteAllCheckpoints() } catch { case e:IOException => - logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}") + logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 664c8df019233..2d6b01524ff3d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -89,14 +89,14 @@ object BoostingStrategy { * @return Configuration for boosting algorithm */ def defaultParams(algo: Algo): BoostingStrategy = { - val treeStragtegy = Strategy.defaultStategy(algo) - treeStragtegy.maxDepth = 3 + val treeStrategy = Strategy.defaultStategy(algo) + treeStrategy.maxDepth = 3 algo match { case Algo.Classification => - treeStragtegy.numClasses = 2 - new BoostingStrategy(treeStragtegy, LogLoss) + treeStrategy.numClasses = 2 + new BoostingStrategy(treeStrategy, LogLoss) case Algo.Regression => - new BoostingStrategy(treeStragtegy, SquaredError) + new BoostingStrategy(treeStrategy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index 6f570b4e09c79..2bdef73c4a8f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -45,9 +45,8 @@ object AbsoluteError extends Loss { if (label - prediction < 0) 1.0 else -1.0 } - override def computeError(prediction: Double, label: Double): Double = { + override private[mllib] def computeError(prediction: Double, label: Double): Double = { val err = label - prediction math.abs(err) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 24ee9f3d51293..778c24526de70 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -47,10 +47,9 @@ object LogLoss extends Loss { - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } - override def computeError(prediction: Double, label: Double): Double = { + override private[mllib] def computeError(prediction: Double, label: Double): Double = { val margin = 2.0 * label * prediction // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. 2.0 * MLUtils.log1pExp(-margin) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index d3b82b752fa0d..64ffccbce073f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -22,6 +22,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. @@ -57,6 +58,5 @@ trait Loss extends Serializable { * @param label True label. * @return Measure of model error on datapoint. */ - def computeError(prediction: Double, label: Double): Double - + private[mllib] def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 58857ae15e93e..a5582d3ef3324 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -45,9 +45,8 @@ object SquaredError extends Loss { 2.0 * (prediction - label) } - override def computeError(prediction: Double, label: Double): Double = { + override private[mllib] def computeError(prediction: Double, label: Double): Double = { val err = prediction - label err * err } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index c9bafd60fba4d..331af428533de 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -113,11 +113,13 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable DecisionTreeModel.SaveLoadV1_0.save(sc, path, this) } - override protected def formatVersion: String = "1.0" + override protected def formatVersion: String = DecisionTreeModel.formatVersion } object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { + private[spark] def formatVersion: String = "1.0" + private[tree] object SaveLoadV1_0 { def thisFormatVersion: String = "1.0" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 4f72bb8014cc0..708ba04b567d3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -175,7 +175,7 @@ class Node ( } } -private[tree] object Node { +private[spark] object Node { /** * Return a node with the given node id (but nothing else set). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index fef3d2acb202a..8341219bfa71c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -38,6 +38,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils + /** * :: Experimental :: * Represents a random forest model. @@ -47,7 +48,7 @@ import org.apache.spark.util.Utils */ @Experimental class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) - extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0), + extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0), combiningStrategy = if (algo == Classification) Vote else Average) with Saveable { @@ -58,11 +59,13 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis RandomForestModel.SaveLoadV1_0.thisClassName) } - override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + override protected def formatVersion: String = RandomForestModel.formatVersion } object RandomForestModel extends Loader[RandomForestModel] { + private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + override def load(sc: SparkContext, path: String): RandomForestModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -102,15 +105,13 @@ class GradientBoostedTreesModel( extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) with Saveable { - require(trees.size == treeWeights.size) + require(trees.length == treeWeights.length) override def save(sc: SparkContext, path: String): Unit = { TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, GradientBoostedTreesModel.SaveLoadV1_0.thisClassName) } - override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion - /** * Method to compute error or loss for every iteration of gradient boosting. * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] @@ -138,7 +139,7 @@ class GradientBoostedTreesModel( evaluationArray(0) = predictionAndError.values.mean() val broadcastTrees = sc.broadcast(trees) - (1 until numIterations).map { nTree => + (1 until numIterations).foreach { nTree => predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => val currentTree = broadcastTrees.value(nTree) val currentTreeWeight = localTreeWeights(nTree) @@ -155,6 +156,7 @@ class GradientBoostedTreesModel( evaluationArray } + override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion } object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { @@ -200,17 +202,17 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { loss: Loss): RDD[(Double, Double)] = { val newPredError = data.zip(predictionAndError).mapPartitions { iter => - iter.map { - case (lp, (pred, error)) => { - val newPred = pred + tree.predict(lp.features) * treeWeight - val newError = loss.computeError(newPred, lp.label) - (newPred, newError) - } + iter.map { case (lp, (pred, error)) => + val newPred = pred + tree.predict(lp.features) * treeWeight + val newError = loss.computeError(newPred, lp.label) + (newPred, newError) } } newPredError } + private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -340,12 +342,12 @@ private[tree] sealed class TreeEnsembleModel( } /** - * Get number of trees in forest. + * Get number of trees in ensemble. */ - def numTrees: Int = trees.size + def numTrees: Int = trees.length /** - * Get total number of nodes, summed over all trees in the forest. + * Get total number of nodes, summed over all trees in the ensemble. */ def totalNumNodes: Int = trees.map(_.numNodes).sum } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java new file mode 100644 index 0000000000000..43b8787f9dd7e --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.File; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.util.Utils; + + +public class JavaDecisionTreeClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + DecisionTreeClassifier dt = new DecisionTreeClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) { + dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]); + } + DecisionTreeClassificationModel model = dt.fit(dataFrame); + + model.transform(dataFrame); + model.numNodes(); + model.depth(); + model.toDebugString(); + + /* + // TODO: Add test once save/load are implemented. + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + DecisionTreeClassificationModel sameModel = + DecisionTreeClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java new file mode 100644 index 0000000000000..a3a339004f31c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.File; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.util.Utils; + + +public class JavaDecisionTreeRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) { + dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]); + } + DecisionTreeRegressionModel model = dt.fit(dataFrame); + + model.transform(dataFrame); + model.numNodes(); + model.depth(); + model.toDebugString(); + + /* + // TODO: Add test once save/load are implemented. + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + DecisionTreeRegressionModel sameModel = DecisionTreeRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 0dcfe5a2002dc..17ddd335deb6d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -44,7 +44,7 @@ class AttributeGroupSuite extends FunSuite { group("abc") } assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name)) - assert(group === AttributeGroup.fromStructField(group.toStructField)) + assert(group === AttributeGroup.fromStructField(group.toStructField())) } test("attribute group without attributes") { @@ -54,7 +54,7 @@ class AttributeGroupSuite extends FunSuite { assert(group0.size === 10) assert(group0.attributes.isEmpty) assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name)) - assert(group0 === AttributeGroup.fromStructField(group0.toStructField)) + assert(group0 === AttributeGroup.fromStructField(group0.toStructField())) val group1 = new AttributeGroup("item") assert(group1.name === "item") diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index 6ec35b03656f9..3e1a7196e37cb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -36,9 +36,9 @@ class AttributeSuite extends FunSuite { assert(attr.max.isEmpty) assert(attr.std.isEmpty) assert(attr.sparsity.isEmpty) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = false) === metadata) - assert(attr.toMetadata(withType = true) === metadataWithType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = false) === metadata) + assert(attr.toMetadataImpl(withType = true) === metadataWithType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === Attribute.fromMetadata(metadataWithType)) intercept[NoSuchElementException] { @@ -59,9 +59,9 @@ class AttributeSuite extends FunSuite { assert(!attr.isNominal) assert(attr.name === Some(name)) assert(attr.index === Some(index)) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = false) === metadata) - assert(attr.toMetadata(withType = true) === metadataWithType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = false) === metadata) + assert(attr.toMetadataImpl(withType = true) === metadataWithType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === Attribute.fromMetadata(metadataWithType)) val field = attr.toStructField() @@ -81,7 +81,7 @@ class AttributeSuite extends FunSuite { assert(attr2.max === Some(1.0)) assert(attr2.std === Some(0.5)) assert(attr2.sparsity === Some(0.3)) - assert(attr2 === Attribute.fromMetadata(attr2.toMetadata())) + assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl())) } test("bad numeric attributes") { @@ -105,9 +105,9 @@ class AttributeSuite extends FunSuite { assert(attr.values.isEmpty) assert(attr.numValues.isEmpty) assert(attr.isOrdinal.isEmpty) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = true) === metadata) - assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === NominalAttribute.fromMetadata(metadataWithoutType)) intercept[NoSuchElementException] { @@ -135,9 +135,9 @@ class AttributeSuite extends FunSuite { assert(attr.values === Some(values)) assert(attr.indexOf("medium") === 1) assert(attr.getValue(1) === "medium") - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = true) === metadata) - assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === NominalAttribute.fromMetadata(metadataWithoutType)) assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField())) @@ -147,8 +147,8 @@ class AttributeSuite extends FunSuite { assert(attr2.index.isEmpty) assert(attr2.values.get === Array("small", "medium", "large", "x-large")) assert(attr2.indexOf("x-large") === 3) - assert(attr2 === Attribute.fromMetadata(attr2.toMetadata())) - assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadata(withType = false))) + assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl())) + assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadataImpl(withType = false))) } test("bad nominal attributes") { @@ -168,9 +168,9 @@ class AttributeSuite extends FunSuite { assert(attr.name.isEmpty) assert(attr.index.isEmpty) assert(attr.values.isEmpty) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = true) === metadata) - assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType)) intercept[NoSuchElementException] { @@ -196,9 +196,9 @@ class AttributeSuite extends FunSuite { assert(attr.name === Some(name)) assert(attr.index === Some(index)) assert(attr.values.get === values) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = true) === metadata) - assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType)) assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField())) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala new file mode 100644 index 0000000000000..af88595df5245 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { + + import DecisionTreeClassifierSuite.compareAPIs + + private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _ + private var orderedLabeledPointsWithLabel1RDD: RDD[LabeledPoint] = _ + private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ + private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ + private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + categoricalDataPointsRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) + orderedLabeledPointsWithLabel0RDD = + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()) + orderedLabeledPointsWithLabel1RDD = + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()) + categoricalDataPointsForMulticlassRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass()) + continuousDataPointsForMulticlassRDD = + sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass()) + categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize( + OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + test("Binary classification stump with ordered categorical features") { + val dt = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3, 1-> 3) + val numClasses = 2 + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) + } + + test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") { + val dt = new DecisionTreeClassifier() + .setMaxDepth(3) + .setMaxBins(100) + val numClasses = 2 + Array(orderedLabeledPointsWithLabel0RDD, orderedLabeledPointsWithLabel1RDD).foreach { rdd => + DecisionTreeClassifier.supportedImpurities.foreach { impurity => + dt.setImpurity(impurity) + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + } + } + + test("Multiclass classification stump with 3-ary (unordered) categorical features") { + val rdd = categoricalDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 3 + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(3.0))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Binary classification stump with 2 continuous features") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Multiclass classification stump with unordered categorical features," + + " with just enough bins") { + val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features + val rdd = categoricalDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(maxBins) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification stump with continuous features") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Multiclass classification stump with continuous + unordered categorical features") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification stump with 10-ary (ordered) categorical features") { + val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 10, 1 -> 10) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification tree with 10-ary (ordered) categorical features," + + " with just enough bins") { + val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(10) + val categoricalFeatures = Map(0 -> 10, 1 -> 10) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("split must satisfy min instances per node requirements") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("do not choose split that does not satisfy min instance per node requirements") { + // if a split does not satisfy min instances per node requirements, + // this split is invalid, even though the information gain of split is large. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxBins(2) + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val categoricalFeatures = Map(0 -> 2, 1-> 2) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("split must satisfy min info gain requirements") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInfoGain(1.0) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification) + val newModel = DecisionTreeClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = DecisionTreeClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private[ml] object DecisionTreeClassifierSuite extends FunSuite { + + /** + * Train 2 decision trees on the given dataset, one using the old API and one using the new API. + * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + dt: DecisionTreeClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int): Unit = { + val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses) + val oldTree = OldDecisionTree.train(data, oldStrategy) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(oldTree, newTree.parent, + newTree.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldTreeAsNew, newTree) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 81ef831c42e55..1b261b2643854 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -228,7 +228,7 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { } val attrGroup = new AttributeGroup("features", featureAttributes) val densePoints1WithMeta = - densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata)) + densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata())) val vectorIndexer = getIndexer.setMaxCategories(2) val model = vectorIndexer.fit(densePoints1WithMeta) // Check that ML metadata are preserved. diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala new file mode 100644 index 0000000000000..2e57d4ce37f1d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl + +import scala.collection.JavaConverters._ + +import org.scalatest.FunSuite + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, DataFrame} + + +private[ml] object TreeTests extends FunSuite { + + /** + * Convert the given data to a DataFrame, and set the features and label metadata. + * @param data Dataset. Categorical features and labels must already have 0-based indices. + * This must be non-empty. + * @param categoricalFeatures Map: categorical feature index -> number of distinct values + * @param numClasses Number of classes label can take. If 0, mark as continuous. + * @return DataFrame with metadata + */ + def setMetadata( + data: RDD[LabeledPoint], + categoricalFeatures: Map[Int, Int], + numClasses: Int): DataFrame = { + val sqlContext = new SQLContext(data.sparkContext) + import sqlContext.implicits._ + val df = data.toDF() + val numFeatures = data.first().features.size + val featuresAttributes = Range(0, numFeatures).map { feature => + if (categoricalFeatures.contains(feature)) { + NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature)) + } else { + NumericAttribute.defaultAttr.withIndex(feature) + } + }.toArray + val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata() + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName("label") + } else { + NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + df.select(df("features").as("features", featuresMetadata), + df("label").as("label", labelMetadata)) + } + + /** Java-friendly version of [[setMetadata()]] */ + def setMetadata( + data: JavaRDD[LabeledPoint], + categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer], + numClasses: Int): DataFrame = { + setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numClasses) + } + + /** + * Check if the two trees are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + * If the trees are not equal, this prints the two trees and throws an exception. + */ + def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + try { + checkEqual(a.rootNode, b.rootNode) + } catch { + case ex: Exception => + throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + "TREE A:\n" + a.toDebugString + "\n" + + "TREE B:\n" + b.toDebugString + "\n", ex) + } + } + + /** + * Return true iff the two nodes and their descendants are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + */ + private def checkEqual(a: Node, b: Node): Unit = { + assert(a.prediction === b.prediction) + assert(a.impurity === b.impurity) + (a, b) match { + case (aye: InternalNode, bee: InternalNode) => + assert(aye.split === bee.split) + checkEqual(aye.leftChild, bee.leftChild) + checkEqual(aye.rightChild, bee.rightChild) + case (aye: LeafNode, bee: LeafNode) => // do nothing + case _ => + throw new AssertionError("Found mismatched nodes") + } + } + + // TODO: Reinstate after adding ensembles + /** + * Check if the two models are exactly the same. + * If the models are not equal, this throws an exception. + */ + /* + def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = { + try { + a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) => + TreeTests.checkEqual(treeA, treeB) + } + assert(a.getTreeWeights === b.getTreeWeights) + } catch { + case ex: Exception => throw new AssertionError( + "checkEqual failed since the two tree ensembles were not identical") + } + } + */ +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala new file mode 100644 index 0000000000000..0b40fe33fae9d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { + + import DecisionTreeRegressorSuite.compareAPIs + + private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + categoricalDataPointsRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + test("Regression stump with 3-ary (ordered) categorical features") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3, 1-> 3) + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) + } + + test("Regression stump with binary (ordered) categorical features") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 2, 1-> 2) + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: test("model save/load") +} + +private[ml] object DecisionTreeRegressorSuite extends FunSuite { + + /** + * Train 2 decision trees on the given dataset, one using the old API and one using the new API. + * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + dt: DecisionTreeRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldStrategy = dt.getOldStrategy(categoricalFeatures) + val oldTree = OldDecisionTree.train(data, oldStrategy) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newTree = dt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(oldTree, newTree.parent, + newTree.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldTreeAsNew, newTree) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 4c162df810bb2..249b8eae19b17 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -36,6 +36,10 @@ import org.apache.spark.util.Utils class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { + ///////////////////////////////////////////////////////////////////////////// + // Tests examining individual elements of training + ///////////////////////////////////////////////////////////////////////////// + test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -254,6 +258,165 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(bins(0).length === 0) } + test("Avoid aggregation on the last level") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Second level node building with vs. without groups") { + val arr = DecisionTreeSuite.generateOrderedLabeledPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + + // Train a 1-node model + val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, + numClasses = 2, maxBins = 100) + val modelOneNode = DecisionTree.train(rdd, strategyOneNode) + val rootNode1 = modelOneNode.topNode.deepCopy() + val rootNode2 = modelOneNode.topNode.deepCopy() + assert(rootNode1.leftNode.nonEmpty) + assert(rootNode1.rightNode.nonEmpty) + + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + // Single group second level tree construction. + val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) + val treeToNodeToIndexInfo = Map((0, Map( + (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), + (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + val children1 = new Array[Node](2) + children1(0) = rootNode1.leftNode.get + children1(1) = rootNode1.rightNode.get + + // Train one second-level node at a time. + val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) + val treeToNodeToIndexInfoA = Map((0, Map( + (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) + val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) + val treeToNodeToIndexInfoB = Map((0, Map( + (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) + val children2 = new Array[Node](2) + children2(0) = rootNode2.leftNode.get + children2(1) = rootNode2.rightNode.get + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until 2) { + assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) + assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) + assert(children1(i).split === children2(i).split) + assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) + val stats1 = children1(i).stats.get + val stats2 = children2(i).stats.get + assert(stats1.gain === stats2.gain) + assert(stats1.impurity === stats2.impurity) + assert(stats1.leftImpurity === stats2.leftImpurity) + assert(stats1.rightImpurity === stats2.rightImpurity) + assert(children1(i).predict.predict === children2(i).predict.predict) + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// test("Binary classification stump with ordered categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() @@ -438,76 +601,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(rootNode.predict.predict === 1) } - test("Second level node building with vs. without groups") { - val arr = DecisionTreeSuite.generateOrderedLabeledPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - - // Train a 1-node model - val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, - numClasses = 2, maxBins = 100) - val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val rootNode1 = modelOneNode.topNode.deepCopy() - val rootNode2 = modelOneNode.topNode.deepCopy() - assert(rootNode1.leftNode.nonEmpty) - assert(rootNode1.rightNode.nonEmpty) - - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - // Single group second level tree construction. - val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) - val treeToNodeToIndexInfo = Map((0, Map( - (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), - (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - val children1 = new Array[Node](2) - children1(0) = rootNode1.leftNode.get - children1(1) = rootNode1.rightNode.get - - // Train one second-level node at a time. - val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) - val treeToNodeToIndexInfoA = Map((0, Map( - (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) - val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) - val treeToNodeToIndexInfoB = Map((0, Map( - (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) - val children2 = new Array[Node](2) - children2(0) = rootNode2.leftNode.get - children2(1) = rootNode2.rightNode.get - - // Verify whether the splits obtained using single group and multiple group level - // construction strategies are the same. - for (i <- 0 until 2) { - assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) - assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) - assert(children1(i).split === children2(i).split) - assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) - val stats1 = children1(i).stats.get - val stats2 = children2(i).stats.get - assert(stats1.gain === stats2.gain) - assert(stats1.impurity === stats2.impurity) - assert(stats1.leftImpurity === stats2.leftImpurity) - assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict.predict === children2(i).predict.predict) - } - } - test("Multiclass classification stump with 3-ary (unordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) @@ -528,11 +621,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) - arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(3.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 2) @@ -544,11 +637,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("Binary classification stump with 2 continuous features") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, @@ -668,11 +761,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("split must satisfy min instances per node requirements") { - val arr = new Array[LabeledPoint](3) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) - + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, numClasses = 2, minInstancesPerNode = 2) @@ -695,11 +787,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { test("do not choose split that does not satisfy min instance per node requirements") { // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) - arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, @@ -715,10 +807,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("split must satisfy min info gain requirements") { - val arr = new Array[LabeledPoint](3) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) val input = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, @@ -739,91 +831,9 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(gain == InformationGainStats.invalidInformationGainStats) } - test("Avoid aggregation on the last level") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue leaf nodes into node queue - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) - } - - test("Avoid aggregation if impurity is 0.0") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue a node into node queue if its impurity is 0.0 - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) - } + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// test("Node.subtreeIterator") { val model = DecisionTreeSuite.createModel(Classification) @@ -996,8 +1006,9 @@ object DecisionTreeSuite extends FunSuite { /** * Create a tree model. This is deterministic and contains a variety of node and feature types. + * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.) */ - private[tree] def createModel(algo: Algo): DecisionTreeModel = { + private[mllib] def createModel(algo: Algo): DecisionTreeModel = { val topNode = createInternalNode(id = 1, Continuous) val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical)) val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7)) @@ -1017,7 +1028,7 @@ object DecisionTreeSuite extends FunSuite { * make mistakes such as creating loops of Nodes. * If the trees are not equal, this prints the two trees and throws an exception. */ - private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + private[mllib] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { try { assert(a.algo === b.algo) checkEqual(a.topNode, b.topNode) From 59e206deb7346148412bbf5ba4ab626718fadf18 Mon Sep 17 00:00:00 2001 From: cafreeman Date: Fri, 17 Apr 2015 13:42:19 -0700 Subject: [PATCH 036/191] [SPARK-6807] [SparkR] Merge recent SparkR-pkg changes This PR pulls in recent changes in SparkR-pkg, including cartesian, intersection, sampleByKey, subtract, subtractByKey, except, and some API for StructType and StructField. Author: cafreeman Author: Davies Liu Author: Zongheng Yang Author: Shivaram Venkataraman Author: Shivaram Venkataraman Author: Sun Rui Closes #5436 from davies/R3 and squashes the following commits: c2b09be [Davies Liu] SQLTypes -> schema a5a02f2 [Davies Liu] Merge branch 'master' of github.com:apache/spark into R3 168b7fe [Davies Liu] sort generics b1fe460 [Davies Liu] fix conflict in README.md e74c04e [Davies Liu] fix schema.R 4f5ac09 [Davies Liu] Merge branch 'master' of github.com:apache/spark into R5 41f8184 [Davies Liu] rm man ae78312 [Davies Liu] Merge pull request #237 from sun-rui/SPARKR-154_3 1bdcb63 [Zongheng Yang] Updates to README.md. 5a553e7 [cafreeman] Use object attribute instead of argument 71372d9 [cafreeman] Update docs and examples 8526d2e71 [cafreeman] Remove `tojson` functions 6ef5f2d [cafreeman] Fix spacing 7741d66 [cafreeman] Rename the SQL DataType function 141efd8 [Shivaram Venkataraman] Merge pull request #245 from hqzizania/upstream 9387402 [Davies Liu] fix style 40199eb [Shivaram Venkataraman] Move except into sorted position 07d0dbc [Sun Rui] [SPARKR-244] Fix test failure after integration of subtract() and subtractByKey() for RDD. 7e8caa3 [Shivaram Venkataraman] Merge pull request #246 from hlin09/fixCombineByKey ed66c81 [cafreeman] Update `subtract` to work with `generics.R` f3ba785 [cafreeman] Fixed duplicate export 275deb4 [cafreeman] Update `NAMESPACE` and tests 1a3b63d [cafreeman] new version of `CreateDF` 836c4bf [cafreeman] Update `createDataFrame` and `toDF` be5d5c1 [cafreeman] refactor schema functions 40338a4 [Zongheng Yang] Merge pull request #244 from sun-rui/SPARKR-154_5 20b97a6 [Zongheng Yang] Merge pull request #234 from hqzizania/assist ba54e34 [Shivaram Venkataraman] Merge pull request #238 from sun-rui/SPARKR-154_4 c9497a3 [Shivaram Venkataraman] Merge pull request #208 from lythesia/master b317aa7 [Zongheng Yang] Merge pull request #243 from hqzizania/master 136a07e [Zongheng Yang] Merge pull request #242 from hqzizania/stats cd66603 [cafreeman] new line at EOF 8b76e81 [Shivaram Venkataraman] Merge pull request #233 from redbaron/fail-early-on-missing-dep 7dd81b7 [cafreeman] Documentation 0e2a94f [cafreeman] Define functions for schema and fields --- R/pkg/DESCRIPTION | 2 +- R/pkg/NAMESPACE | 20 +- R/pkg/R/DataFrame.R | 18 +- R/pkg/R/RDD.R | 205 ++++++++++++------ R/pkg/R/SQLContext.R | 44 +--- R/pkg/R/SQLTypes.R | 64 ------ R/pkg/R/column.R | 2 +- R/pkg/R/generics.R | 46 +++- R/pkg/R/group.R | 2 +- R/pkg/R/pairRDD.R | 192 +++++++++++++--- R/pkg/R/schema.R | 162 ++++++++++++++ R/pkg/R/serialize.R | 9 +- R/pkg/R/utils.R | 80 +++++++ R/pkg/inst/tests/test_rdd.R | 193 ++++++++++++++--- R/pkg/inst/tests/test_shuffle.R | 12 + R/pkg/inst/tests/test_sparkSQL.R | 35 +-- R/pkg/inst/worker/worker.R | 59 ++++- .../scala/org/apache/spark/api/r/RRDD.scala | 131 +++++------ .../scala/org/apache/spark/api/r/SerDe.scala | 14 +- .../org/apache/spark/sql/api/r/SQLUtils.scala | 32 ++- 20 files changed, 971 insertions(+), 351 deletions(-) delete mode 100644 R/pkg/R/SQLTypes.R create mode 100644 R/pkg/R/schema.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 052f68c6c24e2..1c1779a763c7e 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -19,7 +19,7 @@ Collate: 'jobj.R' 'RDD.R' 'pairRDD.R' - 'SQLTypes.R' + 'schema.R' 'column.R' 'group.R' 'DataFrame.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index a354cdce74afa..80283643861ac 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -5,6 +5,7 @@ exportMethods( "aggregateByKey", "aggregateRDD", "cache", + "cartesian", "checkpoint", "coalesce", "cogroup", @@ -28,6 +29,7 @@ exportMethods( "fullOuterJoin", "glom", "groupByKey", + "intersection", "join", "keyBy", "keys", @@ -52,11 +54,14 @@ exportMethods( "reduceByKeyLocally", "repartition", "rightOuterJoin", + "sampleByKey", "sampleRDD", "saveAsTextFile", "saveAsObjectFile", "sortBy", "sortByKey", + "subtract", + "subtractByKey", "sumRDD", "take", "takeOrdered", @@ -95,6 +100,7 @@ exportClasses("DataFrame") exportMethods("columns", "distinct", "dtypes", + "except", "explain", "filter", "groupBy", @@ -118,7 +124,6 @@ exportMethods("columns", "show", "showDF", "sortDF", - "subtract", "toJSON", "toRDD", "unionAll", @@ -178,5 +183,14 @@ export("cacheTable", "toDF", "uncacheTable") -export("print.structType", - "print.structField") +export("sparkRSQL.init", + "sparkRHive.init") + +export("structField", + "structField.jobj", + "structField.character", + "print.structField", + "structType", + "structType.jobj", + "structType.structField", + "print.structType") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 044fdb4d01223..861fe1c78b0db 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -17,7 +17,7 @@ # DataFrame.R - DataFrame class and methods implemented in S4 OO classes -#' @include generics.R jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R +#' @include generics.R jobj.R schema.R RDD.R pairRDD.R column.R group.R NULL setOldClass("jobj") @@ -1141,15 +1141,15 @@ setMethod("intersect", dataFrame(intersected) }) -#' Subtract +#' except #' #' Return a new DataFrame containing rows in this DataFrame #' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL. #' #' @param x A Spark DataFrame #' @param y A Spark DataFrame -#' @return A DataFrame containing the result of the subtract operation. -#' @rdname subtract +#' @return A DataFrame containing the result of the except operation. +#' @rdname except #' @export #' @examples #'\dontrun{ @@ -1157,13 +1157,15 @@ setMethod("intersect", #' sqlCtx <- sparkRSQL.init(sc) #' df1 <- jsonFile(sqlCtx, path) #' df2 <- jsonFile(sqlCtx, path2) -#' subtractDF <- subtract(df, df2) +#' exceptDF <- except(df, df2) #' } -setMethod("subtract", +#' @rdname except +#' @export +setMethod("except", signature(x = "DataFrame", y = "DataFrame"), function(x, y) { - subtracted <- callJMethod(x@sdf, "except", y@sdf) - dataFrame(subtracted) + excepted <- callJMethod(x@sdf, "except", y@sdf) + dataFrame(excepted) }) #' Save the contents of the DataFrame to a data source diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 820027ef67e3b..128431334ca52 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -730,6 +730,7 @@ setMethod("take", index <- -1 jrdd <- getJRDD(x) numPartitions <- numPartitions(x) + serializedModeRDD <- getSerializedMode(x) # TODO(shivaram): Collect more than one partition based on size # estimates similar to the scala version of `take`. @@ -748,13 +749,14 @@ setMethod("take", elems <- convertJListToRList(partition, flatten = TRUE, logicalUpperBound = size, - serializedMode = getSerializedMode(x)) - # TODO: Check if this append is O(n^2)? + serializedMode = serializedModeRDD) + resList <- append(resList, elems) } resList }) + #' First #' #' Return the first element of an RDD @@ -1092,21 +1094,42 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { if (num < length(part)) { # R limitation: order works only on primitive types! ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending) - list(part[ord[1:num]]) + part[ord[1:num]] } else { - list(part) + part } } - reduceFunc <- function(elems, part) { - newElems <- append(elems, part) - # R limitation: order works only on primitive types! - ord <- order(unlist(newElems, recursive = FALSE), decreasing = !ascending) - newElems[ord[1:num]] - } - newRdd <- mapPartitions(x, partitionFunc) - reduce(newRdd, reduceFunc) + + resList <- list() + index <- -1 + jrdd <- getJRDD(newRdd) + numPartitions <- numPartitions(newRdd) + serializedModeRDD <- getSerializedMode(newRdd) + + while (TRUE) { + index <- index + 1 + + if (index >= numPartitions) { + ord <- order(unlist(resList, recursive = FALSE), decreasing = !ascending) + resList <- resList[ord[1:num]] + break + } + + # a JList of byte arrays + partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index))) + partition <- partitionArr[[1]] + + # elems is capped to have at most `num` elements + elems <- convertJListToRList(partition, + flatten = TRUE, + logicalUpperBound = num, + serializedMode = serializedModeRDD) + + resList <- append(resList, elems) + } + resList } #' Returns the first N elements from an RDD in ascending order. @@ -1465,67 +1488,105 @@ setMethod("zipRDD", stop("Can only zip RDDs which have the same number of partitions.") } - if (getSerializedMode(x) != getSerializedMode(other) || - getSerializedMode(x) == "byte") { - # Append the number of elements in each partition to that partition so that we can later - # check if corresponding partitions of both RDDs have the same number of elements. - # - # Note that this appending also serves the purpose of reserialization, because even if - # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded - # as a single byte array. For example, partitions of an RDD generated from partitionBy() - # may be encoded as multiple byte arrays. - appendLength <- function(part) { - part[[length(part) + 1]] <- length(part) + 1 - part - } - x <- lapplyPartition(x, appendLength) - other <- lapplyPartition(other, appendLength) - } + rdds <- appendPartitionLengths(x, other) + jrdd <- callJMethod(getJRDD(rdds[[1]]), "zip", getJRDD(rdds[[2]])) + # The jrdd's elements are of scala Tuple2 type. The serialized + # flag here is used for the elements inside the tuples. + rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - zippedJRDD <- callJMethod(getJRDD(x), "zip", getJRDD(other)) - # The zippedRDD's elements are of scala Tuple2 type. The serialized - # flag Here is used for the elements inside the tuples. - serializerMode <- getSerializedMode(x) - zippedRDD <- RDD(zippedJRDD, serializerMode) + mergePartitions(rdd, TRUE) + }) + +#' Cartesian product of this RDD and another one. +#' +#' Return the Cartesian product of this RDD and another one, +#' that is, the RDD of all pairs of elements (a, b) where a +#' is in this and b is in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @return A new RDD which is the Cartesian product of these two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2) +#' sortByKey(cartesian(rdd, rdd)) +#' # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) +#'} +#' @rdname cartesian +#' @aliases cartesian,RDD,RDD-method +setMethod("cartesian", + signature(x = "RDD", other = "RDD"), + function(x, other) { + rdds <- appendPartitionLengths(x, other) + jrdd <- callJMethod(getJRDD(rdds[[1]]), "cartesian", getJRDD(rdds[[2]])) + # The jrdd's elements are of scala Tuple2 type. The serialized + # flag here is used for the elements inside the tuples. + rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - partitionFunc <- function(split, part) { - len <- length(part) - if (len > 0) { - if (serializerMode == "byte") { - lengthOfValues <- part[[len]] - lengthOfKeys <- part[[len - lengthOfValues]] - stopifnot(len == lengthOfKeys + lengthOfValues) - - # check if corresponding partitions of both RDDs have the same number of elements. - if (lengthOfKeys != lengthOfValues) { - stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") - } - - if (lengthOfKeys > 1) { - keys <- part[1 : (lengthOfKeys - 1)] - values <- part[(lengthOfKeys + 1) : (len - 1)] - } else { - keys <- list() - values <- list() - } - } else { - # Keys, values must have same length here, because this has - # been validated inside the JavaRDD.zip() function. - keys <- part[c(TRUE, FALSE)] - values <- part[c(FALSE, TRUE)] - } - mapply( - function(k, v) { - list(k, v) - }, - keys, - values, - SIMPLIFY = FALSE, - USE.NAMES = FALSE) - } else { - part - } + mergePartitions(rdd, FALSE) + }) + +#' Subtract an RDD with another RDD. +#' +#' Return an RDD with the elements from this that are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the elements from this that are not in other. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) +#' rdd2 <- parallelize(sc, list(2, 4)) +#' collect(subtract(rdd1, rdd2)) +#' # list(1, 1, 3) +#'} +#' @rdname subtract +#' @aliases subtract,RDD +setMethod("subtract", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR::numPartitions(x)) { + mapFunction <- function(e) { list(e, NA) } + rdd1 <- map(x, mapFunction) + rdd2 <- map(other, mapFunction) + keys(subtractByKey(rdd1, rdd2, numPartitions)) + }) + +#' Intersection of this RDD and another one. +#' +#' Return the intersection of this RDD and another one. +#' The output will not contain any duplicate elements, +#' even if the input RDDs did. Performs a hash partition +#' across the cluster. +#' Note that this method performs a shuffle internally. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions The number of partitions in the result RDD. +#' @return An RDD which is the intersection of these two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) +#' rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) +#' collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) +#' # list(1, 2, 3) +#'} +#' @rdname intersection +#' @aliases intersection,RDD +setMethod("intersection", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR::numPartitions(x)) { + rdd1 <- map(x, function(v) { list(v, NA) }) + rdd2 <- map(other, function(v) { list(v, NA) }) + + filterFunction <- function(elem) { + iters <- elem[[2]] + all(as.vector( + lapply(iters, function(iter) { length(iter) > 0 }), mode = "logical")) } - - PipelinedRDD(zippedRDD, partitionFunc) + + keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction)) }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 930ada22f4c38..4f05ba524a01a 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -54,9 +54,9 @@ infer_type <- function(x) { # StructType types <- lapply(x, infer_type) fields <- lapply(1:length(x), function(i) { - list(name = names[[i]], type = types[[i]], nullable = TRUE) + structField(names[[i]], types[[i]], TRUE) }) - list(type = "struct", fields = fields) + do.call(structType, fields) } } else if (length(x) > 1) { list(type = "array", elementType = type, containsNull = TRUE) @@ -65,30 +65,6 @@ infer_type <- function(x) { } } -#' dump the schema into JSON string -tojson <- function(x) { - if (is.list(x)) { - names <- names(x) - if (!is.null(names)) { - items <- lapply(names, function(n) { - safe_n <- gsub('"', '\\"', n) - paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '') - }) - d <- paste(items, collapse = ', ') - paste('{', d, '}', sep = '') - } else { - l <- paste(lapply(x, tojson), collapse = ', ') - paste('[', l, ']', sep = '') - } - } else if (is.character(x)) { - paste('"', x, '"', sep = '') - } else if (is.logical(x)) { - if (x) "true" else "false" - } else { - stop(paste("unexpected type:", class(x))) - } -} - #' Create a DataFrame from an RDD #' #' Converts an RDD to a DataFrame by infer the types. @@ -134,7 +110,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { stop(paste("unexpected type:", class(data))) } - if (is.null(schema) || is.null(names(schema))) { + if (is.null(schema) || (!inherits(schema, "structType") && is.null(names(schema)))) { row <- first(rdd) names <- if (is.null(schema)) { names(row) @@ -143,7 +119,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { } if (is.null(names)) { names <- lapply(1:length(row), function(x) { - paste("_", as.character(x), sep = "") + paste("_", as.character(x), sep = "") }) } @@ -159,20 +135,18 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { types <- lapply(row, infer_type) fields <- lapply(1:length(row), function(i) { - list(name = names[[i]], type = types[[i]], nullable = TRUE) + structField(names[[i]], types[[i]], TRUE) }) - schema <- list(type = "struct", fields = fields) + schema <- do.call(structType, fields) } - stopifnot(class(schema) == "list") - stopifnot(schema$type == "struct") - stopifnot(class(schema$fields) == "list") - schemaString <- tojson(schema) + stopifnot(class(schema) == "structType") + # schemaString <- tojson(schema) jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", - srdd, schemaString, sqlCtx) + srdd, schema$jobj, sqlCtx) dataFrame(sdf) } diff --git a/R/pkg/R/SQLTypes.R b/R/pkg/R/SQLTypes.R deleted file mode 100644 index 962fba5b3cf03..0000000000000 --- a/R/pkg/R/SQLTypes.R +++ /dev/null @@ -1,64 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Utility functions for handling SparkSQL DataTypes. - -# Handler for StructType -structType <- function(st) { - obj <- structure(new.env(parent = emptyenv()), class = "structType") - obj$jobj <- st - obj$fields <- function() { lapply(callJMethod(st, "fields"), structField) } - obj -} - -#' Print a Spark StructType. -#' -#' This function prints the contents of a StructType returned from the -#' SparkR JVM backend. -#' -#' @param x A StructType object -#' @param ... further arguments passed to or from other methods -print.structType <- function(x, ...) { - fieldsList <- lapply(x$fields(), function(i) { i$print() }) - print(fieldsList) -} - -# Handler for StructField -structField <- function(sf) { - obj <- structure(new.env(parent = emptyenv()), class = "structField") - obj$jobj <- sf - obj$name <- function() { callJMethod(sf, "name") } - obj$dataType <- function() { callJMethod(sf, "dataType") } - obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } - obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } - obj$nullable <- function() { callJMethod(sf, "nullable") } - obj$print <- function() { paste("StructField(", - paste(obj$name(), obj$dataType.toString(), obj$nullable(), sep = ", "), - ")", sep = "") } - obj -} - -#' Print a Spark StructField. -#' -#' This function prints the contents of a StructField returned from the -#' SparkR JVM backend. -#' -#' @param x A StructField object -#' @param ... further arguments passed to or from other methods -print.structField <- function(x, ...) { - cat(x$print()) -} diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index b282001d8b6b5..95fb9ff0887b6 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -17,7 +17,7 @@ # Column Class -#' @include generics.R jobj.R SQLTypes.R +#' @include generics.R jobj.R schema.R NULL setOldClass("jobj") diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5fb1ccaa84ee2..6c6233390134c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -230,6 +230,10 @@ setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") } ############ Binary Functions ############# +#' @rdname cartesian +#' @export +setGeneric("cartesian", function(x, other) { standardGeneric("cartesian") }) + #' @rdname countByKey #' @export setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) @@ -238,6 +242,11 @@ setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) #' @export setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) +#' @rdname intersection +#' @export +setGeneric("intersection", function(x, other, numPartitions = 1L) { + standardGeneric("intersection") }) + #' @rdname keys #' @export setGeneric("keys", function(x) { standardGeneric("keys") }) @@ -250,12 +259,18 @@ setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) #' @export setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) +#' @rdname sampleByKey +#' @export +setGeneric("sampleByKey", + function(x, withReplacement, fractions, seed) { + standardGeneric("sampleByKey") + }) + #' @rdname values #' @export setGeneric("values", function(x) { standardGeneric("values") }) - ############ Shuffle Functions ############ #' @rdname aggregateByKey @@ -330,9 +345,24 @@ setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("ri #' @rdname sortByKey #' @export -setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1L) { - standardGeneric("sortByKey") -}) +setGeneric("sortByKey", + function(x, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortByKey") + }) + +#' @rdname subtract +#' @export +setGeneric("subtract", + function(x, other, numPartitions = 1L) { + standardGeneric("subtract") + }) + +#' @rdname subtractByKey +#' @export +setGeneric("subtractByKey", + function(x, other, numPartitions = 1L) { + standardGeneric("subtractByKey") + }) ################### Broadcast Variable Methods ################# @@ -357,6 +387,10 @@ setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) #' @export setGeneric("explain", function(x, ...) { standardGeneric("explain") }) +#' @rdname except +#' @export +setGeneric("except", function(x, y) { standardGeneric("except") }) + #' @rdname filter #' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) @@ -434,10 +468,6 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) #' @export setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") }) -#' @rdname subtract -#' @export -setGeneric("subtract", function(x, y) { standardGeneric("subtract") }) - #' @rdname tojson #' @export setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 855fbdfc7c4ca..02237b3672d6b 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -17,7 +17,7 @@ # group.R - GroupedData class and methods implemented in S4 OO classes -#' @include generics.R jobj.R SQLTypes.R column.R +#' @include generics.R jobj.R schema.R column.R NULL setOldClass("jobj") diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 5d64822859d1f..13efebc11c46e 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -430,7 +430,7 @@ setMethod("combineByKey", pred <- function(item) exists(item$hash, keys) lapply(part, function(item) { - item$hash <- as.character(item[[1]]) + item$hash <- as.character(hashCode(item[[1]])) updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner) }) convertEnvsToList(keys, combiners) @@ -443,7 +443,7 @@ setMethod("combineByKey", pred <- function(item) exists(item$hash, keys) lapply(part, function(item) { - item$hash <- as.character(item[[1]]) + item$hash <- as.character(hashCode(item[[1]])) updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity) }) convertEnvsToList(keys, combiners) @@ -452,19 +452,19 @@ setMethod("combineByKey", }) #' Aggregate a pair RDD by each key. -#' +#' #' Aggregate the values of each key in an RDD, using given combine functions #' and a neutral "zero value". This function can return a different result type, #' U, than the type of the values in this RDD, V. Thus, we need one operation -#' for merging a V into a U and one operation for merging two U's, The former -#' operation is used for merging values within a partition, and the latter is -#' used for merging values between partitions. To avoid memory allocation, both -#' of these functions are allowed to modify and return their first argument +#' for merging a V into a U and one operation for merging two U's, The former +#' operation is used for merging values within a partition, and the latter is +#' used for merging values between partitions. To avoid memory allocation, both +#' of these functions are allowed to modify and return their first argument #' instead of creating a new U. -#' +#' #' @param x An RDD. #' @param zeroValue A neutral "zero value". -#' @param seqOp A function to aggregate the values of each key. It may return +#' @param seqOp A function to aggregate the values of each key. It may return #' a different result type from the type of the values. #' @param combOp A function to aggregate results of seqOp. #' @return An RDD containing the aggregation result. @@ -476,7 +476,7 @@ setMethod("combineByKey", #' zeroValue <- list(0, 0) #' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } #' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } -#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) #' # list(list(1, list(3, 2)), list(2, list(7, 2))) #'} #' @rdname aggregateByKey @@ -493,12 +493,12 @@ setMethod("aggregateByKey", }) #' Fold a pair RDD by each key. -#' +#' #' Aggregate the values of each key in an RDD, using an associative function "func" -#' and a neutral "zero value" which may be added to the result an arbitrary -#' number of times, and must not change the result (e.g., 0 for addition, or +#' and a neutral "zero value" which may be added to the result an arbitrary +#' number of times, and must not change the result (e.g., 0 for addition, or #' 1 for multiplication.). -#' +#' #' @param x An RDD. #' @param zeroValue A neutral "zero value". #' @param func An associative function for folding values of each key. @@ -548,11 +548,11 @@ setMethod("join", function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) - + doJoin <- function(v) { joinTaggedList(v, list(FALSE, FALSE)) } - + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)), doJoin) }) @@ -568,8 +568,8 @@ setMethod("join", #' @param y An RDD to be joined. Should be an RDD where each element is #' list(K, V). #' @param numPartitions Number of partitions to create. -#' @return For each element (k, v) in x, the resulting RDD will either contain -#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +#' @return For each element (k, v) in x, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) #' if no elements in rdd2 have key k. #' @examples #'\dontrun{ @@ -586,11 +586,11 @@ setMethod("leftOuterJoin", function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) - + doJoin <- function(v) { joinTaggedList(v, list(FALSE, TRUE)) } - + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) @@ -623,18 +623,18 @@ setMethod("rightOuterJoin", function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) - + doJoin <- function(v) { joinTaggedList(v, list(TRUE, FALSE)) } - + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) #' Full outer join two RDDs #' #' @description -#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). +#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). #' The key types of the two RDDs should be the same. #' #' @param x An RDD to be joined. Should be an RDD where each element is @@ -644,7 +644,7 @@ setMethod("rightOuterJoin", #' @param numPartitions Number of partitions to create. #' @return For each element (k, v) in x and (k, w) in y, the resulting RDD #' will contain all pairs (k, (v, w)) for both (k, v) in x and -#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements #' in x/y have key k. #' @examples #'\dontrun{ @@ -683,7 +683,7 @@ setMethod("fullOuterJoin", #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) #' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -#' cogroup(rdd1, rdd2, numPartitions = 2L) +#' cogroup(rdd1, rdd2, numPartitions = 2L) #' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) #'} #' @rdname cogroup @@ -694,7 +694,7 @@ setMethod("cogroup", rdds <- list(...) rddsLen <- length(rdds) for (i in 1:rddsLen) { - rdds[[i]] <- lapply(rdds[[i]], + rdds[[i]] <- lapply(rdds[[i]], function(x) { list(x[[1]], list(i, x[[2]])) }) } union.rdd <- Reduce(unionRDD, rdds) @@ -719,7 +719,7 @@ setMethod("cogroup", } }) } - cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), + cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), group.func) }) @@ -741,18 +741,18 @@ setMethod("sortByKey", signature(x = "RDD"), function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { rangeBounds <- list() - + if (numPartitions > 1) { rddSize <- count(x) # constant from Spark's RangePartitioner maxSampleSize <- numPartitions * 20 fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) - + samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L))) - + # Note: the built-in R sort() function only works on atomic vectors samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) - + if (length(samples) > 0) { rangeBounds <- lapply(seq_len(numPartitions - 1), function(i) { @@ -764,24 +764,146 @@ setMethod("sortByKey", rangePartitionFunc <- function(key) { partition <- 0 - + # TODO: Use binary search instead of linear search, similar with Spark while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) { partition <- partition + 1 } - + if (ascending) { partition } else { numPartitions - partition - 1 } } - + partitionFunc <- function(part) { sortKeyValueList(part, decreasing = !ascending) } - + newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) lapplyPartition(newRDD, partitionFunc) }) +#' Subtract a pair RDD with another pair RDD. +#' +#' Return an RDD with the pairs from x whose keys are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the pairs from x whose keys are not in other. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), +#' list("b", 5), list("a", 2))) +#' rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) +#' collect(subtractByKey(rdd1, rdd2)) +#' # list(list("b", 4), list("b", 5)) +#'} +#' @rdname subtractByKey +#' @aliases subtractByKey,RDD +setMethod("subtractByKey", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR::numPartitions(x)) { + filterFunction <- function(elem) { + iters <- elem[[2]] + (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0) + } + + flatMapValues(filterRDD(cogroup(x, + other, + numPartitions = numPartitions), + filterFunction), + function (v) { v[[1]] }) + }) + +#' Return a subset of this RDD sampled by key. +#' +#' @description +#' \code{sampleByKey} Create a sample of this RDD using variable sampling rates +#' for different keys as specified by fractions, a key to sampling rate map. +#' +#' @param x The RDD to sample elements by key, where each element is +#' list(K, V) or c(K, V). +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3000) +#' pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x) +#' else { if (x %% 3 == 1) list("b", x) else list("c", x) }}) +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) +#' 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE +#' 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE +#' 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE +#' lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE +#' lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE +#' lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE +#' lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE +#' lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE +#' lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored +#' fractions <- list(a = 0.2, b = 0.1) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c" +#'} +#' @rdname sampleByKey +#' @aliases sampleByKey,RDD-method +setMethod("sampleByKey", + signature(x = "RDD", withReplacement = "logical", + fractions = "vector", seed = "integer"), + function(x, withReplacement, fractions, seed) { + + for (elem in fractions) { + if (elem < 0.0) { + stop(paste("Negative fraction value ", fractions[which(fractions == elem)])) + } + } + + # The sampler: takes a partition and returns its sampled version. + samplingFunc <- function(split, part) { + set.seed(bitwXor(seed, split)) + res <- vector("list", length(part)) + len <- 0 + + # mixing because the initial seeds are close to each other + runif(10) + + for (elem in part) { + if (elem[[1]] %in% names(fractions)) { + frac <- as.numeric(fractions[which(elem[[1]] == names(fractions))]) + if (withReplacement) { + count <- rpois(1, frac) + if (count > 0) { + res[(len + 1):(len + count)] <- rep(list(elem), count) + len <- len + count + } + } else { + if (runif(1) < frac) { + len <- len + 1 + res[[len]] <- elem + } + } + } else { + stop("KeyError: \"", elem[[1]], "\"") + } + } + + # TODO(zongheng): look into the performance of the current + # implementation. Look into some iterator package? Note that + # Scala avoids many calls to creating an empty list and PySpark + # similarly achieves this using `yield'. (duplicated from sampleRDD) + if (len > 0) { + res[1:len] + } else { + list() + } + } + + lapplyPartitionsWithIndex(x, samplingFunc) + }) diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R new file mode 100644 index 0000000000000..e442119086b17 --- /dev/null +++ b/R/pkg/R/schema.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A set of S3 classes and methods that support the SparkSQL `StructType` and `StructField +# datatypes. These are used to create and interact with DataFrame schemas. + +#' structType +#' +#' Create a structType object that contains the metadata for a DataFrame. Intended for +#' use with createDataFrame and toDF. +#' +#' @param x a structField object (created with the field() function) +#' @param ... additional structField objects +#' @return a structType object +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) +#' schema <- structType(structField("a", "integer"), structField("b", "string")) +#' df <- createDataFrame(sqlCtx, rdd, schema) +#' } +structType <- function(x, ...) { + UseMethod("structType", x) +} + +structType.jobj <- function(x) { + obj <- structure(list(), class = "structType") + obj$jobj <- x + obj$fields <- function() { lapply(callJMethod(obj$jobj, "fields"), structField) } + obj +} + +structType.structField <- function(x, ...) { + fields <- list(x, ...) + if (!all(sapply(fields, inherits, "structField"))) { + stop("All arguments must be structField objects.") + } + sfObjList <- lapply(fields, function(field) { + field$jobj + }) + stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createStructType", + listToSeq(sfObjList)) + structType(stObj) +} + +#' Print a Spark StructType. +#' +#' This function prints the contents of a StructType returned from the +#' SparkR JVM backend. +#' +#' @param x A StructType object +#' @param ... further arguments passed to or from other methods +print.structType <- function(x, ...) { + cat("StructType\n", + sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(), + "\", type = \"", field$dataType.toString(), + "\", nullable = ", field$nullable(), "\n", + sep = "") }) + , sep = "") +} + +#' structField +#' +#' Create a structField object that contains the metadata for a single field in a schema. +#' +#' @param x The name of the field +#' @param type The data type of the field +#' @param nullable A logical vector indicating whether or not the field is nullable +#' @return a structField object +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) +#' field1 <- structField("a", "integer", TRUE) +#' field2 <- structField("b", "string", TRUE) +#' schema <- structType(field1, field2) +#' df <- createDataFrame(sqlCtx, rdd, schema) +#' } + +structField <- function(x, ...) { + UseMethod("structField", x) +} + +structField.jobj <- function(x) { + obj <- structure(list(), class = "structField") + obj$jobj <- x + obj$name <- function() { callJMethod(x, "name") } + obj$dataType <- function() { callJMethod(x, "dataType") } + obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } + obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } + obj$nullable <- function() { callJMethod(x, "nullable") } + obj +} + +structField.character <- function(x, type, nullable = TRUE) { + if (class(x) != "character") { + stop("Field name must be a string.") + } + if (class(type) != "character") { + stop("Field type must be a string.") + } + if (class(nullable) != "logical") { + stop("nullable must be either TRUE or FALSE") + } + options <- c("byte", + "integer", + "double", + "numeric", + "character", + "string", + "binary", + "raw", + "logical", + "boolean", + "timestamp", + "date") + dataType <- if (type %in% options) { + type + } else { + stop(paste("Unsupported type for Dataframe:", type)) + } + sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createStructField", + x, + dataType, + nullable) + structField(sfObj) +} + +#' Print a Spark StructField. +#' +#' This function prints the contents of a StructField returned from the +#' SparkR JVM backend. +#' +#' @param x A StructField object +#' @param ... further arguments passed to or from other methods +print.structField <- function(x, ...) { + cat("StructField(name = \"", x$name(), + "\", type = \"", x$dataType.toString(), + "\", nullable = ", x$nullable(), + ")", + sep = "") +} diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 8a9c0c652ce24..c53d0a961016f 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -69,8 +69,9 @@ writeJobj <- function(con, value) { } writeString <- function(con, value) { - writeInt(con, as.integer(nchar(value) + 1)) - writeBin(value, con, endian = "big") + utfVal <- enc2utf8(value) + writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1)) + writeBin(utfVal, con, endian = "big") } writeInt <- function(con, value) { @@ -189,7 +190,3 @@ writeArgs <- function(con, args) { } } } - -writeStrings <- function(con, stringList) { - writeLines(unlist(stringList), con) -} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index c337fb0751e72..23305d3c67074 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -465,3 +465,83 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { } func } + +# Append partition lengths to each partition in two input RDDs if needed. +# param +# x An RDD. +# Other An RDD. +# return value +# A list of two result RDDs. +appendPartitionLengths <- function(x, other) { + if (getSerializedMode(x) != getSerializedMode(other) || + getSerializedMode(x) == "byte") { + # Append the number of elements in each partition to that partition so that we can later + # know the boundary of elements from x and other. + # + # Note that this appending also serves the purpose of reserialization, because even if + # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded + # as a single byte array. For example, partitions of an RDD generated from partitionBy() + # may be encoded as multiple byte arrays. + appendLength <- function(part) { + len <- length(part) + part[[len + 1]] <- len + 1 + part + } + x <- lapplyPartition(x, appendLength) + other <- lapplyPartition(other, appendLength) + } + list (x, other) +} + +# Perform zip or cartesian between elements from two RDDs in each partition +# param +# rdd An RDD. +# zip A boolean flag indicating this call is for zip operation or not. +# return value +# A result RDD. +mergePartitions <- function(rdd, zip) { + serializerMode <- getSerializedMode(rdd) + partitionFunc <- function(split, part) { + len <- length(part) + if (len > 0) { + if (serializerMode == "byte") { + lengthOfValues <- part[[len]] + lengthOfKeys <- part[[len - lengthOfValues]] + stopifnot(len == lengthOfKeys + lengthOfValues) + + # For zip operation, check if corresponding partitions of both RDDs have the same number of elements. + if (zip && lengthOfKeys != lengthOfValues) { + stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + } + + if (lengthOfKeys > 1) { + keys <- part[1 : (lengthOfKeys - 1)] + } else { + keys <- list() + } + if (lengthOfValues > 1) { + values <- part[(lengthOfKeys + 1) : (len - 1)] + } else { + values <- list() + } + + if (!zip) { + return(mergeCompactLists(keys, values)) + } + } else { + keys <- part[c(TRUE, FALSE)] + values <- part[c(FALSE, TRUE)] + } + mapply( + function(k, v) { list(k, v) }, + keys, + values, + SIMPLIFY = FALSE, + USE.NAMES = FALSE) + } else { + part + } + } + + PipelinedRDD(rdd, partitionFunc) +} diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index b76e4db03e715..3ba7d1716302a 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -35,7 +35,7 @@ test_that("get number of partitions in RDD", { test_that("first on RDD", { expect_true(first(rdd) == 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_true(first(newrdd) == 2) + expect_true(first(newrdd) == 2) }) test_that("count and length on RDD", { @@ -48,7 +48,7 @@ test_that("count by values and keys", { actual <- countByValue(mods) expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - + actual <- countByKey(intRdd) expected <- list(list(2L, 2L), list(1L, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -82,11 +82,11 @@ test_that("filterRDD on RDD", { filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) actual <- collect(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) - + filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd) actual <- collect(filtered.rdd) expect_equal(actual, list(list(1L, -1))) - + # Filter out all elements. filtered.rdd <- filterRDD(rdd, function(x) { x > 10 }) actual <- collect(filtered.rdd) @@ -96,7 +96,7 @@ test_that("filterRDD on RDD", { test_that("lookup on RDD", { vals <- lookup(intRdd, 1L) expect_equal(vals, list(-1, 200)) - + vals <- lookup(intRdd, 3L) expect_equal(vals, list()) }) @@ -110,7 +110,7 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { }) rdd2 <- lapply(rdd2, function(x) x + x) actual <- collect(rdd2) - expected <- list(24, 24, 24, 24, 24, + expected <- list(24, 24, 24, 24, 24, 168, 170, 172, 174, 176) expect_equal(actual, expected) }) @@ -248,10 +248,10 @@ test_that("flatMapValues() on pairwise RDDs", { l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) actual <- collect(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) - + # Generate x to x+1 for every value actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) - expect_equal(actual, + expect_equal(actual, list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) }) @@ -348,7 +348,7 @@ test_that("top() on RDDs", { rdd <- parallelize(sc, l) actual <- top(rdd, 6L) expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6]) - + l <- list("e", "d", "c", "d", "a") rdd <- parallelize(sc, l) actual <- top(rdd, 3L) @@ -358,7 +358,7 @@ test_that("top() on RDDs", { test_that("fold() on RDDs", { actual <- fold(rdd, 0, "+") expect_equal(actual, Reduce("+", nums, 0)) - + rdd <- parallelize(sc, list()) actual <- fold(rdd, 0, "+") expect_equal(actual, 0) @@ -371,7 +371,7 @@ test_that("aggregateRDD() on RDDs", { combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) expect_equal(actual, list(10, 4)) - + rdd <- parallelize(sc, list()) actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) expect_equal(actual, list(0, 0)) @@ -380,13 +380,13 @@ test_that("aggregateRDD() on RDDs", { test_that("zipWithUniqueId() on RDDs", { rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collect(zipWithUniqueId(rdd)) - expected <- list(list("a", 0), list("b", 3), list("c", 1), + expected <- list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) expect_equal(actual, expected) - + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) actual <- collect(zipWithUniqueId(rdd)) - expected <- list(list("a", 0), list("b", 1), list("c", 2), + expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) }) @@ -394,13 +394,13 @@ test_that("zipWithUniqueId() on RDDs", { test_that("zipWithIndex() on RDDs", { rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collect(zipWithIndex(rdd)) - expected <- list(list("a", 0), list("b", 1), list("c", 2), + expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) - + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) actual <- collect(zipWithIndex(rdd)) - expected <- list(list("a", 0), list("b", 1), list("c", 2), + expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) }) @@ -427,12 +427,12 @@ test_that("pipeRDD() on RDDs", { actual <- collect(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) - + trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n")) actual <- collect(pipeRDD(trailed.rdd, "sort")) expected <- list("", "1", "2", "3") expect_equal(actual, expected) - + rev.nums <- 9:0 rev.rdd <- parallelize(sc, rev.nums, 2L) actual <- collect(pipeRDD(rev.rdd, "sort")) @@ -446,11 +446,11 @@ test_that("zipRDD() on RDDs", { actual <- collect(zipRDD(rdd1, rdd2)) expect_equal(actual, list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) - + mockFile = c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + rdd <- textFile(sc, fileName, 1) actual <- collect(zipRDD(rdd, rdd)) expected <- lapply(mockFile, function(x) { list(x ,x) }) @@ -465,10 +465,125 @@ test_that("zipRDD() on RDDs", { actual <- collect(zipRDD(rdd, rdd1)) expected <- lapply(mockFile, function(x) { list(x, x) }) expect_equal(actual, expected) - + + unlink(fileName) +}) + +test_that("cartesian() on RDDs", { + rdd <- parallelize(sc, 1:3) + actual <- collect(cartesian(rdd, rdd)) + expect_equal(sortKeyValueList(actual), + list( + list(1, 1), list(1, 2), list(1, 3), + list(2, 1), list(2, 2), list(2, 3), + list(3, 1), list(3, 2), list(3, 3))) + + # test case where one RDD is empty + emptyRdd <- parallelize(sc, list()) + actual <- collect(cartesian(rdd, emptyRdd)) + expect_equal(actual, list()) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + actual <- collect(cartesian(rdd, rdd)) + expected <- list( + list("Spark is awesome.", "Spark is pretty."), + list("Spark is awesome.", "Spark is awesome."), + list("Spark is pretty.", "Spark is pretty."), + list("Spark is pretty.", "Spark is awesome.")) + expect_equal(sortKeyValueList(actual), expected) + + rdd1 <- parallelize(sc, 0:1) + actual <- collect(cartesian(rdd1, rdd)) + expect_equal(sortKeyValueList(actual), + list( + list(0, "Spark is pretty."), + list(0, "Spark is awesome."), + list(1, "Spark is pretty."), + list(1, "Spark is awesome."))) + + rdd1 <- map(rdd, function(x) { x }) + actual <- collect(cartesian(rdd, rdd1)) + expect_equal(sortKeyValueList(actual), expected) + unlink(fileName) }) +test_that("subtract() on RDDs", { + l <- list(1, 1, 2, 2, 3, 4) + rdd1 <- parallelize(sc, l) + + # subtract by itself + actual <- collect(subtract(rdd1, rdd1)) + expect_equal(actual, list()) + + # subtract by an empty RDD + rdd2 <- parallelize(sc, list()) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + l) + + rdd2 <- parallelize(sc, list(2, 4)) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + list(1, 1, 3)) + + l <- list("a", "a", "b", "b", "c", "d") + rdd1 <- parallelize(sc, l) + rdd2 <- parallelize(sc, list("b", "d")) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="character"))), + list("a", "a", "c")) +}) + +test_that("subtractByKey() on pairwise RDDs", { + l <- list(list("a", 1), list("b", 4), + list("b", 5), list("a", 2)) + rdd1 <- parallelize(sc, l) + + # subtractByKey by itself + actual <- collect(subtractByKey(rdd1, rdd1)) + expect_equal(actual, list()) + + # subtractByKey by an empty RDD + rdd2 <- parallelize(sc, list()) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(l)) + + rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(actual, + list(list("b", 4), list("b", 5))) + + l <- list(list(1, 1), list(2, 4), + list(2, 5), list(1, 2)) + rdd1 <- parallelize(sc, l) + rdd2 <- parallelize(sc, list(list(1, 3), list(3, 1))) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(actual, + list(list(2, 4), list(2, 5))) +}) + +test_that("intersection() on RDDs", { + # intersection with self + actual <- collect(intersection(rdd, rdd)) + expect_equal(sort(as.integer(actual)), nums) + + # intersection with an empty RDD + emptyRdd <- parallelize(sc, list()) + actual <- collect(intersection(rdd, emptyRdd)) + expect_equal(actual, list()) + + rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) + rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) + actual <- collect(intersection(rdd1, rdd2)) + expect_equal(sort(as.integer(actual)), 1:3) +}) + test_that("join() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) @@ -596,9 +711,9 @@ test_that("sortByKey() on pairwise RDDs", { sortedRdd3 <- sortByKey(rdd3) actual <- collect(sortedRdd3) expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) - + # test on the boundary cases - + # boundary case 1: the RDD to be sorted has only 1 partition rdd4 <- parallelize(sc, l, 1L) sortedRdd4 <- sortByKey(rdd4) @@ -623,7 +738,7 @@ test_that("sortByKey() on pairwise RDDs", { rdd7 <- parallelize(sc, l3, 2L) sortedRdd7 <- sortByKey(rdd7) actual <- collect(sortedRdd7) - expect_equal(actual, l3) + expect_equal(actual, l3) }) test_that("collectAsMap() on a pairwise RDD", { @@ -634,12 +749,36 @@ test_that("collectAsMap() on a pairwise RDD", { rdd <- parallelize(sc, list(list("a", 1), list("b", 2))) vals <- collectAsMap(rdd) expect_equal(vals, list(a = 1, b = 2)) - + rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4)) - + rdd <- parallelize(sc, list(list(1, "a"), list(2, "b"))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1` = "a", `2` = "b")) }) + +test_that("sampleByKey() on pairwise RDDs", { + rdd <- parallelize(sc, 1:2000) + pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) + fractions <- list(a = 0.2, b = 0.1) + sample <- sampleByKey(pairsRDD, FALSE, fractions, 1618L) + expect_equal(100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")), TRUE) + expect_equal(50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")), TRUE) + expect_equal(lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0, TRUE) + expect_equal(lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000, TRUE) + expect_equal(lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0, TRUE) + expect_equal(lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000, TRUE) + + rdd <- parallelize(sc, 1:2000) + pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list(2, x) else list(3, x) }) + fractions <- list(`2` = 0.2, `3` = 0.1) + sample <- sampleByKey(pairsRDD, TRUE, fractions, 1618L) + expect_equal(100 < length(lookup(sample, 2)) && 300 > length(lookup(sample, 2)), TRUE) + expect_equal(50 < length(lookup(sample, 3)) && 150 > length(lookup(sample, 3)), TRUE) + expect_equal(lookup(sample, 2)[which.min(lookup(sample, 2))] >= 0, TRUE) + expect_equal(lookup(sample, 2)[which.max(lookup(sample, 2))] <= 2000, TRUE) + expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE) + expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE) +}) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R index d1da8232aea81..d7dedda553c56 100644 --- a/R/pkg/inst/tests/test_shuffle.R +++ b/R/pkg/inst/tests/test_shuffle.R @@ -87,6 +87,18 @@ test_that("combineByKey for doubles", { expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) }) +test_that("combineByKey for characters", { + stringKeyRDD <- parallelize(sc, + list(list("max", 1L), list("min", 2L), + list("other", 3L), list("max", 4L)), 2L) + reduced <- combineByKey(stringKeyRDD, + function(x) { x }, "+", "+", 2L) + actual <- collect(reduced) + + expected <- list(list("max", 5L), list("min", 2L), list("other", 3L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + test_that("aggregateByKey", { # test aggregateByKey for int keys rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index cf5cf6d1692af..25831ae2d9e18 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -44,9 +44,8 @@ test_that("infer types", { expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) expect_equal(infer_type(list(a = 1L, b = "2")), - list(type = "struct", - fields = list(list(name = "a", type = "integer", nullable = TRUE), - list(name = "b", type = "string", nullable = TRUE)))) + structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE))) e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), @@ -54,6 +53,18 @@ test_that("infer types", { valueContainsNull = TRUE)) }) +test_that("structType and structField", { + testField <- structField("a", "string") + expect_true(inherits(testField, "structField")) + expect_true(testField$name() == "a") + expect_true(testField$nullable()) + + testSchema <- structType(testField, structField("b", "integer")) + expect_true(inherits(testSchema, "structType")) + expect_true(inherits(testSchema$fields()[[2]], "structField")) + expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType") +}) + test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(sqlCtx, rdd, list("a", "b")) @@ -66,9 +77,8 @@ test_that("create DataFrame from RDD", { expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("_1", "_2")) - fields <- list(list(name = "a", type = "integer", nullable = TRUE), - list(name = "b", type = "string", nullable = TRUE)) - schema <- list(type = "struct", fields = fields) + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) df <- createDataFrame(sqlCtx, rdd, schema) expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("a", "b")) @@ -94,9 +104,8 @@ test_that("toDF", { expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("_1", "_2")) - fields <- list(list(name = "a", type = "integer", nullable = TRUE), - list(name = "b", type = "string", nullable = TRUE)) - schema <- list(type = "struct", fields = fields) + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) df <- toDF(rdd, schema) expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("a", "b")) @@ -635,7 +644,7 @@ test_that("isLocal()", { expect_false(isLocal(df)) }) -test_that("unionAll(), subtract(), and intersect() on a DataFrame", { +test_that("unionAll(), except(), and intersect() on a DataFrame", { df <- jsonFile(sqlCtx, jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", @@ -650,10 +659,10 @@ test_that("unionAll(), subtract(), and intersect() on a DataFrame", { expect_true(count(unioned) == 6) expect_true(first(unioned)$name == "Michael") - subtracted <- sortDF(subtract(df, df2), desc(df$age)) + excepted <- sortDF(except(df, df2), desc(df$age)) expect_true(inherits(unioned, "DataFrame")) - expect_true(count(subtracted) == 2) - expect_true(first(subtracted)$name == "Justin") + expect_true(count(excepted) == 2) + expect_true(first(excepted)$name == "Justin") intersected <- sortDF(intersect(df, df2), df$age) expect_true(inherits(unioned, "DataFrame")) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index c6542928e8ddd..014bf7bd7b3fe 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -17,6 +17,23 @@ # Worker class +# Get current system time +currentTimeSecs <- function() { + as.numeric(Sys.time()) +} + +# Get elapsed time +elapsedSecs <- function() { + proc.time()[3] +} + +# Constants +specialLengths <- list(END_OF_STERAM = 0L, TIMING_DATA = -1L) + +# Timing R process boot +bootTime <- currentTimeSecs() +bootElap <- elapsedSecs() + rLibDir <- Sys.getenv("SPARKR_RLIBDIR") # Set libPaths to include SparkR package as loadNamespace needs this # TODO: Figure out if we can avoid this by not loading any objects that require @@ -37,7 +54,7 @@ serializer <- SparkR:::readString(inputCon) # Include packages as required packageNames <- unserialize(SparkR:::readRaw(inputCon)) for (pkg in packageNames) { - suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE)) + suppressPackageStartupMessages(library(as.character(pkg), character.only=TRUE)) } # read function dependencies @@ -46,6 +63,9 @@ computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen)) env <- environment(computeFunc) parent.env(env) <- .GlobalEnv # Attach under global environment. +# Timing init envs for computing +initElap <- elapsedSecs() + # Read and set broadcast variables numBroadcastVars <- SparkR:::readInt(inputCon) if (numBroadcastVars > 0) { @@ -56,6 +76,9 @@ if (numBroadcastVars > 0) { } } +# Timing broadcast +broadcastElap <- elapsedSecs() + # If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int # as number of partitions to create. numPartitions <- SparkR:::readInt(inputCon) @@ -73,14 +96,23 @@ if (isEmpty != 0) { } else if (deserializer == "row") { data <- SparkR:::readDeserializeRows(inputCon) } + # Timing reading input data for execution + inputElap <- elapsedSecs() + output <- computeFunc(partition, data) + # Timing computing + computeElap <- elapsedSecs() + if (serializer == "byte") { SparkR:::writeRawSerialize(outputCon, output) } else if (serializer == "row") { SparkR:::writeRowSerialize(outputCon, output) } else { - SparkR:::writeStrings(outputCon, output) + # write lines one-by-one with flag + lapply(output, function(line) SparkR:::writeString(outputCon, line)) } + # Timing output + outputElap <- elapsedSecs() } else { if (deserializer == "byte") { # Now read as many characters as described in funcLen @@ -90,6 +122,8 @@ if (isEmpty != 0) { } else if (deserializer == "row") { data <- SparkR:::readDeserializeRows(inputCon) } + # Timing reading input data for execution + inputElap <- elapsedSecs() res <- new.env() @@ -107,6 +141,8 @@ if (isEmpty != 0) { res[[bucket]] <- acc } invisible(lapply(data, hashTupleToEnvir)) + # Timing computing + computeElap <- elapsedSecs() # Step 2: write out all of the environment as key-value pairs. for (name in ls(res)) { @@ -116,13 +152,26 @@ if (isEmpty != 0) { length(res[[name]]$data) <- res[[name]]$counter SparkR:::writeRawSerialize(outputCon, res[[name]]$data) } + # Timing output + outputElap <- elapsedSecs() } +} else { + inputElap <- broadcastElap + computeElap <- broadcastElap + outputElap <- broadcastElap } +# Report timing +SparkR:::writeInt(outputCon, specialLengths$TIMING_DATA) +SparkR:::writeDouble(outputCon, bootTime) +SparkR:::writeDouble(outputCon, initElap - bootElap) # init +SparkR:::writeDouble(outputCon, broadcastElap - initElap) # broadcast +SparkR:::writeDouble(outputCon, inputElap - broadcastElap) # input +SparkR:::writeDouble(outputCon, computeElap - inputElap) # compute +SparkR:::writeDouble(outputCon, outputElap - computeElap) # output + # End of output -if (serializer %in% c("byte", "row")) { - SparkR:::writeInt(outputCon, 0L) -} +SparkR:::writeInt(outputCon, specialLengths$END_OF_STERAM) close(outputCon) close(inputCon) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 5fa4d483b8342..6fea5e1144f2f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -42,10 +42,15 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( rLibDir: String, broadcastVars: Array[Broadcast[Object]]) extends RDD[U](parent) with Logging { + protected var dataStream: DataInputStream = _ + private var bootTime: Double = _ override def getPartitions: Array[Partition] = parent.partitions override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + // Timing start + bootTime = System.currentTimeMillis / 1000.0 + // The parent may be also an RRDD, so we should launch it first. val parentIterator = firstParent[T].iterator(partition, context) @@ -69,7 +74,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( // the socket used to receive the output of task val outSocket = serverSocket.accept() val inputStream = new BufferedInputStream(outSocket.getInputStream) - val dataStream = openDataStream(inputStream) + dataStream = new DataInputStream(inputStream) serverSocket.close() try { @@ -155,6 +160,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( } else if (deserializer == SerializationFormats.ROW) { dataOut.write(elem.asInstanceOf[Array[Byte]]) } else if (deserializer == SerializationFormats.STRING) { + // write string(for StringRRDD) printOut.println(elem) } } @@ -180,9 +186,41 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( }.start() } - protected def openDataStream(input: InputStream): Closeable + protected def readData(length: Int): U - protected def read(): U + protected def read(): U = { + try { + val length = dataStream.readInt() + + length match { + case SpecialLengths.TIMING_DATA => + // Timing data from R worker + val boot = dataStream.readDouble - bootTime + val init = dataStream.readDouble + val broadcast = dataStream.readDouble + val input = dataStream.readDouble + val compute = dataStream.readDouble + val output = dataStream.readDouble + logInfo( + ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + + "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + + "total = %.3f s").format( + boot, + init, + broadcast, + input, + compute, + output, + boot + init + broadcast + input + compute + output)) + read() + case length if length >= 0 => + readData(length) + } + } catch { + case eof: EOFException => + throw new SparkException("R worker exited unexpectedly (cranshed)", eof) + } + } } /** @@ -202,31 +240,16 @@ private class PairwiseRRDD[T: ClassTag]( SerializationFormats.BYTE, packageNames, rLibDir, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - private var dataStream: DataInputStream = _ - - override protected def openDataStream(input: InputStream): Closeable = { - dataStream = new DataInputStream(input) - dataStream - } - - override protected def read(): (Int, Array[Byte]) = { - try { - val length = dataStream.readInt() - - length match { - case length if length == 2 => - val hashedKey = dataStream.readInt() - val contentPairsLength = dataStream.readInt() - val contentPairs = new Array[Byte](contentPairsLength) - dataStream.readFully(contentPairs) - (hashedKey, contentPairs) - case _ => null // End of input - } - } catch { - case eof: EOFException => { - throw new SparkException("R worker exited unexpectedly (crashed)", eof) - } - } + override protected def readData(length: Int): (Int, Array[Byte]) = { + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null + } } lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) @@ -247,28 +270,13 @@ private class RRDD[T: ClassTag]( parent, -1, func, deserializer, serializer, packageNames, rLibDir, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - private var dataStream: DataInputStream = _ - - override protected def openDataStream(input: InputStream): Closeable = { - dataStream = new DataInputStream(input) - dataStream - } - - override protected def read(): Array[Byte] = { - try { - val length = dataStream.readInt() - - length match { - case length if length > 0 => - val obj = new Array[Byte](length) - dataStream.readFully(obj, 0, length) - obj - case _ => null - } - } catch { - case eof: EOFException => { - throw new SparkException("R worker exited unexpectedly (crashed)", eof) - } + override protected def readData(length: Int): Array[Byte] = { + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj) + obj + case _ => null } } @@ -289,26 +297,21 @@ private class StringRRDD[T: ClassTag]( parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - private var dataStream: BufferedReader = _ - - override protected def openDataStream(input: InputStream): Closeable = { - dataStream = new BufferedReader(new InputStreamReader(input)) - dataStream - } - - override protected def read(): String = { - try { - dataStream.readLine() - } catch { - case e: IOException => { - throw new SparkException("R worker exited unexpectedly (crashed)", e) - } + override protected def readData(length: Int): String = { + length match { + case length if length > 0 => + SerDe.readStringBytes(dataStream, length) + case _ => null } } lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) } +private object SpecialLengths { + val TIMING_DATA = -1 +} + private[r] class BufferedStreamThread( in: InputStream, name: String, diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index ccb2a371f4e48..371dfe454d1a2 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -85,13 +85,17 @@ private[spark] object SerDe { in.readDouble() } + def readStringBytes(in: DataInputStream, len: Int): String = { + val bytes = new Array[Byte](len) + in.readFully(bytes) + assert(bytes(len - 1) == 0) + val str = new String(bytes.dropRight(1), "UTF-8") + str + } + def readString(in: DataInputStream): String = { val len = in.readInt() - val asciiBytes = new Array[Byte](len) - in.readFully(asciiBytes) - assert(asciiBytes(len - 1) == 0) - val str = new String(asciiBytes.dropRight(1).map(_.toChar)) - str + readStringBytes(in, len) } def readBoolean(in: DataInputStream): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d1ea7cc3e9162..ae77f72998a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -23,7 +23,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} private[r] object SQLUtils { @@ -39,8 +39,34 @@ private[r] object SQLUtils { arr.toSeq } - def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + def createStructType(fields : Seq[StructField]): StructType = { + StructType(fields) + } + + def getSQLDataType(dataType: String): DataType = { + dataType match { + case "byte" => org.apache.spark.sql.types.ByteType + case "integer" => org.apache.spark.sql.types.IntegerType + case "double" => org.apache.spark.sql.types.DoubleType + case "numeric" => org.apache.spark.sql.types.DoubleType + case "character" => org.apache.spark.sql.types.StringType + case "string" => org.apache.spark.sql.types.StringType + case "binary" => org.apache.spark.sql.types.BinaryType + case "raw" => org.apache.spark.sql.types.BinaryType + case "logical" => org.apache.spark.sql.types.BooleanType + case "boolean" => org.apache.spark.sql.types.BooleanType + case "timestamp" => org.apache.spark.sql.types.TimestampType + case "date" => org.apache.spark.sql.types.DateType + case _ => throw new IllegalArgumentException(s"Invaid type $dataType") + } + } + + def createStructField(name: String, dataType: String, nullable: Boolean): StructField = { + val dtObj = getSQLDataType(dataType) + StructField(name, dtObj, nullable) + } + + def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { val num = schema.fields.size val rowRDD = rdd.map(bytesToRow) sqlContext.createDataFrame(rowRDD, schema) From d305e686b3d73213784bd75cdad7d168b22a1dc4 Mon Sep 17 00:00:00 2001 From: Olivier Girardot Date: Fri, 17 Apr 2015 16:23:10 -0500 Subject: [PATCH 037/191] SPARK-6988 : Fix documentation regarding DataFrames using the Java API This patch includes : * adding how to use map after an sql query using javaRDD * fixing the first few java examples that were written in Scala Thank you for your time, Olivier. Author: Olivier Girardot Closes #5564 from ogirardot/branch-1.3 and squashes the following commits: 9f8d60e [Olivier Girardot] SPARK-6988 : Fix documentation regarding DataFrames using the Java API (cherry picked from commit 6b528dc139da594ef2e651d84bd91fe0f738a39d) Signed-off-by: Reynold Xin --- docs/sql-programming-guide.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 03500867df70f..d49233714a0bb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -193,8 +193,8 @@ df.groupBy("age").count().show()
{% highlight java %} -val sc: JavaSparkContext // An existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) +JavaSparkContext sc // An existing SparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); @@ -308,8 +308,8 @@ val df = sqlContext.sql("SELECT * FROM table")
{% highlight java %} -val sqlContext = ... // An existing SQLContext -val df = sqlContext.sql("SELECT * FROM table") +SQLContext sqlContext = ... // An existing SQLContext +DataFrame df = sqlContext.sql("SELECT * FROM table") {% endhighlight %}
@@ -435,7 +435,7 @@ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AN // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. -List teenagerNames = teenagers.map(new Function() { +List teenagerNames = teenagers.javaRDD().map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); } @@ -599,7 +599,7 @@ DataFrame results = sqlContext.sql("SELECT name FROM people"); // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. -List names = results.map(new Function() { +List names = results.javaRDD().map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); } @@ -860,7 +860,7 @@ DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); -List teenagerNames = teenagers.map(new Function() { +List teenagerNames = teenagers.javaRDD().map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); } From a452c59210cf2c8ff8601cdb11401eea6dc14973 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 17 Apr 2015 16:30:13 -0500 Subject: [PATCH 038/191] Minor fix to SPARK-6958: Improve Python docstring for DataFrame.sort. As a follow up PR to #5544. cc davies Author: Reynold Xin Closes #5558 from rxin/sort-doc-improvement and squashes the following commits: f4c276f [Reynold Xin] Review feedback. d2dcf24 [Reynold Xin] Minor fix to SPARK-6958: Improve Python docstring for DataFrame.sort. --- python/pyspark/sql/dataframe.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 326d22e72f104..d70c5b0a6930c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -489,8 +489,9 @@ def sort(self, *cols, **kwargs): """Returns a new :class:`DataFrame` sorted by the specified column(s). :param cols: list of :class:`Column` or column names to sort by. - :param ascending: sort by ascending order or not, could be bool, int - or list of bool, int (default: True). + :param ascending: boolean or list of boolean (default True). + Sort ascending vs. descending. Specify list for multiple sort orders. + If a list is specified, length of the list must equal length of the `cols`. >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] @@ -519,7 +520,7 @@ def sort(self, *cols, **kwargs): jcols = [jc if asc else jc.desc() for asc, jc in zip(ascending, jcols)] else: - raise TypeError("ascending can only be bool or list, but got %s" % type(ascending)) + raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) jdf = self._jdf.sort(self._jseq(jcols)) return DataFrame(jdf, self.sql_ctx) From c5ed510135aee3a1a0402057b3b5229892aa6f3a Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Fri, 17 Apr 2015 18:28:42 -0700 Subject: [PATCH 039/191] [SPARK-6703][Core] Provide a way to discover existing SparkContext's I've added a static getOrCreate method to the static SparkContext object that allows one to either retrieve a previously created SparkContext or to instantiate a new one with the provided config. The method accepts an optional SparkConf to make usage intuitive. Still working on a test for this, basically want to create a new context from scratch, then ensure that subsequent calls don't overwrite that. Author: Ilya Ganelin Closes #5501 from ilganeli/SPARK-6703 and squashes the following commits: db9a963 [Ilya Ganelin] Closing second spark context 1dc0444 [Ilya Ganelin] Added ref equality check 8c884fa [Ilya Ganelin] Made getOrCreate synchronized cb0c6b7 [Ilya Ganelin] Doc updates and code cleanup 270cfe3 [Ilya Ganelin] [SPARK-6703] Documentation fixes 15e8dea [Ilya Ganelin] Updated comments and added MiMa Exclude 0e1567c [Ilya Ganelin] Got rid of unecessary option for AtomicReference dfec4da [Ilya Ganelin] Changed activeContext to AtomicReference 733ec9f [Ilya Ganelin] Fixed some bugs in test code 8be2f83 [Ilya Ganelin] Replaced match with if e92caf7 [Ilya Ganelin] [SPARK-6703] Added test to ensure that getOrCreate both allows creation, retrieval, and a second context if desired a99032f [Ilya Ganelin] Spacing fix d7a06b8 [Ilya Ganelin] Updated SparkConf class to add getOrCreate method. Started test suite implementation --- .../scala/org/apache/spark/SparkContext.scala | 49 ++++++++++++++++--- .../org/apache/spark/SparkContextSuite.scala | 20 ++++++++ project/MimaExcludes.scala | 4 ++ 3 files changed, 66 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e106c5c4bef60..86269eac52db0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -23,7 +23,7 @@ import java.io._ import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} -import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} +import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID import scala.collection.{Map, Set} @@ -1887,11 +1887,12 @@ object SparkContext extends Logging { private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object() /** - * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`. + * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `null`. * - * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK. */ - private var activeContext: Option[SparkContext] = None + private val activeContext: AtomicReference[SparkContext] = + new AtomicReference[SparkContext](null) /** * Points to a partially-constructed SparkContext if some thread is in the SparkContext @@ -1926,7 +1927,8 @@ object SparkContext extends Logging { logWarning(warnMsg) } - activeContext.foreach { ctx => + if (activeContext.get() != null) { + val ctx = activeContext.get() val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." + " To ignore this error, set spark.driver.allowMultipleContexts = true. " + s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}" @@ -1941,6 +1943,39 @@ object SparkContext extends Logging { } } + /** + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, + * this is useful when applications may wish to share a SparkContext. + * + * Note: This function cannot be used to create multiple SparkContext instances + * even if multiple contexts are allowed. + */ + def getOrCreate(config: SparkConf): SparkContext = { + // Synchronize to ensure that multiple create requests don't trigger an exception + // from assertNoOtherContextIsRunning within setActiveContext + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + if (activeContext.get() == null) { + setActiveContext(new SparkContext(config), allowMultipleContexts = false) + } + activeContext.get() + } + } + + /** + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, + * this is useful when applications may wish to share a SparkContext. + * + * This method allows not passing a SparkConf (useful if just retrieving). + * + * Note: This function cannot be used to create multiple SparkContext instances + * even if multiple contexts are allowed. + */ + def getOrCreate(): SparkContext = { + getOrCreate(new SparkConf()) + } + /** * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is * running. Throws an exception if a running context is detected and logs a warning if another @@ -1967,7 +2002,7 @@ object SparkContext extends Logging { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { assertNoOtherContextIsRunning(sc, allowMultipleContexts) contextBeingConstructed = None - activeContext = Some(sc) + activeContext.set(sc) } } @@ -1978,7 +2013,7 @@ object SparkContext extends Logging { */ private[spark] def clearActiveContext(): Unit = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - activeContext = None + activeContext.set(null) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 94be1c6d6397c..728558a424780 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -67,6 +67,26 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { } } + test("Test getOrCreate") { + var sc2: SparkContext = null + SparkContext.clearActiveContext() + val conf = new SparkConf().setAppName("test").setMaster("local") + + sc = SparkContext.getOrCreate(conf) + + assert(sc.getConf.get("spark.app.name").equals("test")) + sc2 = SparkContext.getOrCreate(new SparkConf().setAppName("test2").setMaster("local")) + assert(sc2.getConf.get("spark.app.name").equals("test")) + assert(sc === sc2) + assert(sc eq sc2) + + // Try creating second context to confirm that it's still possible, if desired + sc2 = new SparkContext(new SparkConf().setAppName("test3").setMaster("local") + .set("spark.driver.allowMultipleContexts", "true")) + + sc2.stop() + } + test("BytesWritable implicit conversion is correct") { // Regression test for SPARK-3121 val bytesWritable = new BytesWritable() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1564babefa62f..7ef363a2f07ad 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -68,6 +68,10 @@ object MimaExcludes { // SPARK-6693 add tostring with max lines and width for matrix ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Matrix.toString") + )++ Seq( + // SPARK-6703 Add getOrCreate method to SparkContext + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") ) case v if v.startsWith("1.3") => From 6fbeb82e13db7117d8f216e6148632490a4bc5be Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Fri, 17 Apr 2015 18:30:55 -0700 Subject: [PATCH 040/191] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Defined executorCores from "spark.mesos.executor.cores" - Changed the amount of mesosExecutor's cores to executorCores. - Added new configuration option on running-on-mesos.md Author: Jongyoul Lee Closes #5063 from jongyoul/SPARK-6350 and squashes the following commits: 9238d6e [Jongyoul Lee] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Fixed docs - Changed configuration name - Made mesosExecutorCores private 2d41241 [Jongyoul Lee] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Fixed docs 89edb4f [Jongyoul Lee] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Fixed docs 8ba7694 [Jongyoul Lee] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Fixed docs 7549314 [Jongyoul Lee] [SPARK-6453][Mesos] Some Mesos*Suite have a different package with their classes - Fixed docs 4ae7b0c [Jongyoul Lee] [SPARK-6453][Mesos] Some Mesos*Suite have a different package with their classes - Removed TODO c27efce [Jongyoul Lee] [SPARK-6453][Mesos] Some Mesos*Suite have a different package with their classes - Fixed Mesos*Suite for supporting integer WorkerOffers - Fixed Documentation 1fe4c03 [Jongyoul Lee] [SPARK-6453][Mesos] Some Mesos*Suite have a different package with their classes - Change available resources of cpus to integer value beacuse WorkerOffer support the amount cpus as integer value 5f3767e [Jongyoul Lee] Revert "[SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode" 4b7c69e [Jongyoul Lee] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Changed configruation name and description from "spark.mesos.executor.cores" to "spark.executor.frameworkCores" 0556792 [Jongyoul Lee] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Defined executorCores from "spark.mesos.executor.cores" - Changed the amount of mesosExecutor's cores to executorCores. - Added new configuration option on running-on-mesos.md --- .../cluster/mesos/MesosSchedulerBackend.scala | 14 +++++++------- .../cluster/mesos/MesosSchedulerBackendSuite.scala | 4 ++-- docs/running-on-mesos.md | 10 ++++++++++ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index b381436839227..d9d62b0e287ed 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -67,6 +67,8 @@ private[spark] class MesosSchedulerBackend( // The listener bus to publish executor added/removed events. val listenerBus = sc.listenerBus + + private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) @volatile var appId: String = _ @@ -139,7 +141,7 @@ private[spark] class MesosSchedulerBackend( .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder() - .setValue(scheduler.CPUS_PER_TASK).build()) + .setValue(mesosExecutorCores).build()) .build() val memory = Resource.newBuilder() .setName("mem") @@ -220,10 +222,9 @@ private[spark] class MesosSchedulerBackend( val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - // TODO(pwendell): Should below be 1 + scheduler.CPUS_PER_TASK? (mem >= MemoryUtils.calculateTotalMemory(sc) && // need at least 1 for executor, 1 for task - cpus >= 2 * scheduler.CPUS_PER_TASK) || + cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) || (slaveIdsWithExecutors.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) } @@ -232,10 +233,9 @@ private[spark] class MesosSchedulerBackend( val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt } else { - // If the executor doesn't exist yet, subtract CPU for executor - // TODO(pwendell): Should below just subtract "1"? - getResource(o.getResourcesList, "cpus").toInt - - scheduler.CPUS_PER_TASK + // If the Mesos executor has not been started on this slave yet, set aside a few + // cores for the Mesos executor by offering fewer cores to the Spark executor + (getResource(o.getResourcesList, "cpus") - mesosExecutorCores).toInt } new WorkerOffer( o.getSlaveId.getValue, diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index a311512e82c5e..cdd7be0fbe5dd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -118,12 +118,12 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo expectedWorkerOffers.append(new WorkerOffer( mesosOffers.get(0).getSlaveId.getValue, mesosOffers.get(0).getHostname, - 2 + (minCpu - backend.mesosExecutorCores).toInt )) expectedWorkerOffers.append(new WorkerOffer( mesosOffers.get(2).getSlaveId.getValue, mesosOffers.get(2).getHostname, - 2 + (minCpu - backend.mesosExecutorCores).toInt )) val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index c984639bd34cf..594bf78b67713 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -210,6 +210,16 @@ See the [configuration page](configuration.html) for information on Spark config Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. +
+ + + + From 1991337336596f94698e79c2366f065c374128ab Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 17 Apr 2015 19:02:07 -0700 Subject: [PATCH 041/191] [SPARK-5933] [core] Move config deprecation warnings to SparkConf. I didn't find many deprecated configs after a grep-based search, but the ones I could find were moved to the centralized location in SparkConf. While there, I deprecated a couple more HS configs that mentioned time units. Author: Marcelo Vanzin Closes #5562 from vanzin/SPARK-5933 and squashes the following commits: dcb617e7 [Marcelo Vanzin] [SPARK-5933] [core] Move config deprecation warnings to SparkConf. --- .../main/scala/org/apache/spark/SparkConf.scala | 17 ++++++++++++++--- .../main/scala/org/apache/spark/SparkEnv.scala | 10 ++-------- .../deploy/history/FsHistoryProvider.scala | 15 +++------------ .../scala/org/apache/spark/SparkConfSuite.scala | 3 +++ docs/monitoring.md | 15 +++++++-------- .../spark/deploy/yarn/ApplicationMaster.scala | 9 +-------- 6 files changed, 30 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index b0186e9a007b8..e3a649d755450 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -403,6 +403,9 @@ private[spark] object SparkConf extends Logging { */ private val deprecatedConfigs: Map[String, DeprecatedConfig] = { val configs = Seq( + DeprecatedConfig("spark.cache.class", "0.8", + "The spark.cache.class property is no longer being used! Specify storage levels using " + + "the RDD.persist() method instead."), DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", "Please use spark.{driver,executor}.userClassPathFirst instead.")) Map(configs.map { cfg => (cfg.key -> cfg) }:_*) @@ -420,7 +423,15 @@ private[spark] object SparkConf extends Logging { "spark.history.fs.update.interval" -> Seq( AlternateConfig("spark.history.fs.update.interval.seconds", "1.4"), AlternateConfig("spark.history.fs.updateInterval", "1.3"), - AlternateConfig("spark.history.updateInterval", "1.3")) + AlternateConfig("spark.history.updateInterval", "1.3")), + "spark.history.fs.cleaner.interval" -> Seq( + AlternateConfig("spark.history.fs.cleaner.interval.seconds", "1.4")), + "spark.history.fs.cleaner.maxAge" -> Seq( + AlternateConfig("spark.history.fs.cleaner.maxAge.seconds", "1.4")), + "spark.yarn.am.waitTime" -> Seq( + AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", + // Translate old value to a duration, with 10s wait time per try. + translation = s => s"${s.toLong * 10}s")) ) /** @@ -470,7 +481,7 @@ private[spark] object SparkConf extends Logging { configsWithAlternatives.get(key).flatMap { alts => alts.collectFirst { case alt if conf.contains(alt.key) => val value = conf.get(alt.key) - alt.translation.map(_(value)).getOrElse(value) + if (alt.translation != null) alt.translation(value) else value } } } @@ -514,6 +525,6 @@ private[spark] object SparkConf extends Logging { private case class AlternateConfig( key: String, version: String, - translation: Option[String => String] = None) + translation: String => String = null) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0171488e09562..959aefabd8de4 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -103,7 +103,7 @@ class SparkEnv ( // actorSystem.awaitTermination() // Note that blockTransferService is stopped by BlockManager since it is started by it. - + // If we only stop sc, but the driver process still run as a services then we need to delete // the tmp dir, if not, it will create too many tmp dirs. // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the @@ -375,12 +375,6 @@ object SparkEnv extends Logging { "." } - // Warn about deprecated spark.cache.class property - if (conf.contains("spark.cache.class")) { - logWarning("The spark.cache.class property is no longer being used! Specify storage " + - "levels using the RDD.persist() method instead.") - } - val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { new OutputCommitCoordinator(conf) } @@ -406,7 +400,7 @@ object SparkEnv extends Logging { shuffleMemoryManager, outputCommitCoordinator, conf) - + // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is // called, and we only need to do it for driver. Because driver may run as a service, and if we // don't delete this tmp dir when sc is stopped, then will create too many tmp dirs. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 985545742df67..47bdd7749ec3d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -52,8 +52,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private val UPDATE_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.update.interval", "10s") // Interval between each cleaner checks for event logs to delete - private val CLEAN_INTERVAL_MS = conf.getLong("spark.history.fs.cleaner.interval.seconds", - DEFAULT_SPARK_HISTORY_FS_CLEANER_INTERVAL_S) * 1000 + private val CLEAN_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.cleaner.interval", "1d") private val logDir = conf.getOption("spark.history.fs.logDirectory") .map { d => Utils.resolveURI(d).toString } @@ -130,8 +129,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { // A task that periodically cleans event logs on disk. - pool.scheduleAtFixedRate(getRunner(cleanLogs), 0, CLEAN_INTERVAL_MS, - TimeUnit.MILLISECONDS) + pool.scheduleAtFixedRate(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) } } } @@ -270,8 +268,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis try { val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) .getOrElse(Seq[FileStatus]()) - val maxAge = conf.getLong("spark.history.fs.cleaner.maxAge.seconds", - DEFAULT_SPARK_HISTORY_FS_MAXAGE_S) * 1000 + val maxAge = conf.getTimeAsSeconds("spark.history.fs.cleaner.maxAge", "7d") * 1000 val now = System.currentTimeMillis() val appsToRetain = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() @@ -417,12 +414,6 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" - - // One day - val DEFAULT_SPARK_HISTORY_FS_CLEANER_INTERVAL_S = Duration(1, TimeUnit.DAYS).toSeconds - - // One week - val DEFAULT_SPARK_HISTORY_FS_MAXAGE_S = Duration(7, TimeUnit.DAYS).toSeconds } private class FsApplicationHistoryInfo( diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 7d87ba5fd2610..8e6c200c4ba00 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -217,6 +217,9 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro val count = conf.getAll.filter { case (k, v) => k.startsWith("spark.history.") }.size assert(count === 4) + + conf.set("spark.yarn.applicationMaster.waitTries", "42") + assert(conf.getTimeAsSeconds("spark.yarn.am.waitTime") === 420) } } diff --git a/docs/monitoring.md b/docs/monitoring.md index 2a130224591ca..8a85928d6d44d 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -153,19 +153,18 @@ follows: - - + + - - + +
@@ -223,12 +225,14 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { (driver.state == DriverState.RUNNING || driver.state == DriverState.SUBMITTED || driver.state == DriverState.RELAUNCHING)) { - val killLinkUri = s"driver/kill?id=${driver.id}&terminate=true" - val confirm = "return window.confirm(" + - s"'Are you sure you want to kill driver ${driver.id} ?');" - - (kill) - + val confirm = + s"if (window.confirm('Are you sure you want to kill driver ${driver.id} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" +
+ + + (kill) +
}
{driver.id} {killLink}
spark.history.fs.update.interval.seconds10spark.history.fs.update.interval10s - The period, in seconds, at which information displayed by this history server is updated. + The period at which information displayed by this history server is updated. Each update checks for any changes made to the event logs in persisted storage.
spark.mesos.mesosExecutor.cores1.0 + (Fine-grained mode only) Number of cores to give each Mesos executor. This does not + include the cores used to run the Spark tasks. In other words, even if no Spark task + is being run, each Mesos executor will occupy the number of cores configured here. + The value can be a floating point number. +
spark.mesos.executor.home driver side SPARK_HOME
spark.history.fs.cleaner.interval.seconds86400spark.history.fs.cleaner.interval1d - How often the job history cleaner checks for files to delete, in seconds. Defaults to 86400 (one day). - Files are only deleted if they are older than spark.history.fs.cleaner.maxAge.seconds. + How often the job history cleaner checks for files to delete. + Files are only deleted if they are older than spark.history.fs.cleaner.maxAge.
spark.history.fs.cleaner.maxAge.seconds3600 * 24 * 7spark.history.fs.cleaner.maxAge7d - Job history files older than this many seconds will be deleted when the history cleaner runs. - Defaults to 3600 * 24 * 7 (1 week). + Job history files older than this will be deleted when the history cleaner runs.
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index c357b7ae9d4da..f7a84207e9da6 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -373,14 +373,7 @@ private[spark] class ApplicationMaster( private def waitForSparkContextInitialized(): SparkContext = { logInfo("Waiting for spark context initialization") sparkContextRef.synchronized { - val waitTries = sparkConf.getOption("spark.yarn.applicationMaster.waitTries") - .map(_.toLong * 10000L) - if (waitTries.isDefined) { - logWarning( - "spark.yarn.applicationMaster.waitTries is deprecated, use spark.yarn.am.waitTime") - } - val totalWaitTime = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", - s"${waitTries.getOrElse(100000L)}ms") + val totalWaitTime = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s") val deadline = System.currentTimeMillis() + totalWaitTime while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) { From d850b4bd3a294dd245881e03f7f94bf970a7ee79 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 17 Apr 2015 19:17:06 -0700 Subject: [PATCH 042/191] [SPARK-6975][Yarn] Fix argument validation error `numExecutors` checking is failed when dynamic allocation is enabled with default configuration. Details can be seen is [SPARK-6975](https://issues.apache.org/jira/browse/SPARK-6975). sryza, please help me to review this, not sure is this the correct way, I think previous you change this part :) Author: jerryshao Closes #5551 from jerryshao/SPARK-6975 and squashes the following commits: 4335da1 [jerryshao] Change according to the comments 77bdcbd [jerryshao] Fix argument validation error --- .../org/apache/spark/deploy/yarn/ClientArguments.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index da6798cb1b279..1423533470fc0 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -103,9 +103,13 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) * This is intended to be called only after the provided arguments have been parsed. */ private def validateArgs(): Unit = { - if (numExecutors <= 0) { + if (numExecutors < 0 || (!isDynamicAllocationEnabled && numExecutors == 0)) { throw new IllegalArgumentException( - "You must specify at least 1 executor!\n" + getUsageMessage()) + s""" + |Number of executors was $numExecutors, but must be at least 1 + |(or 0 if dynamic executor allocation is enabled). + |${getUsageMessage()} + """.stripMargin) } if (executorCores < sparkConf.getInt("spark.task.cpus", 1)) { throw new SparkException("Executor cores must not be less than " + From 5f095d56054d57c54d81db1d36cd46312810fb6a Mon Sep 17 00:00:00 2001 From: Olivier Girardot Date: Sat, 18 Apr 2015 00:31:01 -0700 Subject: [PATCH 043/191] SPARK-6992 : Fix documentation example for Spark SQL on StructType This patch is fixing the Java examples for Spark SQL when defining programmatically a Schema and mapping Rows. Author: Olivier Girardot Closes #5569 from ogirardot/branch-1.3 and squashes the following commits: c29e58d [Olivier Girardot] SPARK-6992 : Fix documentation example for Spark SQL on StructType (cherry picked from commit c9b1ba4b16a7afe93d45bf75b128cc0dd287ded0) Signed-off-by: Reynold Xin --- docs/sql-programming-guide.md | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d49233714a0bb..b2022546268a7 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -555,13 +555,16 @@ by `SQLContext`. For example: {% highlight java %} -// Import factory methods provided by DataType. -import org.apache.spark.sql.types.DataType; +import org.apache.spark.api.java.function.Function; +// Import factory methods provided by DataTypes. +import org.apache.spark.sql.types.DataTypes; // Import StructType and StructField import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructField; // Import Row. import org.apache.spark.sql.Row; +// Import RowFactory. +import org.apache.spark.sql.RowFactory; // sc is an existing JavaSparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); @@ -575,16 +578,16 @@ String schemaString = "name age"; // Generate the schema based on the string of schema List fields = new ArrayList(); for (String fieldName: schemaString.split(" ")) { - fields.add(DataType.createStructField(fieldName, DataType.StringType, true)); + fields.add(DataTypes.createStructField(fieldName, DataTypes.StringType, true)); } -StructType schema = DataType.createStructType(fields); +StructType schema = DataTypes.createStructType(fields); // Convert records of the RDD (people) to Rows. JavaRDD rowRDD = people.map( new Function() { public Row call(String record) throws Exception { String[] fields = record.split(","); - return Row.create(fields[0], fields[1].trim()); + return RowFactory.create(fields[0], fields[1].trim()); } }); From 327ebf0cb5e236579bece057eda27b21aed0e2dc Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 18 Apr 2015 10:14:56 +0100 Subject: [PATCH 044/191] [core] [minor] Make sure ConnectionManager stops. My previous fix (force a selector wakeup) didn't seem to work since I ran into the hang again. So change the code a bit to be more explicit about the condition when the selector thread should exit. Author: Marcelo Vanzin Closes #5566 from vanzin/conn-mgr-hang and squashes the following commits: ddb2c03 [Marcelo Vanzin] [core] [minor] Make sure ConnectionManager stops. --- .../spark/network/nio/ConnectionManager.scala | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 5a74c13b38bf7..1a68e621eaee7 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -188,6 +188,7 @@ private[nio] class ConnectionManager( private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + @volatile private var isActive = true private val selectorThread = new Thread("connection-manager-thread") { override def run(): Unit = ConnectionManager.this.run() } @@ -342,7 +343,7 @@ private[nio] class ConnectionManager( def run() { try { - while(!selectorThread.isInterrupted) { + while (isActive) { while (!registerRequests.isEmpty) { val conn: SendingConnection = registerRequests.dequeue() addListeners(conn) @@ -398,7 +399,7 @@ private[nio] class ConnectionManager( } catch { // Explicitly only dealing with CancelledKeyException here since other exceptions // should be dealt with differently. - case e: CancelledKeyException => { + case e: CancelledKeyException => // Some keys within the selectors list are invalid/closed. clear them. val allKeys = selector.keys().iterator() @@ -420,8 +421,11 @@ private[nio] class ConnectionManager( } } } - } - 0 + 0 + + case e: ClosedSelectorException => + logDebug("Failed select() as selector is closed.", e) + return } if (selectedKeysCount == 0) { @@ -988,11 +992,11 @@ private[nio] class ConnectionManager( } def stop() { + isActive = false ackTimeoutMonitor.stop() - selector.wakeup() + selector.close() selectorThread.interrupt() selectorThread.join() - selector.close() val connections = connectionsByKey.values connections.foreach(_.close()) if (connectionsByKey.size != 0) { From 28683b4df5de06373b867068b9b8adfbcaf93176 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Sat, 18 Apr 2015 16:46:28 -0700 Subject: [PATCH 045/191] [SPARK-6219] Reuse pep8.py Per the discussion in the comments on [this commit](https://github.com/apache/spark/commit/f17d43b033d928dbc46aef8e367aa08902e698ad#commitcomment-10780649), this PR allows the Python lint script to reuse `pep8.py` when possible. Author: Nicholas Chammas Closes #5561 from nchammas/save-dem-pep8-bytes and squashes the following commits: b7c91e6 [Nicholas Chammas] reuse pep8.py --- dev/.gitignore | 1 + dev/lint-python | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 10 deletions(-) create mode 100644 dev/.gitignore diff --git a/dev/.gitignore b/dev/.gitignore new file mode 100644 index 0000000000000..4a6027429e0d3 --- /dev/null +++ b/dev/.gitignore @@ -0,0 +1 @@ +pep8*.py diff --git a/dev/lint-python b/dev/lint-python index fded654893a7c..f50d149dc4d44 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -32,18 +32,19 @@ compile_status="${PIPESTATUS[0]}" #+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 #+ TODOs: #+ - Download pep8 from PyPI. It's more "official". -PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8.py" -PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.6.2/pep8.py" +PEP8_VERSION="1.6.2" +PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8-$PEP8_VERSION.py" +PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/$PEP8_VERSION/pep8.py" -# if [ ! -e "$PEP8_SCRIPT_PATH" ]; then -curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" -curl_status="$?" +if [ ! -e "$PEP8_SCRIPT_PATH" ]; then + curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" + curl_status="$?" -if [ "$curl_status" -ne 0 ]; then - echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." - exit "$curl_status" + if [ "$curl_status" -ne 0 ]; then + echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." + exit "$curl_status" + fi fi -# fi # There is no need to write this output to a file #+ first, but we do so so that the check status can @@ -65,7 +66,7 @@ else echo "Python lint checks passed." fi -rm "$PEP8_SCRIPT_PATH" +# rm "$PEP8_SCRIPT_PATH" rm "$PYTHON_LINT_REPORT_PATH" exit "$lint_status" From 729885ec6b4be61144d04821f1a6e8d2134eea00 Mon Sep 17 00:00:00 2001 From: Gaurav Nanda Date: Sat, 18 Apr 2015 17:20:46 -0700 Subject: [PATCH 046/191] Fixed doc Just fixed a doc. Author: Gaurav Nanda Closes #5576 from gaurav324/master and squashes the following commits: 8a7323f [Gaurav Nanda] Fixed doc --- docs/mllib-linear-methods.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 9270741d439d9..2b2be4d9d0273 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -377,7 +377,7 @@ references. Here is an [detailed mathematical derivation](http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297). -For multiclass classification problems, the algorithm will outputs a multinomial logistic regression +For multiclass classification problems, the algorithm will output a multinomial logistic regression model, which contains $K - 1$ binary logistic regression models regressed against the first class. Given a new data points, $K - 1$ models will be run, and the class with largest probability will be chosen as the predicted class. From 8fbd45c74e762dd6b071ea58a60f5bb649f74042 Mon Sep 17 00:00:00 2001 From: Olivier Girardot Date: Sat, 18 Apr 2015 18:21:44 -0700 Subject: [PATCH 047/191] SPARK-6993 : Add default min, max methods for JavaDoubleRDD The default method will use Guava's Ordering instead of java.util.Comparator.naturalOrder() because it's not available in Java 7, only in Java 8. Author: Olivier Girardot Closes #5571 from ogirardot/master and squashes the following commits: 7fe2e9e [Olivier Girardot] SPARK-6993 : Add default min, max methods for JavaDoubleRDD --- .../org/apache/spark/api/java/JavaDoubleRDD.scala | 14 ++++++++++++++ .../test/java/org/apache/spark/JavaAPISuite.java | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 79e4ebf2db578..61af867b11b9c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -163,6 +163,20 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) /** Add up the elements in this RDD. */ def sum(): JDouble = srdd.sum() + /** + * Returns the minimum element from this RDD as defined by + * the default comparator natural order. + * @return the minimum of the RDD + */ + def min(): JDouble = min(com.google.common.collect.Ordering.natural()) + + /** + * Returns the maximum element from this RDD as defined by + * the default comparator natural order. + * @return the maximum of the RDD + */ + def max(): JDouble = max(com.google.common.collect.Ordering.natural()) + /** * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and * count of the RDD's elements in one operation. diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index d4b5bb519157c..8a4f2a08fe701 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -761,6 +761,20 @@ public void min() { Assert.assertEquals(1.0, max, 0.001); } + @Test + public void naturalMax() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.max(); + Assert.assertTrue(4.0 == max); + } + + @Test + public void naturalMin() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.min(); + Assert.assertTrue(1.0 == max); + } + @Test public void takeOrdered() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); From 0424da68d4c81dc3a9944d8485feb1233c6633c4 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sun, 19 Apr 2015 09:37:09 +0100 Subject: [PATCH 048/191] [SPARK-6963][CORE]Flaky test: o.a.s.ContextCleanerSuite automatically cleanup checkpoint cc andrewor14 Author: GuoQiang Li Closes #5548 from witgo/SPARK-6963 and squashes the following commits: 964aea7 [GuoQiang Li] review commits b08b3c9 [GuoQiang Li] Flaky test: o.a.s.ContextCleanerSuite automatically cleanup checkpoint --- .../org/apache/spark/ContextCleaner.scala | 2 ++ .../apache/spark/ContextCleanerSuite.scala | 21 +++++++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 715b259057569..37198d887b07b 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -236,6 +236,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning rdd checkpoint data " + rddId) RDDCheckpointData.clearRDDCheckpointData(sc, rddId) + listeners.foreach(_.checkpointCleaned(rddId)) logInfo("Cleaned rdd checkpoint data " + rddId) } catch { @@ -260,4 +261,5 @@ private[spark] trait CleanerListener { def shuffleCleaned(shuffleId: Int) def broadcastCleaned(broadcastId: Long) def accumCleaned(accId: Long) + def checkpointCleaned(rddId: Long) } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 097e7076e5391..c7868ddcf770f 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -224,7 +224,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { assert(fs.exists(path)) // the checkpoint is not cleaned by default (without the configuration set) - var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil) + var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Nil) rdd = null // Make RDD out of scope runGC() postGCTester.assertCleanup() @@ -245,7 +245,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) // Test that GC causes checkpoint data cleanup after dereferencing the RDD - postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil) + postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId)) rdd = null // Make RDD out of scope runGC() postGCTester.assertCleanup() @@ -406,12 +406,14 @@ class CleanerTester( sc: SparkContext, rddIds: Seq[Int] = Seq.empty, shuffleIds: Seq[Int] = Seq.empty, - broadcastIds: Seq[Long] = Seq.empty) + broadcastIds: Seq[Long] = Seq.empty, + checkpointIds: Seq[Long] = Seq.empty) extends Logging { val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds + val toBeCheckpointIds = new HashSet[Long] with SynchronizedSet[Long] ++= checkpointIds val isDistributed = !sc.isLocal val cleanerListener = new CleanerListener { @@ -427,12 +429,17 @@ class CleanerTester( def broadcastCleaned(broadcastId: Long): Unit = { toBeCleanedBroadcstIds -= broadcastId - logInfo("Broadcast" + broadcastId + " cleaned") + logInfo("Broadcast " + broadcastId + " cleaned") } def accumCleaned(accId: Long): Unit = { logInfo("Cleaned accId " + accId + " cleaned") } + + def checkpointCleaned(rddId: Long): Unit = { + toBeCheckpointIds -= rddId + logInfo("checkpoint " + rddId + " cleaned") + } } val MAX_VALIDATION_ATTEMPTS = 10 @@ -456,7 +463,8 @@ class CleanerTester( /** Verify that RDDs, shuffles, etc. occupy resources */ private def preCleanupValidate() { - assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup") + assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty || + checkpointIds.nonEmpty, "Nothing to cleanup") // Verify the RDDs have been persisted and blocks are present rddIds.foreach { rddId => @@ -547,7 +555,8 @@ class CleanerTester( private def isAllCleanedUp = toBeCleanedRDDIds.isEmpty && toBeCleanedShuffleIds.isEmpty && - toBeCleanedBroadcstIds.isEmpty + toBeCleanedBroadcstIds.isEmpty && + toBeCheckpointIds.isEmpty private def getRDDBlocks(rddId: Int): Seq[BlockId] = { blockManager.master.getMatchingBlockIds( _ match { From fa73da024000386eecef79573e8ac96d6f05b2c7 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 19 Apr 2015 20:33:51 -0700 Subject: [PATCH 049/191] [SPARK-6998][MLlib] Make StreamingKMeans 'Serializable' If `StreamingKMeans` is not `Serializable`, we cannot do checkpoint for applications that using `StreamingKMeans`. So we should make it `Serializable`. Author: zsxwing Closes #5582 from zsxwing/SPARK-6998 and squashes the following commits: 67c2a14 [zsxwing] Make StreamingKMeans 'Serializable' --- .../org/apache/spark/mllib/clustering/StreamingKMeans.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index f483fd1c7d2cf..d4606fda37b0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -165,7 +165,7 @@ class StreamingKMeansModel( class StreamingKMeans( var k: Int, var decayFactor: Double, - var timeUnit: String) extends Logging { + var timeUnit: String) extends Logging with Serializable { def this() = this(2, 1.0, StreamingKMeans.BATCHES) From d8e1b7b06c499289ff3ce5ec91ff354493a17c48 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 19 Apr 2015 20:35:43 -0700 Subject: [PATCH 050/191] [SPARK-6983][Streaming] Update ReceiverTrackerActor to use the new Rpc interface A subtask of [SPARK-5293](https://issues.apache.org/jira/browse/SPARK-5293) Author: zsxwing Closes #5557 from zsxwing/SPARK-6983 and squashes the following commits: e777e9f [zsxwing] Update ReceiverTrackerActor to use the new Rpc interface --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 2 +- .../receiver/ReceiverSupervisorImpl.scala | 52 +++++---------- .../streaming/scheduler/ReceiverInfo.scala | 4 +- .../streaming/scheduler/ReceiverTracker.scala | 64 ++++++++++--------- 4 files changed, 52 insertions(+), 70 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index f2c1c86af767e..cba038ca355d7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -258,7 +258,7 @@ private[spark] trait RpcEndpoint { final def stop(): Unit = { val _self = self if (_self != null) { - rpcEnv.stop(self) + rpcEnv.stop(_self) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 8f2f1fef76874..89af40330b9d9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -21,18 +21,16 @@ import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await -import akka.actor.{ActorRef, Actor, Props} -import akka.pattern.ask import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.rpc.{RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] @@ -63,37 +61,23 @@ private[streaming] class ReceiverSupervisorImpl( } - /** Remote Akka actor for the ReceiverTracker */ - private val trackerActor = { - val ip = env.conf.get("spark.driver.host", "localhost") - val port = env.conf.getInt("spark.driver.port", 7077) - val url = AkkaUtils.address( - AkkaUtils.protocol(env.actorSystem), - SparkEnv.driverActorSystemName, - ip, - port, - "ReceiverTracker") - env.actorSystem.actorSelection(url) - } - - /** Timeout for Akka actor messages */ - private val askTimeout = AkkaUtils.askTimeout(env.conf) + /** Remote RpcEndpointRef for the ReceiverTracker */ + private val trackerEndpoint = RpcUtils.makeDriverRef("ReceiverTracker", env.conf, env.rpcEnv) - /** Akka actor for receiving messages from the ReceiverTracker in the driver */ - private val actor = env.actorSystem.actorOf( - Props(new Actor { + /** RpcEndpointRef for receiving messages from the ReceiverTracker in the driver */ + private val endpoint = env.rpcEnv.setupEndpoint( + "Receiver-" + streamId + "-" + System.currentTimeMillis(), new ThreadSafeRpcEndpoint { + override val rpcEnv: RpcEnv = env.rpcEnv override def receive: PartialFunction[Any, Unit] = { case StopReceiver => logInfo("Received stop signal") - stop("Stopped by driver", None) + ReceiverSupervisorImpl.this.stop("Stopped by driver", None) case CleanupOldBlocks(threshTime) => logDebug("Received delete old batch signal") cleanupOldBlocks(threshTime) } - - def ref: ActorRef = self - }), "Receiver-" + streamId + "-" + System.currentTimeMillis()) + }) /** Unique block ids if one wants to add blocks directly */ private val newBlockId = new AtomicLong(System.currentTimeMillis()) @@ -162,15 +146,14 @@ private[streaming] class ReceiverSupervisorImpl( logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") val blockInfo = ReceivedBlockInfo(streamId, numRecords, blockStoreResult) - val future = trackerActor.ask(AddBlock(blockInfo))(askTimeout) - Await.result(future, askTimeout) + trackerEndpoint.askWithReply[Boolean](AddBlock(blockInfo)) logDebug(s"Reported block $blockId") } /** Report error to the receiver tracker */ def reportError(message: String, error: Throwable) { val errorString = Option(error).map(Throwables.getStackTraceAsString).getOrElse("") - trackerActor ! ReportError(streamId, message, errorString) + trackerEndpoint.send(ReportError(streamId, message, errorString)) logWarning("Reported error " + message + " - " + error) } @@ -180,22 +163,19 @@ private[streaming] class ReceiverSupervisorImpl( override protected def onStop(message: String, error: Option[Throwable]) { blockGenerator.stop() - env.actorSystem.stop(actor) + env.rpcEnv.stop(endpoint) } override protected def onReceiverStart() { val msg = RegisterReceiver( - streamId, receiver.getClass.getSimpleName, Utils.localHostName(), actor) - val future = trackerActor.ask(msg)(askTimeout) - Await.result(future, askTimeout) + streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) + trackerEndpoint.askWithReply[Boolean](msg) } override protected def onReceiverStop(message: String, error: Option[Throwable]) { logInfo("Deregistering receiver " + streamId) val errorString = error.map(Throwables.getStackTraceAsString).getOrElse("") - val future = trackerActor.ask( - DeregisterReceiver(streamId, message, errorString))(askTimeout) - Await.result(future, askTimeout) + trackerEndpoint.askWithReply[Boolean](DeregisterReceiver(streamId, message, errorString)) logInfo("Stopped receiver " + streamId) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index d7e39c528c519..52f08b9c9de68 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.scheduler -import akka.actor.ActorRef import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef /** * :: DeveloperApi :: @@ -28,7 +28,7 @@ import org.apache.spark.annotation.DeveloperApi case class ReceiverInfo( streamId: Int, name: String, - private[streaming] val actor: ActorRef, + private[streaming] val endpoint: RpcEndpointRef, active: Boolean, location: String, lastErrorMessage: String = "", diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 98900473138fe..c4ead6f30a63d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,13 +17,11 @@ package org.apache.spark.streaming.scheduler - import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials -import akka.actor._ - import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} +import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} @@ -36,7 +34,7 @@ private[streaming] case class RegisterReceiver( streamId: Int, typ: String, host: String, - receiverActor: ActorRef + receiverEndpoint: RpcEndpointRef ) extends ReceiverTrackerMessage private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo) extends ReceiverTrackerMessage @@ -67,19 +65,19 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false ) private val listenerBus = ssc.scheduler.listenerBus - // actor is created when generator starts. + // endpoint is created when generator starts. // This not being null means the tracker has been started and not stopped - private var actor: ActorRef = null + private var endpoint: RpcEndpointRef = null - /** Start the actor and receiver execution thread. */ + /** Start the endpoint and receiver execution thread. */ def start(): Unit = synchronized { - if (actor != null) { + if (endpoint != null) { throw new SparkException("ReceiverTracker already started") } if (!receiverInputStreams.isEmpty) { - actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor), - "ReceiverTracker") + endpoint = ssc.env.rpcEnv.setupEndpoint( + "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv)) if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") } @@ -87,13 +85,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Stop the receiver execution thread. */ def stop(graceful: Boolean): Unit = synchronized { - if (!receiverInputStreams.isEmpty && actor != null) { + if (!receiverInputStreams.isEmpty && endpoint != null) { // First, stop the receivers if (!skipReceiverLaunch) receiverExecutor.stop(graceful) - // Finally, stop the actor - ssc.env.actorSystem.stop(actor) - actor = null + // Finally, stop the endpoint + ssc.env.rpcEnv.stop(endpoint) + endpoint = null receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") } @@ -129,8 +127,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Signal the receivers to delete old block data if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") - receiverInfo.values.flatMap { info => Option(info.actor) } - .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) } + receiverInfo.values.flatMap { info => Option(info.endpoint) } + .foreach { _.send(CleanupOldBlocks(cleanupThreshTime)) } } } @@ -139,23 +137,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false streamId: Int, typ: String, host: String, - receiverActor: ActorRef, - sender: ActorRef + receiverEndpoint: RpcEndpointRef, + senderAddress: RpcAddress ) { if (!receiverInputStreamIds.contains(streamId)) { throw new SparkException("Register received for unexpected id " + streamId) } receiverInfo(streamId) = ReceiverInfo( - streamId, s"${typ}-${streamId}", receiverActor, true, host) + streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) - logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address) + logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) } /** Deregister a receiver */ private def deregisterReceiver(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => - oldInfo.copy(actor = null, active = false, lastErrorMessage = message, lastError = error) + oldInfo.copy(endpoint = null, active = false, lastErrorMessage = message, lastError = error) case None => logWarning("No prior receiver info") ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) @@ -199,19 +197,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false receivedBlockTracker.hasUnallocatedReceivedBlocks } - /** Actor to receive messages from the receivers. */ - private class ReceiverTrackerActor extends Actor { + /** RpcEndpoint to receive messages from the receivers. */ + private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { + override def receive: PartialFunction[Any, Unit] = { - case RegisterReceiver(streamId, typ, host, receiverActor) => - registerReceiver(streamId, typ, host, receiverActor, sender) - sender ! true - case AddBlock(receivedBlockInfo) => - sender ! addBlock(receivedBlockInfo) case ReportError(streamId, message, error) => reportError(streamId, message, error) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterReceiver(streamId, typ, host, receiverEndpoint) => + registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) + context.reply(true) + case AddBlock(receivedBlockInfo) => + context.reply(addBlock(receivedBlockInfo)) case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) - sender ! true + context.reply(true) } } @@ -314,8 +316,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Stops the receivers. */ private def stopReceivers() { // Signal the receivers to stop - receiverInfo.values.flatMap { info => Option(info.actor)} - .foreach { _ ! StopReceiver } + receiverInfo.values.flatMap { info => Option(info.endpoint)} + .foreach { _.send(StopReceiver) } logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") } } From c776ee8a6fdcdc463746a815b7686e4e33a874a9 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 19 Apr 2015 20:48:36 -0700 Subject: [PATCH 051/191] [SPARK-6979][Streaming] Replace JobScheduler.eventActor and JobGenerator.eventActor with EventLoop Title says it all. cc rxin tdas Author: zsxwing Closes #5554 from zsxwing/SPARK-6979 and squashes the following commits: 5304350 [zsxwing] Fix NotSerializableException e9d3479 [zsxwing] Add blank lines 633e279 [zsxwing] Fix NotSerializableException e496ace [zsxwing] Replace JobGenerator.eventActor with EventLoop ec6ec58 [zsxwing] Fix the import order ce0fa73 [zsxwing] Replace JobScheduler.eventActor with EventLoop --- .../mllib/clustering/StreamingKMeans.scala | 3 +- .../streaming/scheduler/JobGenerator.scala | 38 +++++++++--------- .../streaming/scheduler/JobScheduler.scala | 40 ++++++++++--------- 3 files changed, 42 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index d4606fda37b0d..812014a041719 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.clustering import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.{Experimental, DeveloperApi} +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.DStream diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 58e56638a2dca..2467d50839add 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -19,12 +19,10 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Success, Try} -import akka.actor.{ActorRef, Props, Actor} - import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.{Clock, ManualClock, Utils} +import org.apache.spark.util.{Clock, EventLoop, ManualClock} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent @@ -58,7 +56,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => eventActor ! GenerateJobs(new Time(longTime)), "JobGenerator") + longTime => eventLoop.post(GenerateJobs(new Time(longTime))), "JobGenerator") // This is marked lazy so that this is initialized after checkpoint duration has been set // in the context and the generator has been started. @@ -70,22 +68,26 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { null } - // eventActor is created when generator starts. + // eventLoop is created when generator starts. // This not being null means the scheduler has been started and not stopped - private var eventActor: ActorRef = null + private var eventLoop: EventLoop[JobGeneratorEvent] = null // last batch whose completion,checkpointing and metadata cleanup has been completed private var lastProcessedBatch: Time = null /** Start generation of jobs */ def start(): Unit = synchronized { - if (eventActor != null) return // generator has already been started + if (eventLoop != null) return // generator has already been started + + eventLoop = new EventLoop[JobGeneratorEvent]("JobGenerator") { + override protected def onReceive(event: JobGeneratorEvent): Unit = processEvent(event) - eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - override def receive: PartialFunction[Any, Unit] = { - case event: JobGeneratorEvent => processEvent(event) + override protected def onError(e: Throwable): Unit = { + jobScheduler.reportError("Error in job generator", e) } - }), "JobGenerator") + } + eventLoop.start() + if (ssc.isCheckpointPresent) { restart() } else { @@ -99,7 +101,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { * checkpoints written. */ def stop(processReceivedData: Boolean): Unit = synchronized { - if (eventActor == null) return // generator has already been stopped + if (eventLoop == null) return // generator has already been stopped if (processReceivedData) { logInfo("Stopping JobGenerator gracefully") @@ -146,9 +148,9 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { graph.stop() } - // Stop the actor and checkpoint writer + // Stop the event loop and checkpoint writer if (shouldCheckpoint) checkpointWriter.stop() - ssc.env.actorSystem.stop(eventActor) + eventLoop.stop() logInfo("Stopped JobGenerator") } @@ -156,7 +158,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { * Callback called when a batch has been completely processed. */ def onBatchCompletion(time: Time) { - eventActor ! ClearMetadata(time) + eventLoop.post(ClearMetadata(time)) } /** @@ -164,7 +166,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { */ def onCheckpointCompletion(time: Time, clearCheckpointDataLater: Boolean) { if (clearCheckpointDataLater) { - eventActor ! ClearCheckpointData(time) + eventLoop.post(ClearCheckpointData(time)) } } @@ -247,7 +249,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } - eventActor ! DoCheckpoint(time, clearCheckpointDataLater = false) + eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = false)) } /** Clear DStream metadata for the given `time`. */ @@ -257,7 +259,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed if (shouldCheckpoint) { - eventActor ! DoCheckpoint(time, clearCheckpointDataLater = true) + eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = true)) } else { // If checkpointing is not enabled, then delete metadata information about // received blocks (block data not saved in any case). Otherwise, wait for diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 95f1857b4c377..508b89278dcba 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,13 +17,15 @@ package org.apache.spark.streaming.scheduler -import scala.util.{Failure, Success, Try} -import scala.collection.JavaConversions._ import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} -import akka.actor.{ActorRef, Actor, Props} -import org.apache.spark.{SparkException, Logging, SparkEnv} + +import scala.collection.JavaConversions._ +import scala.util.{Failure, Success} + +import org.apache.spark.Logging import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ +import org.apache.spark.util.EventLoop private[scheduler] sealed trait JobSchedulerEvent @@ -46,20 +48,20 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { val listenerBus = new StreamingListenerBus() // These two are created only when scheduler starts. - // eventActor not being null means the scheduler has been started and not stopped + // eventLoop not being null means the scheduler has been started and not stopped var receiverTracker: ReceiverTracker = null - private var eventActor: ActorRef = null - + private var eventLoop: EventLoop[JobSchedulerEvent] = null def start(): Unit = synchronized { - if (eventActor != null) return // scheduler has already been started + if (eventLoop != null) return // scheduler has already been started logDebug("Starting JobScheduler") - eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - override def receive: PartialFunction[Any, Unit] = { - case event: JobSchedulerEvent => processEvent(event) - } - }), "JobScheduler") + eventLoop = new EventLoop[JobSchedulerEvent]("JobScheduler") { + override protected def onReceive(event: JobSchedulerEvent): Unit = processEvent(event) + + override protected def onError(e: Throwable): Unit = reportError("Error in job scheduler", e) + } + eventLoop.start() listenerBus.start(ssc.sparkContext) receiverTracker = new ReceiverTracker(ssc) @@ -69,7 +71,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } def stop(processAllReceivedData: Boolean): Unit = synchronized { - if (eventActor == null) return // scheduler has already been stopped + if (eventLoop == null) return // scheduler has already been stopped logDebug("Stopping JobScheduler") // First, stop receiving @@ -96,8 +98,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // Stop everything else listenerBus.stop() - ssc.env.actorSystem.stop(eventActor) - eventActor = null + eventLoop.stop() + eventLoop = null logInfo("Stopped JobScheduler") } @@ -117,7 +119,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } def reportError(msg: String, e: Throwable) { - eventActor ! ErrorReported(msg, e) + eventLoop.post(ErrorReported(msg, e)) } private def processEvent(event: JobSchedulerEvent) { @@ -172,14 +174,14 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private class JobHandler(job: Job) extends Runnable { def run() { - eventActor ! JobStarted(job) + eventLoop.post(JobStarted(job)) // Disable checks for existing output directories in jobs launched by the streaming scheduler, // since we may need to write output to an existing directory during checkpoint recovery; // see SPARK-4835 for more details. PairRDDFunctions.disableOutputSpecValidation.withValue(true) { job.run() } - eventActor ! JobCompleted(job) + eventLoop.post(JobCompleted(job)) } } } From 6fe690d5a8216ba7efde4b52e7a19fb00814341c Mon Sep 17 00:00:00 2001 From: dobashim Date: Mon, 20 Apr 2015 00:03:23 -0400 Subject: [PATCH 052/191] [doc][mllib] Fix typo of the page title in Isotonic regression documents * Fix the page title in Isotonic regression documents (Naive Bayes -> Isotonic regression) * Add a newline character at the end of the file Author: dobashim Closes #5581 from dobashim/master and squashes the following commits: d54a041 [dobashim] Fix typo of the page title in Isotonic regression documents --- docs/mllib-isotonic-regression.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 12fb29d426741..b521c2f27cd6e 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -1,6 +1,6 @@ --- layout: global -title: Naive Bayes - MLlib +title: Isotonic regression - MLlib displayTitle: MLlib - Regression --- @@ -152,4 +152,4 @@ Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( System.out.println("Mean Squared Error = " + meanSquaredError); {% endhighlight %} - \ No newline at end of file + From 1be207078cef48c5935595969bf9f6b1ec1334ca Mon Sep 17 00:00:00 2001 From: jrabary Date: Mon, 20 Apr 2015 09:47:56 -0700 Subject: [PATCH 053/191] [SPARK-5924] Add the ability to specify withMean or withStd parameters with StandarScaler The current implementation call the default constructor of mllib.feature.StandarScaler without the possibility to specify withMean or withStd options. Author: jrabary Closes #4704 from jrabary/master and squashes the following commits: fae8568 [jrabary] style fix 8896b0e [jrabary] Comments fix ef96d73 [jrabary] style fix 8e52607 [jrabary] style fix edd9d48 [jrabary] Fix default param initialization 17e1a76 [jrabary] Fix default param initialization 298f405 [jrabary] Typo fix 45ed914 [jrabary] Add withMean and withStd params to StandarScaler --- .../spark/ml/feature/StandardScaler.scala | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1b102619b3524..447851ec034d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -30,7 +30,22 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * Params for [[StandardScaler]] and [[StandardScalerModel]]. */ -private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol +private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { + + /** + * False by default. Centers the data with mean before scaling. + * It will build a dense output, so this does not work on sparse input + * and will raise an exception. + * @group param + */ + val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") + + /** + * True by default. Scales the data to unit standard deviation. + * @group param + */ + val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") +} /** * :: AlphaComponent :: @@ -40,18 +55,27 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with @AlphaComponent class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { + setDefault(withMean -> false, withStd -> true) + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - + + /** @group setParam */ + def setWithMean(value: Boolean): this.type = set(withMean, value) + + /** @group setParam */ + def setWithStd(value: Boolean): this.type = set(withStd, value) + override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) val map = extractParamMap(paramMap) val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } - val scaler = new feature.StandardScaler().fit(input) - val model = new StandardScalerModel(this, map, scaler) + val scaler = new feature.StandardScaler(withMean = map(withMean), withStd = map(withStd)) + val scalerModel = scaler.fit(input) + val model = new StandardScalerModel(this, map, scalerModel) Params.inheritValues(map, this, model) model } From 968ad972175390bb0a96918fd3c7b318d70fa466 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Mon, 20 Apr 2015 09:54:21 -0700 Subject: [PATCH 054/191] [SPARK-7003] Improve reliability of connection failure detection between Netty block transfer service endpoints Currently we rely on the assumption that an exception will be raised and the channel closed if two endpoints cannot communicate over a Netty TCP channel. However, this guarantee does not hold in all network environments, and [SPARK-6962](https://issues.apache.org/jira/browse/SPARK-6962) seems to point to a case where only the server side of the connection detected a fault. This patch improves robustness of fetch/rpc requests by having an explicit timeout in the transport layer which closes the connection if there is a period of inactivity while there are outstanding requests. NB: This patch is actually only around 50 lines added if you exclude the testing-related code. Author: Aaron Davidson Closes #5584 from aarondav/timeout and squashes the following commits: 8699680 [Aaron Davidson] Address Reynold's comments 37ce656 [Aaron Davidson] [SPARK-7003] Improve reliability of connection failure detection between Netty block transfer service endpoints --- .../spark/network/TransportContext.java | 5 +- .../client/TransportResponseHandler.java | 14 +- .../server/TransportChannelHandler.java | 33 ++- .../spark/network/util/MapConfigProvider.java | 41 +++ .../apache/spark/network/util/NettyUtils.java | 2 +- .../RequestTimeoutIntegrationSuite.java | 277 ++++++++++++++++++ .../network/TransportClientFactorySuite.java | 21 +- 7 files changed, 375 insertions(+), 18 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java create mode 100644 network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index f0a89c9d9116c..3fe69b1bd8851 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -22,6 +22,7 @@ import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; +import io.netty.handler.timeout.IdleStateHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -106,6 +107,7 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) { .addLast("encoder", encoder) .addLast("frameDecoder", NettyUtils.createFrameDecoder()) .addLast("decoder", decoder) + .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this // would require more logic to guarantee if this were not part of the same event loop. .addLast("handler", channelHandler); @@ -126,7 +128,8 @@ private TransportChannelHandler createChannelHandler(Channel channel) { TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler); - return new TransportChannelHandler(client, responseHandler, requestHandler); + return new TransportChannelHandler(client, responseHandler, requestHandler, + conf.connectionTimeoutMs()); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 2044afb0d85db..94fc21af5e606 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -20,8 +20,8 @@ import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; -import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,13 +50,18 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; + /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ + private final AtomicLong timeOfLastRequestNs; + public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap(); this.outstandingRpcs = new ConcurrentHashMap(); + this.timeOfLastRequestNs = new AtomicLong(0); } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); outstandingFetches.put(streamChunkId, callback); } @@ -65,6 +70,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { } public void addRpcRequest(long requestId, RpcResponseCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); outstandingRpcs.put(requestId, callback); } @@ -161,8 +167,12 @@ public void handle(ResponseMessage message) { } /** Returns total number of outstanding requests (fetch requests + rpcs) */ - @VisibleForTesting public int numOutstandingRequests() { return outstandingFetches.size() + outstandingRpcs.size(); } + + /** Returns the time in nanoseconds of when the last request was sent out. */ + public long getTimeOfLastRequestNs() { + return timeOfLastRequestNs.get(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index e491367fa4528..8e0ee709e38e3 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -19,6 +19,8 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.timeout.IdleState; +import io.netty.handler.timeout.IdleStateEvent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,6 +42,11 @@ * Client. * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, * for the Client's responses to the Server's requests. + * + * This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}. + * We consider a connection timed out if there are outstanding fetch or RPC requests but no traffic + * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not + * timeout if the client is continuously sending but getting no responses, for simplicity. */ public class TransportChannelHandler extends SimpleChannelInboundHandler { private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); @@ -47,14 +54,17 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler 0; + boolean isActuallyOverdue = + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; + if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + ctx.close(); + } + } + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java new file mode 100644 index 0000000000000..668d2356b955d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import com.google.common.collect.Maps; + +import java.util.Map; +import java.util.NoSuchElementException; + +/** ConfigProvider based on a Map (copied in the constructor). */ +public class MapConfigProvider extends ConfigProvider { + private final Map config; + + public MapConfigProvider(Map config) { + this.config = Maps.newHashMap(config); + } + + @Override + public String get(String name) { + String value = config.get(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index dabd6261d2aa0..26c6399ce7dbc 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -98,7 +98,7 @@ public static ByteToMessageDecoder createFrameDecoder() { return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); } - /** Returns the remote address on the channel or "<remote address>" if none exists. */ + /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ public static String getRemoteAddress(Channel channel) { if (channel != null && channel.remoteAddress() != null) { return channel.remoteAddress().toString(); diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java new file mode 100644 index 0000000000000..84ebb337e6d54 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import com.google.common.collect.Maps; +import com.google.common.util.concurrent.Uninterruptibles; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; +import org.junit.*; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +/** + * Suite which ensures that requests that go without a response for the network timeout period are + * failed, and the connection closed. + * + * In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests, + * to ensure stability in different test environments. + */ +public class RequestTimeoutIntegrationSuite { + + private TransportServer server; + private TransportClientFactory clientFactory; + + private StreamManager defaultManager; + private TransportConf conf; + + // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever. + private final int FOREVER = 60 * 1000; + + @Before + public void setUp() throws Exception { + Map configMap = Maps.newHashMap(); + configMap.put("spark.shuffle.io.connectionTimeout", "2s"); + conf = new TransportConf(new MapConfigProvider(configMap)); + + defaultManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + throw new UnsupportedOperationException(); + } + }; + } + + @After + public void tearDown() { + if (server != null) { + server.close(); + } + if (clientFactory != null) { + clientFactory.close(); + } + } + + // Basic suite: First request completes quickly, and second waits for longer than network timeout. + @Test + public void timeoutInactiveRequests() throws Exception { + final Semaphore semaphore = new Semaphore(1); + final byte[] response = new byte[16]; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + try { + semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); + callback.onSuccess(response); + } catch (InterruptedException e) { + // do nothing + } + } + + @Override + public StreamManager getStreamManager() { + return defaultManager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + // First completes quickly (semaphore starts at 1). + TestCallback callback0 = new TestCallback(); + synchronized (callback0) { + client.sendRpc(new byte[0], callback0); + callback0.wait(FOREVER); + assert (callback0.success.length == response.length); + } + + // Second times out after 2 seconds, with slack. Must be IOException. + TestCallback callback1 = new TestCallback(); + synchronized (callback1) { + client.sendRpc(new byte[0], callback1); + callback1.wait(4 * 1000); + assert (callback1.failure != null); + assert (callback1.failure instanceof IOException); + } + semaphore.release(); + } + + // A timeout will cause the connection to be closed, invalidating the current TransportClient. + // It should be the case that requesting a client from the factory produces a new, valid one. + @Test + public void timeoutCleanlyClosesClient() throws Exception { + final Semaphore semaphore = new Semaphore(0); + final byte[] response = new byte[16]; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + try { + semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); + callback.onSuccess(response); + } catch (InterruptedException e) { + // do nothing + } + } + + @Override + public StreamManager getStreamManager() { + return defaultManager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + + // First request should eventually fail. + TransportClient client0 = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TestCallback callback0 = new TestCallback(); + synchronized (callback0) { + client0.sendRpc(new byte[0], callback0); + callback0.wait(FOREVER); + assert (callback0.failure instanceof IOException); + assert (!client0.isActive()); + } + + // Increment the semaphore and the second request should succeed quickly. + semaphore.release(2); + TransportClient client1 = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TestCallback callback1 = new TestCallback(); + synchronized (callback1) { + client1.sendRpc(new byte[0], callback1); + callback1.wait(FOREVER); + assert (callback1.success.length == response.length); + assert (callback1.failure == null); + } + } + + // The timeout is relative to the LAST request sent, which is kinda weird, but still. + // This test also makes sure the timeout works for Fetch requests as well as RPCs. + @Test + public void furtherRequestsDelay() throws Exception { + final byte[] response = new byte[16]; + final StreamManager manager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS); + return new NioManagedBuffer(ByteBuffer.wrap(response)); + } + }; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return manager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + // Send one request, which will eventually fail. + TestCallback callback0 = new TestCallback(); + client.fetchChunk(0, 0, callback0); + Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); + + // Send a second request before the first has failed. + TestCallback callback1 = new TestCallback(); + client.fetchChunk(0, 1, callback1); + Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); + + synchronized (callback0) { + // not complete yet, but should complete soon + assert (callback0.success == null && callback0.failure == null); + callback0.wait(2 * 1000); + assert (callback0.failure instanceof IOException); + } + + synchronized (callback1) { + // failed at same time as previous + assert (callback0.failure instanceof IOException); + } + } + + /** + * Callback which sets 'success' or 'failure' on completion. + * Additionally notifies all waiters on this callback when invoked. + */ + class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { + + byte[] success; + Throwable failure; + + @Override + public void onSuccess(byte[] response) { + synchronized(this) { + success = response; + this.notifyAll(); + } + } + + @Override + public void onFailure(Throwable e) { + synchronized(this) { + failure = e; + this.notifyAll(); + } + } + + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + synchronized(this) { + try { + success = buffer.nioByteBuffer().array(); + this.notifyAll(); + } catch (IOException e) { + // weird + } + } + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + synchronized(this) { + failure = e; + this.notifyAll(); + } + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 416dc1b969fa4..35de5e57ccb98 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -20,10 +20,11 @@ import java.io.IOException; import java.util.Collections; import java.util.HashSet; -import java.util.NoSuchElementException; +import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; +import com.google.common.collect.Maps; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -36,9 +37,9 @@ import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.ConfigProvider; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class TransportClientFactorySuite { @@ -70,16 +71,10 @@ public void tearDown() { */ private void testClientReuse(final int maxConnections, boolean concurrent) throws IOException, InterruptedException { - TransportConf conf = new TransportConf(new ConfigProvider() { - @Override - public String get(String name) { - if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) { - return Integer.toString(maxConnections); - } else { - throw new NoSuchElementException(); - } - } - }); + + Map configMap = Maps.newHashMap(); + configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); + TransportConf conf = new TransportConf(new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); From 77176619a97d07811ab20e1dde4677359d85eb33 Mon Sep 17 00:00:00 2001 From: Elisey Zanko Date: Mon, 20 Apr 2015 10:44:09 -0700 Subject: [PATCH 055/191] [SPARK-6661] Python type errors should print type, not object Author: Elisey Zanko Closes #5361 from 31z4/spark-6661 and squashes the following commits: 73c5d79 [Elisey Zanko] Python type errors should print type, not object --- python/pyspark/accumulators.py | 2 +- python/pyspark/context.py | 2 +- python/pyspark/ml/param/__init__.py | 2 +- python/pyspark/ml/pipeline.py | 4 ++-- python/pyspark/mllib/linalg.py | 4 ++-- python/pyspark/mllib/regression.py | 2 +- python/pyspark/mllib/tests.py | 6 ++++-- python/pyspark/sql/_types.py | 12 ++++++------ python/pyspark/sql/context.py | 8 ++++---- python/pyspark/sql/dataframe.py | 2 +- 10 files changed, 23 insertions(+), 21 deletions(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 7271809e43880..0d21a132048a5 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -83,7 +83,7 @@ >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... -Exception:... +TypeError:... """ import sys diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1dc2fec0ae5c8..6a743ac8bd600 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -671,7 +671,7 @@ def accumulator(self, value, accum_param=None): elif isinstance(value, complex): accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM else: - raise Exception("No default accumulator param for type %s" % type(value)) + raise TypeError("No default accumulator param for type %s" % type(value)) SparkContext._next_accum_id += 1 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 9fccb65675185..49c20b4cf70cf 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -30,7 +30,7 @@ class Param(object): def __init__(self, parent, name, doc): if not isinstance(parent, Params): - raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__) + raise TypeError("Parent must be a Params but got type %s." % type(parent)) self.parent = parent self.name = str(name) self.doc = str(doc) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index d94ecfff09f66..7c1ec3026da6f 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -131,8 +131,8 @@ def fit(self, dataset, params={}): stages = paramMap[self.stages] for stage in stages: if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): - raise ValueError( - "Cannot recognize a pipeline stage of type %s." % type(stage).__name__) + raise TypeError( + "Cannot recognize a pipeline stage of type %s." % type(stage)) indexOfLastEstimator = -1 for i, stage in enumerate(stages): if isinstance(stage, Estimator): diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 38b3aa3ad460e..ec8c879ea9389 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -145,7 +145,7 @@ def serialize(self, obj): values = [float(v) for v in obj] return (1, None, None, values) else: - raise ValueError("cannot serialize %r of type %r" % (obj, type(obj))) + raise TypeError("cannot serialize %r of type %r" % (obj, type(obj))) def deserialize(self, datum): assert len(datum) == 4, \ @@ -561,7 +561,7 @@ def __getitem__(self, index): inds = self.indices vals = self.values if not isinstance(index, int): - raise ValueError( + raise TypeError( "Indices must be of type integer, got type %s" % type(index)) if index < 0: index += self.size diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index cd7310a64f4ae..a0117c57133ae 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -170,7 +170,7 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): from pyspark.mllib.classification import LogisticRegressionModel first = data.first() if not isinstance(first, LabeledPoint): - raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) + raise TypeError("data should be an RDD of LabeledPoint, but got %s" % type(first)) if initial_weights is None: initial_weights = [0.0] * len(data.first().features) if (modelClass == LogisticRegressionModel): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index c6ed5acd1770e..849c88341a967 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -135,8 +135,10 @@ def test_sparse_vector_indexing(self): self.assertEquals(sv[-1], 2) self.assertEquals(sv[-2], 0) self.assertEquals(sv[-4], 0) - for ind in [4, -5, 7.8]: + for ind in [4, -5]: self.assertRaises(ValueError, sv.__getitem__, ind) + for ind in [7.8, '1']: + self.assertRaises(TypeError, sv.__getitem__, ind) def test_matrix_indexing(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -450,7 +452,7 @@ def test_infer_schema(self): elif isinstance(v, DenseVector): self.assertEqual(v, self.dv1) else: - raise ValueError("expecting a vector but got %r of type %r" % (v, type(v))) + raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) @unittest.skipIf(not _have_scipy, "SciPy not installed") diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py index 492c0cbdcf693..110d1152fbdb6 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/_types.py @@ -562,8 +562,8 @@ def _infer_type(obj): else: try: return _infer_schema(obj) - except ValueError: - raise ValueError("not supported type: %s" % type(obj)) + except TypeError: + raise TypeError("not supported type: %s" % type(obj)) def _infer_schema(row): @@ -584,7 +584,7 @@ def _infer_schema(row): items = sorted(row.__dict__.items()) else: - raise ValueError("Can not infer schema for type: %s" % type(row)) + raise TypeError("Can not infer schema for type: %s" % type(row)) fields = [StructField(k, _infer_type(v), True) for k, v in items] return StructType(fields) @@ -696,7 +696,7 @@ def _merge_type(a, b): return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (a, b)) + raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) # same type if isinstance(a, StructType): @@ -773,7 +773,7 @@ def convert_struct(obj): elif hasattr(obj, "__dict__"): # object d = obj.__dict__ else: - raise ValueError("Unexpected obj: %s" % obj) + raise TypeError("Unexpected obj type: %s" % type(obj)) if convert_fields: return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) @@ -912,7 +912,7 @@ def _infer_schema_type(obj, dataType): return StructType(fields) else: - raise ValueError("Unexpected dataType: %s" % dataType) + raise TypeError("Unexpected dataType: %s" % type(dataType)) _acceptable_types = { diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c90afc326ca0e..acf3c114548c0 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -208,7 +208,7 @@ def applySchema(self, rdd, schema): raise TypeError("Cannot apply schema to DataFrame") if not isinstance(schema, StructType): - raise TypeError("schema should be StructType, but got %s" % schema) + raise TypeError("schema should be StructType, but got %s" % type(schema)) return self.createDataFrame(rdd, schema) @@ -281,7 +281,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): # data could be list, tuple, generator ... rdd = self._sc.parallelize(data) except Exception: - raise ValueError("cannot create an RDD from type: %s" % type(data)) + raise TypeError("cannot create an RDD from type: %s" % type(data)) else: rdd = data @@ -293,8 +293,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): if isinstance(schema, (list, tuple)): first = rdd.first() if not isinstance(first, (list, tuple)): - raise ValueError("each row in `rdd` should be list or tuple, " - "but got %r" % type(first)) + raise TypeError("each row in `rdd` should be list or tuple, " + "but got %r" % type(first)) row_cls = Row(*schema) schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d70c5b0a6930c..75c181c0c7f5e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -608,7 +608,7 @@ def __getitem__(self, item): jc = self._jdf.apply(self.columns[item]) return Column(jc) else: - raise TypeError("unexpected type: %s" % type(item)) + raise TypeError("unexpected item type: %s" % type(item)) def __getattr__(self, name): """Returns the :class:`Column` denoted by ``name``. From 1ebceaa55bec28850a48fb28b4cf7b44c8448a78 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 20 Apr 2015 10:47:37 -0700 Subject: [PATCH 056/191] [Minor][MLlib] Incorrect path to test data is used in DecisionTreeExample It should load from `testInput` instead of `input` for test data. Author: Liang-Chi Hsieh Closes #5594 from viirya/use_testinput and squashes the following commits: 5e8b174 [Liang-Chi Hsieh] Fix style. b60b475 [Liang-Chi Hsieh] Use testInput. --- .../org/apache/spark/examples/ml/DecisionTreeExample.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index d4cc8dede07ef..921b396e799e7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -173,7 +173,8 @@ object DecisionTreeExample { val splits: Array[RDD[LabeledPoint]] = if (testInput != "") { // Load testInput. val numFeatures = origExamples.take(1)(0).features.size - val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures)) + val origTestExamples: RDD[LabeledPoint] = + loadData(sc, testInput, dataFormat, Some(numFeatures)) Array(origExamples, origTestExamples) } else { // Split input into training, test. From 97fda73db4efda2ba5b12937954de428258a5b56 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Mon, 20 Apr 2015 13:11:21 -0700 Subject: [PATCH 057/191] fixed doc The contribution is my original work. I license the work to the project under the project's open source license. Small typo in the programming guide. Author: Eric Chiang Closes #5599 from ericchiang/docs-typo and squashes the following commits: 1177942 [Eric Chiang] fixed doc --- docs/programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index f4fabb0927b66..27816515c5de2 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1093,7 +1093,7 @@ for details. ### Shuffle operations Certain operations within Spark trigger an event known as the shuffle. The shuffle is Spark's -mechanism for re-distributing data so that is grouped differently across partitions. This typically +mechanism for re-distributing data so that it's grouped differently across partitions. This typically involves copying data across executors and machines, making the shuffle a complex and costly operation. From 517bdf36aecdc94ef569b68f0a96892e707b5c7b Mon Sep 17 00:00:00 2001 From: BenFradet Date: Mon, 20 Apr 2015 13:46:55 -0700 Subject: [PATCH 058/191] [doc][streaming] Fixed broken link in mllib section The commit message is pretty self-explanatory. Author: BenFradet Closes #5600 from BenFradet/master and squashes the following commits: 108492d [BenFradet] [doc][streaming] Fixed broken link in mllib section --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 262512a639046..2f2fea53168a3 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1588,7 +1588,7 @@ See the [DataFrames and SQL](sql-programming-guide.html) guide to learn more abo *** ## MLlib Operations -You can also easily use machine learning algorithms provided by [MLlib](mllib-guide.html). First of all, there are streaming machine learning algorithms (e.g. (Streaming Linear Regression](mllib-linear-methods.html#streaming-linear-regression), [Streaming KMeans](mllib-clustering.html#streaming-k-means), etc.) which can simultaneously learn from the streaming data as well as apply the model on the streaming data. Beyond these, for a much larger class of machine learning algorithms, you can learn a learning model offline (i.e. using historical data) and then apply the model online on streaming data. See the [MLlib](mllib-guide.html) guide for more details. +You can also easily use machine learning algorithms provided by [MLlib](mllib-guide.html). First of all, there are streaming machine learning algorithms (e.g. [Streaming Linear Regression](mllib-linear-methods.html#streaming-linear-regression), [Streaming KMeans](mllib-clustering.html#streaming-k-means), etc.) which can simultaneously learn from the streaming data as well as apply the model on the streaming data. Beyond these, for a much larger class of machine learning algorithms, you can learn a learning model offline (i.e. using historical data) and then apply the model online on streaming data. See the [MLlib](mllib-guide.html) guide for more details. *** From ce7ddabbcd330b19f6d0c17082304dfa6e1621b2 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 20 Apr 2015 18:42:50 -0700 Subject: [PATCH 059/191] [SPARK-6368][SQL] Build a specialized serializer for Exchange operator. JIRA: https://issues.apache.org/jira/browse/SPARK-6368 Author: Yin Huai Closes #5497 from yhuai/serializer2 and squashes the following commits: da562c5 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2 50e0c3d [Yin Huai] When no filed is emitted to shuffle, use SparkSqlSerializer for now. 9f1ed92 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2 6d07678 [Yin Huai] Address comments. 4273b8c [Yin Huai] Enabled SparkSqlSerializer2. 09e587a [Yin Huai] Remove TODO. 791b96a [Yin Huai] Use UTF8String. 60a1487 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2 3e09655 [Yin Huai] Use getAs for Date column. 43b9fb4 [Yin Huai] Test. 8297732 [Yin Huai] Fix test. c9373c8 [Yin Huai] Support DecimalType. 2379eeb [Yin Huai] ASF header. 39704ab [Yin Huai] Specialized serializer for Exchange. --- .../scala/org/apache/spark/sql/SQLConf.scala | 4 + .../apache/spark/sql/execution/Exchange.scala | 59 ++- .../sql/execution/SparkSqlSerializer2.scala | 421 ++++++++++++++++++ .../execution/SparkSqlSerializer2Suite.scala | 195 ++++++++ 4 files changed, 673 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 5c65f04ee8497..4fc5de7e824fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -64,6 +64,8 @@ private[spark] object SQLConf { // Set to false when debugging requires the ability to look at invalid query plans. val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" + val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2" + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -147,6 +149,8 @@ private[sql] class SQLConf extends Serializable { */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean + /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to * a broadcast value during the physical executions of join operations. Setting this to -1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 69a620e1ec929..5b2e46962cd3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} +import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner} import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.serializer.Serializer import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.DataType import org.apache.spark.util.MutablePair object Exchange { @@ -77,9 +79,48 @@ case class Exchange( } } - override def execute(): RDD[Row] = attachTree(this , "execute") { - lazy val sparkConf = child.sqlContext.sparkContext.getConf + @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf + + def serializer( + keySchema: Array[DataType], + valueSchema: Array[DataType], + numPartitions: Int): Serializer = { + // In ExternalSorter's spillToMergeableFile function, key-value pairs are written out + // through write(key) and then write(value) instead of write((key, value)). Because + // SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use + // it when spillToMergeableFile in ExternalSorter will be used. + // So, we will not use SparkSqlSerializer2 when + // - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater + // then the bypassMergeThreshold; or + // - newOrdering is defined. + val cannotUseSqlSerializer2 = + (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty + + // It is true when there is no field that needs to be write out. + // For now, we will not use SparkSqlSerializer2 when noField is true. + val noField = + (keySchema == null || keySchema.length == 0) && + (valueSchema == null || valueSchema.length == 0) + + val useSqlSerializer2 = + child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. + !cannotUseSqlSerializer2 && // Safe to use Serializer2. + SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. + SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported. + !noField + + val serializer = if (useSqlSerializer2) { + logInfo("Using SparkSqlSerializer2.") + new SparkSqlSerializer2(keySchema, valueSchema) + } else { + logInfo("Using SparkSqlSerializer.") + new SparkSqlSerializer(sparkConf) + } + + serializer + } + override def execute(): RDD[Row] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. @@ -111,7 +152,10 @@ case class Exchange( } else { new ShuffledRDD[Row, Row, Row](rdd, part) } - shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) + val keySchema = expressions.map(_.dataType).toArray + val valueSchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions)) + shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => @@ -134,7 +178,9 @@ case class Exchange( } else { new ShuffledRDD[Row, Null, Null](rdd, part) } - shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) + val keySchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(keySchema, null, numPartitions)) + shuffled.map(_._1) case SinglePartition => @@ -152,7 +198,8 @@ case class Exchange( } val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) - shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) + val valueSchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(null, valueSchema, 1)) shuffled.map(_._2) case _ => sys.error(s"Exchange not implemented for $newPartitioning") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala new file mode 100644 index 0000000000000..cec97de2cd8e4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io._ +import java.math.{BigDecimal, BigInteger} +import java.nio.ByteBuffer +import java.sql.Timestamp + +import scala.reflect.ClassTag + +import org.apache.spark.serializer._ +import org.apache.spark.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types._ + +/** + * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in + * its `writeObject` are [[Product2]]. The serialization functions for the key and value of the + * [[Product2]] are constructed based on their schemata. + * The benefit of this serialization stream is that compared with general-purpose serializers like + * Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower + * allocation cost, which can benefit the shuffle operation. Right now, its main limitations are: + * 1. It does not support complex types, i.e. Map, Array, and Struct. + * 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when + * [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because + * the objects passed in the serializer are not in the type of [[Product2]]. Also also see + * the comment of the `serializer` method in [[Exchange]] for more information on it. + */ +private[sql] class Serializer2SerializationStream( + keySchema: Array[DataType], + valueSchema: Array[DataType], + out: OutputStream) + extends SerializationStream with Logging { + + val rowOut = new DataOutputStream(out) + val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) + val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + + def writeObject[T: ClassTag](t: T): SerializationStream = { + val kv = t.asInstanceOf[Product2[Row, Row]] + writeKey(kv._1) + writeValue(kv._2) + + this + } + + def flush(): Unit = { + rowOut.flush() + } + + def close(): Unit = { + rowOut.close() + } +} + +/** + * The corresponding deserialization stream for [[Serializer2SerializationStream]]. + */ +private[sql] class Serializer2DeserializationStream( + keySchema: Array[DataType], + valueSchema: Array[DataType], + in: InputStream) + extends DeserializationStream with Logging { + + val rowIn = new DataInputStream(new BufferedInputStream(in)) + + val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null + val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null + val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key) + val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value) + + def readObject[T: ClassTag](): T = { + readKey() + readValue() + + (key, value).asInstanceOf[T] + } + + def close(): Unit = { + rowIn.close() + } +} + +private[sql] class ShuffleSerializerInstance( + keySchema: Array[DataType], + valueSchema: Array[DataType]) + extends SerializerInstance { + + def serialize[T: ClassTag](t: T): ByteBuffer = + throw new UnsupportedOperationException("Not supported.") + + def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException("Not supported.") + + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException("Not supported.") + + def serializeStream(s: OutputStream): SerializationStream = { + new Serializer2SerializationStream(keySchema, valueSchema, s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new Serializer2DeserializationStream(keySchema, valueSchema, s) + } +} + +/** + * SparkSqlSerializer2 is a special serializer that creates serialization function and + * deserialization function based on the schema of data. It assumes that values passed in + * are key/value pairs and values returned from it are also key/value pairs. + * The schema of keys is represented by `keySchema` and that of values is represented by + * `valueSchema`. + */ +private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType]) + extends Serializer + with Logging + with Serializable{ + + def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema) +} + +private[sql] object SparkSqlSerializer2 { + + final val NULL = 0 + final val NOT_NULL = 1 + + /** + * Check if rows with the given schema can be serialized with ShuffleSerializer. + */ + def support(schema: Array[DataType]): Boolean = { + if (schema == null) return true + + var i = 0 + while (i < schema.length) { + schema(i) match { + case udt: UserDefinedType[_] => return false + case array: ArrayType => return false + case map: MapType => return false + case struct: StructType => return false + case _ => + } + i += 1 + } + + return true + } + + /** + * The util function to create the serialization function based on the given schema. + */ + def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = { + (row: Row) => + // If the schema is null, the returned function does nothing when it get called. + if (schema != null) { + var i = 0 + while (i < schema.length) { + schema(i) match { + // When we write values to the underlying stream, we also first write the null byte + // first. Then, if the value is not null, we write the contents out. + + case NullType => // Write nothing. + + case BooleanType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeBoolean(row.getBoolean(i)) + } + + case ByteType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeByte(row.getByte(i)) + } + + case ShortType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeShort(row.getShort(i)) + } + + case IntegerType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeInt(row.getInt(i)) + } + + case LongType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeLong(row.getLong(i)) + } + + case FloatType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeFloat(row.getFloat(i)) + } + + case DoubleType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeDouble(row.getDouble(i)) + } + + case decimal: DecimalType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val value = row.apply(i).asInstanceOf[Decimal] + val javaBigDecimal = value.toJavaBigDecimal + // First, write out the unscaled value. + val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray + out.writeInt(bytes.length) + out.write(bytes) + // Then, write out the scale. + out.writeInt(javaBigDecimal.scale()) + } + + case DateType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeInt(row.getAs[Int](i)) + } + + case TimestampType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val timestamp = row.getAs[java.sql.Timestamp](i) + val time = timestamp.getTime + val nanos = timestamp.getNanos + out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value. + out.writeInt(nanos) // Write the nanoseconds part. + } + + case StringType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val bytes = row.getAs[UTF8String](i).getBytes + out.writeInt(bytes.length) + out.write(bytes) + } + + case BinaryType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val bytes = row.getAs[Array[Byte]](i) + out.writeInt(bytes.length) + out.write(bytes) + } + } + i += 1 + } + } + } + + /** + * The util function to create the deserialization function based on the given schema. + */ + def createDeserializationFunction( + schema: Array[DataType], + in: DataInputStream, + mutableRow: SpecificMutableRow): () => Unit = { + () => { + // If the schema is null, the returned function does nothing when it get called. + if (schema != null) { + var i = 0 + while (i < schema.length) { + schema(i) match { + // When we read values from the underlying stream, we also first read the null byte + // first. Then, if the value is not null, we update the field of the mutable row. + + case NullType => mutableRow.setNullAt(i) // Read nothing. + + case BooleanType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setBoolean(i, in.readBoolean()) + } + + case ByteType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setByte(i, in.readByte()) + } + + case ShortType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setShort(i, in.readShort()) + } + + case IntegerType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setInt(i, in.readInt()) + } + + case LongType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setLong(i, in.readLong()) + } + + case FloatType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setFloat(i, in.readFloat()) + } + + case DoubleType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setDouble(i, in.readDouble()) + } + + case decimal: DecimalType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + // First, read in the unscaled value. + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + val unscaledVal = new BigInteger(bytes) + // Then, read the scale. + val scale = in.readInt() + // Finally, create the Decimal object and set it in the row. + mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) + } + + case DateType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.update(i, in.readInt()) + } + + case TimestampType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val time = in.readLong() // Read the milliseconds value. + val nanos = in.readInt() // Read the nanoseconds part. + val timestamp = new Timestamp(time) + timestamp.setNanos(nanos) + mutableRow.update(i, timestamp) + } + + case StringType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + mutableRow.update(i, UTF8String(bytes)) + } + + case BinaryType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + mutableRow.update(i, bytes) + } + } + i += 1 + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala new file mode 100644 index 0000000000000..27f063d73a9a9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.sql.{Timestamp, Date} + +import org.scalatest.{FunSuite, BeforeAndAfterAll} + +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.ShuffleDependency +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} + +class SparkSqlSerializer2DataTypeSuite extends FunSuite { + // Make sure that we will not use serializer2 for unsupported data types. + def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { + val testName = + s"${if (dataType == null) null else dataType.toString} is " + + s"${if (isSupported) "supported" else "unsupported"}" + + test(testName) { + assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported) + } + } + + checkSupported(null, isSupported = true) + checkSupported(NullType, isSupported = true) + checkSupported(BooleanType, isSupported = true) + checkSupported(ByteType, isSupported = true) + checkSupported(ShortType, isSupported = true) + checkSupported(IntegerType, isSupported = true) + checkSupported(LongType, isSupported = true) + checkSupported(FloatType, isSupported = true) + checkSupported(DoubleType, isSupported = true) + checkSupported(DateType, isSupported = true) + checkSupported(TimestampType, isSupported = true) + checkSupported(StringType, isSupported = true) + checkSupported(BinaryType, isSupported = true) + checkSupported(DecimalType(10, 5), isSupported = true) + checkSupported(DecimalType.Unlimited, isSupported = true) + + // For now, ArrayType, MapType, and StructType are not supported. + checkSupported(ArrayType(DoubleType, true), isSupported = false) + checkSupported(ArrayType(StringType, false), isSupported = false) + checkSupported(MapType(IntegerType, StringType, true), isSupported = false) + checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false) + checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false) + // UDTs are not supported right now. + checkSupported(new MyDenseVectorUDT, isSupported = false) +} + +abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { + var allColumns: String = _ + val serializerClass: Class[Serializer] = + classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] + var numShufflePartitions: Int = _ + var useSerializer2: Boolean = _ + + override def beforeAll(): Unit = { + numShufflePartitions = conf.numShufflePartitions + useSerializer2 = conf.useSqlSerializer2 + + sql("set spark.sql.useSerializer2=true") + + val supportedTypes = + Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + DateType, TimestampType) + + val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + allColumns = fields.map(_.name).mkString(",") + val schema = StructType(fields) + + // Create a RDD with all data types supported by SparkSqlSerializer2. + val rdd = + sparkContext.parallelize((1 to 1000), 10).map { i => + Row( + s"str${i}: test serializer2.", + s"binary${i}: test serializer2.".getBytes("UTF-8"), + null, + i % 2 == 0, + i.toByte, + i.toShort, + i, + Long.MaxValue - i.toLong, + (i + 0.25).toFloat, + (i + 0.75), + BigDecimal(Long.MaxValue.toString + ".12345"), + new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), + new Date(i), + new Timestamp(i)) + } + + createDataFrame(rdd, schema).registerTempTable("shuffle") + + super.beforeAll() + } + + override def afterAll(): Unit = { + dropTempTable("shuffle") + sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") + sql(s"set spark.sql.useSerializer2=$useSerializer2") + super.afterAll() + } + + def checkSerializer[T <: Serializer]( + executedPlan: SparkPlan, + expectedSerializerClass: Class[T]): Unit = { + executedPlan.foreach { + case exchange: Exchange => + val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]] + val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + val serializerNotSetMessage = + s"Expected $expectedSerializerClass as the serializer of Exchange. " + + s"However, the serializer was not set." + val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) + assert(serializer.getClass === expectedSerializerClass) + case _ => // Ignore other nodes. + } + } + + test("key schema and value schema are not nulls") { + val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + table("shuffle").collect()) + } + + test("value schema is null") { + val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + assert( + df.map(r => r.getString(0)).collect().toSeq === + table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) + } + + test("no map output field") { + val df = sql(s"SELECT 1 + 1 FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) + } +} + +/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ +class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { + override def beforeAll(): Unit = { + super.beforeAll() + // Sort merge will not be triggered. + sql("set spark.sql.shuffle.partitions = 200") + } + + test("key schema is null") { + val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") + val df = sql(s"SELECT $aggregations FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) + } +} + +/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ +class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { + + // We are expecting SparkSqlSerializer. + override val serializerClass: Class[Serializer] = + classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]] + + override def beforeAll(): Unit = { + super.beforeAll() + // To trigger the sort merge. + sql("set spark.sql.shuffle.partitions = 201") + } +} From c736220dac51cf73181fdd7f621c960c4e7bf0c2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 20 Apr 2015 18:54:01 -0700 Subject: [PATCH 060/191] [SPARK-6635][SQL] DataFrame.withColumn should replace columns with identical column names JIRA https://issues.apache.org/jira/browse/SPARK-6635 Author: Liang-Chi Hsieh Closes #5541 from viirya/replace_with_column and squashes the following commits: b539c7b [Liang-Chi Hsieh] For comment. 72f35b1 [Liang-Chi Hsieh] DataFrame.withColumn can replace original column with identical column name. --- .../scala/org/apache/spark/sql/DataFrame.scala | 14 +++++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 8 ++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 17c21f6e3a0e9..45f5da387692e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -747,7 +747,19 @@ class DataFrame private[sql]( * Returns a new [[DataFrame]] by adding a column. * @group dfops */ - def withColumn(colName: String, col: Column): DataFrame = select(Column("*"), col.as(colName)) + def withColumn(colName: String, col: Column): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val replaced = schema.exists(f => resolver(f.name, colName)) + if (replaced) { + val colNames = schema.map { field => + val name = field.name + if (resolver(name, colName)) col.as(colName) else Column(name) + } + select(colNames :_*) + } else { + select(Column("*"), col.as(colName)) + } + } /** * Returns a new [[DataFrame]] with a column renamed. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3250ab476aeb4..b9b6a400ae195 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -473,6 +473,14 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol")) } + test("replace column using withColumn") { + val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df3 = df2.withColumn("x", df2("x") + 1) + checkAnswer( + df3.select("x"), + Row(2) :: Row(3) :: Row(4) :: Nil) + } + test("withColumnRenamed") { val df = testData.toDF().withColumn("newCol", col("key") + 1) .withColumnRenamed("value", "valueRenamed") From 8136810dfad12008ac300116df7bc8448740f1ae Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 20 Apr 2015 23:18:42 -0700 Subject: [PATCH 061/191] [SPARK-6490][Core] Add spark.rpc.* and deprecate spark.akka.* Deprecated `spark.akka.num.retries`, `spark.akka.retry.wait`, `spark.akka.askTimeout`, `spark.akka.lookupTimeout`, and added `spark.rpc.num.retries`, `spark.rpc.retry.wait`, `spark.rpc.askTimeout`, `spark.rpc.lookupTimeout`. Author: zsxwing Closes #5595 from zsxwing/SPARK-6490 and squashes the following commits: e0d80a9 [zsxwing] Use getTimeAsMs and getTimeAsSeconds and other minor fixes 31dbe69 [zsxwing] Add spark.rpc.* and deprecate spark.akka.* --- .../scala/org/apache/spark/SparkConf.scala | 10 ++++++- .../org/apache/spark/deploy/Client.scala | 6 ++--- .../spark/deploy/client/AppClient.scala | 4 +-- .../apache/spark/deploy/master/Master.scala | 4 +-- .../spark/deploy/master/ui/MasterWebUI.scala | 4 +-- .../deploy/rest/StandaloneRestServer.scala | 8 +++--- .../spark/deploy/worker/ui/WorkerWebUI.scala | 4 +-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 10 +++---- .../cluster/YarnSchedulerBackend.scala | 4 +-- .../spark/storage/BlockManagerMaster.scala | 4 +-- .../org/apache/spark/util/AkkaUtils.scala | 26 +++---------------- .../org/apache/spark/util/RpcUtils.scala | 23 ++++++++++++++++ .../apache/spark/MapOutputTrackerSuite.scala | 4 +-- .../org/apache/spark/SparkConfSuite.scala | 24 ++++++++++++++++- .../org/apache/spark/rpc/RpcEnvSuite.scala | 4 +-- 15 files changed, 86 insertions(+), 53 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index e3a649d755450..c1996e08756a6 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -431,7 +431,15 @@ private[spark] object SparkConf extends Logging { "spark.yarn.am.waitTime" -> Seq( AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", // Translate old value to a duration, with 10s wait time per try. - translation = s => s"${s.toLong * 10}s")) + translation = s => s"${s.toLong * 10}s")), + "spark.rpc.numRetries" -> Seq( + AlternateConfig("spark.akka.num.retries", "1.4")), + "spark.rpc.retry.wait" -> Seq( + AlternateConfig("spark.akka.retry.wait", "1.4")), + "spark.rpc.askTimeout" -> Seq( + AlternateConfig("spark.akka.askTimeout", "1.4")), + "spark.rpc.lookupTimeout" -> Seq( + AlternateConfig("spark.akka.lookupTimeout", "1.4")) ) /** diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 8d13b2a2cd4f3..c2c3e9a9e4827 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -27,7 +27,7 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} /** * Proxy that relays messages to the driver. @@ -36,7 +36,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with ActorLogReceive with Logging { var masterActor: ActorSelection = _ - val timeout = AkkaUtils.askTimeout(conf) + val timeout = RpcUtils.askTimeout(conf) override def preStart(): Unit = { masterActor = context.actorSelection( @@ -155,7 +155,7 @@ object Client { if (!driverArgs.logLevel.isGreaterOrEqual(Level.WARN)) { conf.set("spark.akka.logLifecycleEvents", "true") } - conf.set("spark.akka.askTimeout", "10") + conf.set("spark.rpc.askTimeout", "10") conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 4f06d7f96c46e..43c8a934c311a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -30,7 +30,7 @@ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils} +import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -193,7 +193,7 @@ private[spark] class AppClient( def stop() { if (actor != null) { try { - val timeout = AkkaUtils.askTimeout(conf) + val timeout = RpcUtils.askTimeout(conf) val future = actor.ask(StopAppClient)(timeout) Await.result(future, timeout) } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index c5a6b1beac9be..ff2eed6dee70a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -47,7 +47,7 @@ import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, SignalLogger, Utils} private[master] class Master( host: String, @@ -931,7 +931,7 @@ private[deploy] object Master extends Logging { securityManager = securityMgr) val actor = actorSystem.actorOf( Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) - val timeout = AkkaUtils.askTimeout(conf) + val timeout = RpcUtils.askTimeout(conf) val portsRequest = actor.ask(BoundPortsRequest)(timeout) val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index bb11e0642ddc6..aad9c87bdb987 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -21,7 +21,7 @@ import org.apache.spark.Logging import org.apache.spark.deploy.master.Master import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone master. @@ -31,7 +31,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging { val masterActorRef = master.self - val timeout = AkkaUtils.askTimeout(master.conf) + val timeout = RpcUtils.askTimeout(master.conf) val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) initialize() diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 4f19af59f409f..2d6b8d4204795 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -32,7 +32,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} import org.apache.spark.deploy.ClientArguments._ @@ -223,7 +223,7 @@ private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) } protected def handleKill(submissionId: String): KillSubmissionResponse = { - val askTimeout = AkkaUtils.askTimeout(conf) + val askTimeout = RpcUtils.askTimeout(conf) val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) val k = new KillSubmissionResponse @@ -257,7 +257,7 @@ private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) } protected def handleStatus(submissionId: String): SubmissionStatusResponse = { - val askTimeout = AkkaUtils.askTimeout(conf) + val askTimeout = RpcUtils.askTimeout(conf) val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } @@ -321,7 +321,7 @@ private[rest] class SubmitRequestServlet( responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { requestMessage match { case submitRequest: CreateSubmissionRequest => - val askTimeout = AkkaUtils.askTimeout(conf) + val askTimeout = RpcUtils.askTimeout(conf) val driverDescription = buildDriverDescription(submitRequest) val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index de6423beb543e..b3bb5f911dbd7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -25,7 +25,7 @@ import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone worker. @@ -38,7 +38,7 @@ class WorkerWebUI( extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { - private[ui] val timeout = AkkaUtils.askTimeout(worker.conf) + private[ui] val timeout = RpcUtils.askTimeout(worker.conf) initialize() diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index cba038ca355d7..a5336b7563802 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -25,7 +25,7 @@ import scala.language.postfixOps import scala.reflect.ClassTag import org.apache.spark.{Logging, SparkException, SecurityManager, SparkConf} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to @@ -38,7 +38,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} */ private[spark] abstract class RpcEnv(conf: SparkConf) { - private[spark] val defaultLookupTimeout = AkkaUtils.lookupTimeout(conf) + private[spark] val defaultLookupTimeout = RpcUtils.lookupTimeout(conf) /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement @@ -282,9 +282,9 @@ trait ThreadSafeRpcEndpoint extends RpcEndpoint private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) extends Serializable with Logging { - private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) - private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) - private[this] val defaultAskTimeout = conf.getLong("spark.akka.askTimeout", 30) seconds + private[this] val maxRetries = RpcUtils.numRetries(conf) + private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) + private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) /** * return the address for the [[RpcEndpointRef]] diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index f72566c370a6f..1406a36a669c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -24,7 +24,7 @@ import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} import scala.util.control.NonFatal @@ -46,7 +46,7 @@ private[spark] abstract class YarnSchedulerBackend( private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) - private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) + private implicit val askTimeout = RpcUtils.askTimeout(sc.conf) /** * Request executors from the ApplicationMaster by specifying the total number desired. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index ceacf043029f3..c798843bd5d8a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -23,7 +23,7 @@ import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.RpcUtils private[spark] class BlockManagerMaster( @@ -32,7 +32,7 @@ class BlockManagerMaster( isDriver: Boolean) extends Logging { - val timeout = AkkaUtils.askTimeout(conf) + val timeout = RpcUtils.askTimeout(conf) /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 8e8cc7cc6389e..b725df3b44596 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import scala.collection.JavaConversions.mapAsJavaMap import scala.concurrent.Await -import scala.concurrent.duration.{Duration, FiniteDuration} +import scala.concurrent.duration.FiniteDuration import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -125,16 +125,6 @@ private[spark] object AkkaUtils extends Logging { (actorSystem, boundPort) } - /** Returns the default Spark timeout to use for Akka ask operations. */ - def askTimeout(conf: SparkConf): FiniteDuration = { - Duration.create(conf.getLong("spark.akka.askTimeout", 30), "seconds") - } - - /** Returns the default Spark timeout to use for Akka remote actor lookup. */ - def lookupTimeout(conf: SparkConf): FiniteDuration = { - Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") - } - private val AKKA_MAX_FRAME_SIZE_IN_MB = Int.MaxValue / 1024 / 1024 /** Returns the configured max frame size for Akka messages in bytes. */ @@ -150,16 +140,6 @@ private[spark] object AkkaUtils extends Logging { /** Space reserved for extra data in an Akka message besides serialized task or task result. */ val reservedSizeBytes = 200 * 1024 - /** Returns the configured number of times to retry connecting */ - def numRetries(conf: SparkConf): Int = { - conf.getInt("spark.akka.num.retries", 3) - } - - /** Returns the configured number of milliseconds to wait on each retry */ - def retryWaitMs(conf: SparkConf): Int = { - conf.getInt("spark.akka.retry.wait", 3000) - } - /** * Send a message to the given actor and get its result within a default timeout, or * throw a SparkException if this fails. @@ -216,7 +196,7 @@ private[spark] object AkkaUtils extends Logging { val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) - val timeout = AkkaUtils.lookupTimeout(conf) + val timeout = RpcUtils.lookupTimeout(conf) logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } @@ -230,7 +210,7 @@ private[spark] object AkkaUtils extends Logging { val executorActorSystemName = SparkEnv.executorActorSystemName Utils.checkHost(host, "Expected hostname") val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) - val timeout = AkkaUtils.lookupTimeout(conf) + val timeout = RpcUtils.lookupTimeout(conf) logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 6665b17c3d5df..5ae793e0e87a3 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -17,6 +17,9 @@ package org.apache.spark.util +import scala.concurrent.duration._ +import scala.language.postfixOps + import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} @@ -32,4 +35,24 @@ object RpcUtils { Utils.checkHost(driverHost, "Expected hostname") rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name) } + + /** Returns the configured number of times to retry connecting */ + def numRetries(conf: SparkConf): Int = { + conf.getInt("spark.rpc.numRetries", 3) + } + + /** Returns the configured number of milliseconds to wait on each retry */ + def retryWaitMs(conf: SparkConf): Long = { + conf.getTimeAsMs("spark.rpc.retry.wait", "3s") + } + + /** Returns the default Spark timeout to use for RPC ask operations. */ + def askTimeout(conf: SparkConf): FiniteDuration = { + conf.getTimeAsSeconds("spark.rpc.askTimeout", "30s") seconds + } + + /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ + def lookupTimeout(conf: SparkConf): FiniteDuration = { + conf.getTimeAsSeconds("spark.rpc.lookupTimeout", "30s") seconds + } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 6295d34be5ca9..6ed057a7cab97 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -154,7 +154,7 @@ class MapOutputTrackerSuite extends FunSuite { test("remote fetch below akka frame size") { val newConf = new SparkConf newConf.set("spark.akka.frameSize", "1") - newConf.set("spark.akka.askTimeout", "1") // Fail fast + newConf.set("spark.rpc.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) val rpcEnv = createRpcEnv("spark") @@ -180,7 +180,7 @@ class MapOutputTrackerSuite extends FunSuite { test("remote fetch exceeds akka frame size") { val newConf = new SparkConf newConf.set("spark.akka.frameSize", "1") - newConf.set("spark.akka.askTimeout", "1") // Fail fast + newConf.set("spark.rpc.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) val rpcEnv = createRpcEnv("test") diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 8e6c200c4ba00..d7d8014a20498 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -19,11 +19,13 @@ package org.apache.spark import java.util.concurrent.{TimeUnit, Executors} +import scala.concurrent.duration._ +import scala.language.postfixOps import scala.util.{Try, Random} import org.scalatest.FunSuite import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} -import org.apache.spark.util.ResetSystemProperties +import org.apache.spark.util.{RpcUtils, ResetSystemProperties} import com.esotericsoftware.kryo.Kryo class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { @@ -222,6 +224,26 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(conf.getTimeAsSeconds("spark.yarn.am.waitTime") === 420) } + test("akka deprecated configs") { + val conf = new SparkConf() + + assert(!conf.contains("spark.rpc.num.retries")) + assert(!conf.contains("spark.rpc.retry.wait")) + assert(!conf.contains("spark.rpc.askTimeout")) + assert(!conf.contains("spark.rpc.lookupTimeout")) + + conf.set("spark.akka.num.retries", "1") + assert(RpcUtils.numRetries(conf) === 1) + + conf.set("spark.akka.retry.wait", "2") + assert(RpcUtils.retryWaitMs(conf) === 2L) + + conf.set("spark.akka.askTimeout", "3") + assert(RpcUtils.askTimeout(conf) === (3 seconds)) + + conf.set("spark.akka.lookupTimeout", "4") + assert(RpcUtils.lookupTimeout(conf) === (4 seconds)) + } } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index ada07ef11cd7a..5fbda37c7cb88 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -155,8 +155,8 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { }) val conf = new SparkConf() - conf.set("spark.akka.retry.wait", "0") - conf.set("spark.akka.num.retries", "1") + conf.set("spark.rpc.retry.wait", "0") + conf.set("spark.rpc.num.retries", "1") val anotherEnv = createRpcEnv(conf, "remote", 13345) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") From ab9128fb7ec7ca479dc91e7cc7c775e8d071eafa Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 21 Apr 2015 00:08:18 -0700 Subject: [PATCH 062/191] [SPARK-6949] [SQL] [PySpark] Support Date/Timestamp in Column expression This PR enable auto_convert in JavaGateway, then we could register a converter for a given types, for example, date and datetime. There are two bugs related to auto_convert, see [1] and [2], we workaround it in this PR. [1] https://github.com/bartdag/py4j/issues/160 [2] https://github.com/bartdag/py4j/issues/161 cc rxin JoshRosen Author: Davies Liu Closes #5570 from davies/py4j_date and squashes the following commits: eb4fa53 [Davies Liu] fix tests in python 3 d17d634 [Davies Liu] rollback changes in mllib 2e7566d [Davies Liu] convert tuple into ArrayList ceb3779 [Davies Liu] Update rdd.py 3c373f3 [Davies Liu] support date and datetime by auto_convert cb094ff [Davies Liu] enable auto convert --- python/pyspark/context.py | 6 +----- python/pyspark/java_gateway.py | 15 ++++++++++++++- python/pyspark/rdd.py | 3 +++ python/pyspark/sql/_types.py | 27 +++++++++++++++++++++++++++ python/pyspark/sql/context.py | 13 ++++--------- python/pyspark/sql/dataframe.py | 18 ++++-------------- python/pyspark/sql/tests.py | 11 +++++++++++ python/pyspark/streaming/context.py | 11 +++-------- python/pyspark/streaming/kafka.py | 7 ++----- python/pyspark/streaming/tests.py | 6 +----- 10 files changed, 70 insertions(+), 47 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6a743ac8bd600..b006120eb266d 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -23,8 +23,6 @@ from threading import Lock from tempfile import NamedTemporaryFile -from py4j.java_collections import ListConverter - from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast @@ -643,7 +641,6 @@ def union(self, rdds): rdds = [x._reserialize() for x in rdds] first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self._gateway._gateway_client) return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): @@ -846,13 +843,12 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): """ if partitions is None: partitions = range(rdd._jrdd.partitions().size()) - javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client) # Implementation note: This is implemented as a mapPartitions followed # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, + port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions, allowLocal) return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 45bc38f7e61f8..3cee4ea6e3a35 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -17,17 +17,30 @@ import atexit import os +import sys import select import signal import shlex import socket import platform from subprocess import Popen, PIPE + +if sys.version >= '3': + xrange = range + from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.java_collections import ListConverter from pyspark.serializers import read_int +# patching ListConverter, or it will convert bytearray into Java ArrayList +def can_convert_list(self, obj): + return isinstance(obj, (list, tuple, xrange)) + +ListConverter.can_convert = can_convert_list + + def launch_gateway(): if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) @@ -92,7 +105,7 @@ def killChild(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False) + gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d9cdbb666f92a..d254deb527d10 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2267,6 +2267,9 @@ def _prepare_for_python_RDD(sc, command, obj=None): # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) + # There is a bug in py4j.java_gateway.JavaClass with auto_convert + # https://github.com/bartdag/py4j/issues/161 + # TODO: use auto_convert once py4j fix the bug broadcast_vars = ListConverter().convert( [x._jbroadcast for x in sc._pickled_broadcast_vars], sc._gateway._gateway_client) diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py index 110d1152fbdb6..95fb91ad43457 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/_types.py @@ -17,6 +17,7 @@ import sys import decimal +import time import datetime import keyword import warnings @@ -30,6 +31,9 @@ long = int unicode = str +from py4j.protocol import register_input_converter +from py4j.java_gateway import JavaClass + __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", @@ -1237,6 +1241,29 @@ def __repr__(self): return "" % ", ".join(self) +class DateConverter(object): + def can_convert(self, obj): + return isinstance(obj, datetime.date) + + def convert(self, obj, gateway_client): + Date = JavaClass("java.sql.Date", gateway_client) + return Date.valueOf(obj.strftime("%Y-%m-%d")) + + +class DatetimeConverter(object): + def can_convert(self, obj): + return isinstance(obj, datetime.datetime) + + def convert(self, obj, gateway_client): + Timestamp = JavaClass("java.sql.Timestamp", gateway_client) + return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000) + + +# datetime is a subclass of date, we should register DatetimeConverter first +register_input_converter(DatetimeConverter()) +register_input_converter(DateConverter()) + + def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index acf3c114548c0..f6f107ca32d2f 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -25,7 +25,6 @@ from itertools import imap as map from py4j.protocol import Py4JError -from py4j.java_collections import MapConverter from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer @@ -442,15 +441,13 @@ def load(self, path=None, source=None, schema=None, **options): if source is None: source = self.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) if schema is None: - df = self._ssql_ctx.load(source, joptions) + df = self._ssql_ctx.load(source, options) else: if not isinstance(schema, StructType): raise TypeError("schema should be StructType") scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.load(source, scala_datatype, joptions) + df = self._ssql_ctx.load(source, scala_datatype, options) return DataFrame(df, self) def createExternalTable(self, tableName, path=None, source=None, @@ -471,16 +468,14 @@ def createExternalTable(self, tableName, path=None, source=None, if source is None: source = self.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) if schema is None: - df = self._ssql_ctx.createExternalTable(tableName, source, joptions) + df = self._ssql_ctx.createExternalTable(tableName, source, options) else: if not isinstance(schema, StructType): raise TypeError("schema should be StructType") scala_datatype = self._ssql_ctx.parseDataType(schema.json()) df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype, - joptions) + options) return DataFrame(df, self) @ignore_unicode_prefix diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 75c181c0c7f5e..ca9bf8efb945c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -25,8 +25,6 @@ else: from itertools import imap as map -from py4j.java_collections import ListConverter, MapConverter - from pyspark.context import SparkContext from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer @@ -186,9 +184,7 @@ def saveAsTable(self, tableName, source=None, mode="error", **options): source = self.sql_ctx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") jmode = self._java_save_mode(mode) - joptions = MapConverter().convert(options, - self.sql_ctx._sc._gateway._gateway_client) - self._jdf.saveAsTable(tableName, source, jmode, joptions) + self._jdf.saveAsTable(tableName, source, jmode, options) def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. @@ -211,9 +207,7 @@ def save(self, path=None, source=None, mode="error", **options): source = self.sql_ctx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") jmode = self._java_save_mode(mode) - joptions = MapConverter().convert(options, - self._sc._gateway._gateway_client) - self._jdf.save(source, jmode, joptions) + self._jdf.save(source, jmode, options) @property def schema(self): @@ -819,7 +813,6 @@ def fillna(self, value, subset=None): value = float(value) if isinstance(value, dict): - value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client) return DataFrame(self._jdf.na().fill(value), self.sql_ctx) elif subset is None: return DataFrame(self._jdf.na().fill(value), self.sql_ctx) @@ -932,9 +925,7 @@ def agg(self, *exprs): """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): - jmap = MapConverter().convert(exprs[0], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(jmap) + jdf = self._jdf.agg(exprs[0]) else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" @@ -1040,8 +1031,7 @@ def _to_seq(sc, cols, converter=None): """ if converter: cols = [converter(c) for c in cols] - jcols = ListConverter().convert(cols, sc._gateway._gateway_client) - return sc._jvm.PythonUtils.toSeq(jcols) + return sc._jvm.PythonUtils.toSeq(cols) def _unary_op(name, doc="unary operator"): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index aa3aa1d164d9f..23e84283679e1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -26,6 +26,7 @@ import tempfile import pickle import functools +import datetime import py4j @@ -464,6 +465,16 @@ def test_infer_long_type(self): self.assertEqual(_infer_type(2**61), LongType()) self.assertEqual(_infer_type(2**71), LongType()) + def test_filter_with_datetime(self): + time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) + date = time.date() + row = Row(date=date, time=time) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1, df.filter(df.date == date).count()) + self.assertEqual(1, df.filter(df.time == time).count()) + self.assertEqual(0, df.filter(df.date > date).count()) + self.assertEqual(0, df.filter(df.time > time).count()) + def test_dropna(self): schema = StructType([ StructField("name", StringType(), True), diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 4590c58839266..ac5ba69e8dbbb 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -20,7 +20,6 @@ import os import sys -from py4j.java_collections import ListConverter from py4j.java_gateway import java_import, JavaObject from pyspark import RDD, SparkConf @@ -305,9 +304,7 @@ def queueStream(self, rdds, oneAtATime=True, default=None): rdds = [self._sc.parallelize(input) for input in rdds] self._check_serializers(rdds) - jrdds = ListConverter().convert([r._jrdd for r in rdds], - SparkContext._gateway._gateway_client) - queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds]) if default: default = default._reserialize(rdds[0]._jrdd_deserializer) jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) @@ -322,8 +319,7 @@ def transform(self, dstreams, transformFunc): the transform function parameter will be the same as the order of corresponding DStreams in the list. """ - jdstreams = ListConverter().convert([d._jdstream for d in dstreams], - SparkContext._gateway._gateway_client) + jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction(self._sc, lambda t, *rdds: transformFunc(rdds).map(lambda x: x), @@ -346,6 +342,5 @@ def union(self, *dstreams): if len(set(s._slideDuration for s in dstreams)) > 1: raise ValueError("All DStreams should have same slide duration") first = dstreams[0] - jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], - SparkContext._gateway._gateway_client) + jrest = [d._jdstream for d in dstreams[1:]] return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 7a7b6e1d9a527..8d610d6569b4a 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,8 +15,7 @@ # limitations under the License. # -from py4j.java_collections import MapConverter -from py4j.java_gateway import java_import, Py4JError, Py4JJavaError +from py4j.java_gateway import Py4JJavaError from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer @@ -57,8 +56,6 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, }) if not isinstance(topics, dict): raise TypeError("topics should be dict") - jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client) - jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) try: @@ -66,7 +63,7 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") helper = helperClass.newInstance() - jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel) + jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel) except Py4JJavaError as e: # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 06d22154373bc..33f958a601f3a 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -24,8 +24,6 @@ import struct from functools import reduce -from py4j.java_collections import MapConverter - from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import KafkaUtils @@ -581,11 +579,9 @@ def test_kafka_stream(self): """Test the Python Kafka stream API.""" topic = "topic1" sendData = {"a": 3, "b": 5, "c": 10} - jSendData = MapConverter().convert(sendData, - self.ssc.sparkContext._gateway._gateway_client) self._kafkaTestUtils.createTopic(topic) - self._kafkaTestUtils.sendMessages(topic, jSendData) + self._kafkaTestUtils.sendMessages(topic, sendData) stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, From 1f2f723b0daacbb9e70ec42c19a84470af1d7765 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 21 Apr 2015 00:14:16 -0700 Subject: [PATCH 063/191] [SPARK-5990] [MLLIB] Model import/export for IsotonicRegression Model import/export for IsotonicRegression Author: Yanbo Liang Closes #5270 from yanboliang/spark-5990 and squashes the following commits: 872028d [Yanbo Liang] fix code style f80ec1b [Yanbo Liang] address comments 49600cc [Yanbo Liang] address comments 429ff7d [Yanbo Liang] store each interval as a record 2b2f5a1 [Yanbo Liang] Model import/export for IsotonicRegression --- .../mllib/regression/IsotonicRegression.scala | 78 ++++++++++++++++++- .../regression/IsotonicRegressionSuite.scala | 21 +++++ 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index cb70852e3cc8d..1d7617046b6c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -23,9 +23,16 @@ import java.util.Arrays.binarySearch import scala.collection.mutable.ArrayBuffer +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.{DataFrame, SQLContext} /** * :: Experimental :: @@ -42,7 +49,7 @@ import org.apache.spark.rdd.RDD class IsotonicRegressionModel ( val boundaries: Array[Double], val predictions: Array[Double], - val isotonic: Boolean) extends Serializable { + val isotonic: Boolean) extends Serializable with Saveable { private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse @@ -124,6 +131,75 @@ class IsotonicRegressionModel ( predictions(foundIndex) } } + + override def save(sc: SparkContext, path: String): Unit = { + IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic) + } + + override protected def formatVersion: String = "1.0" +} + +object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { + + import org.apache.spark.mllib.util.Loader._ + + private object SaveLoadV1_0 { + + def thisFormatVersion: String = "1.0" + + /** Hard-code class name string in case it changes in the future */ + def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel" + + /** Model data for model import/export */ + case class Data(boundary: Double, prediction: Double) + + def save( + sc: SparkContext, + path: String, + boundaries: Array[Double], + predictions: Array[Double], + isotonic: Boolean): Unit = { + val sqlContext = new SQLContext(sc) + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("isotonic" -> isotonic))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) + + sqlContext.createDataFrame( + boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) } + ).saveAsParquetFile(dataPath(path)) + } + + def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { + val sqlContext = new SQLContext(sc) + val dataRDD = sqlContext.parquetFile(dataPath(path)) + + checkSchema[Data](dataRDD.schema) + val dataArray = dataRDD.select("boundary", "prediction").collect() + val (boundaries, predictions) = dataArray.map { x => + (x.getDouble(0), x.getDouble(1)) + }.toList.sortBy(_._1).unzip + (boundaries.toArray, predictions.toArray) + } + } + + override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { + implicit val formats = DefaultFormats + val (loadedClassName, version, metadata) = loadMetadata(sc, path) + val isotonic = (metadata \ "isotonic").extract[Boolean] + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, version) match { + case (className, "1.0") if className == classNameV1_0 => + val (boundaries, predictions) = SaveLoadV1_0.load(sc, path) + new IsotonicRegressionModel(boundaries, predictions, isotonic) + case _ => throw new Exception( + s"IsotonicRegressionModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $version). Supported:\n" + + s" ($classNameV1_0, 1.0)" + ) + } + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala index 7ef45248281e9..8e12340bbd9d6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.{Matchers, FunSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { @@ -73,6 +74,26 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M assert(model.isotonic) } + test("model save/load") { + val boundaries = Array(0.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0) + val predictions = Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0) + val model = new IsotonicRegressionModel(boundaries, predictions, true) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = IsotonicRegressionModel.load(sc, path) + assert(model.boundaries === sameModel.boundaries) + assert(model.predictions === sameModel.predictions) + assert(model.isotonic === model.isotonic) + } finally { + Utils.deleteRecursively(tempDir) + } + } + test("isotonic regression with size 0") { val model = runIsotonicRegression(Seq(), true) From 5fea3e5c36450658d8b767dd3c06dac2251a0e0c Mon Sep 17 00:00:00 2001 From: David McGuire Date: Tue, 21 Apr 2015 07:21:10 -0400 Subject: [PATCH 064/191] [SPARK-6985][streaming] Receiver maxRate over 1000 causes a StackOverflowError A simple truncation in integer division (on rates over 1000 messages / second) causes the existing implementation to sleep for 0 milliseconds, then call itself recursively; this causes what is essentially an infinite recursion, since the base case of the calculated amount of time having elapsed can't be reached before available stack space is exhausted. A fix to this truncation error is included in this patch. However, even with the defect patched, the accuracy of the existing implementation is abysmal (the error bounds of the original test were effectively [-30%, +10%], although this fact was obscured by hard-coded error margins); as such, when the error bounds were tightened down to [-5%, +5%], the existing implementation failed to meet the new, tightened, requirements. Therefore, an industry-vetted solution (from Guava) was used to get the adapted tests to pass. Author: David McGuire Closes #5559 from dmcguire81/master and squashes the following commits: d29d2e0 [David McGuire] Back out to +/-5% error margins, for flexibility in timing 8be6934 [David McGuire] Fix spacing per code review 90e98b9 [David McGuire] Address scalastyle errors 29011bd [David McGuire] Further ratchet down the error margins b33b796 [David McGuire] Eliminate dependency on even distribution by BlockGenerator 8f2934b [David McGuire] Remove arbitrary thread timing / cooperation code 70ee310 [David McGuire] Use Thread.yield(), since Thread.sleep(0) is system-dependent 82ee46d [David McGuire] Replace guard clause with nested conditional 2794717 [David McGuire] Replace the RateLimiter with the Guava implementation 38f3ca8 [David McGuire] Ratchet down the error rate to +/- 5%; tests fail 24b1bc0 [David McGuire] Fix truncation in integer division causing infinite recursion d6e1079 [David McGuire] Stack overflow error in RateLimiter on rates over 1000/s --- .../streaming/receiver/RateLimiter.scala | 33 +++---------------- .../spark/streaming/ReceiverSuite.scala | 29 +++++++++------- 2 files changed, 21 insertions(+), 41 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index e4f6ba626ebbf..97db9ded83367 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.receiver import org.apache.spark.{Logging, SparkConf} -import java.util.concurrent.TimeUnit._ +import com.google.common.util.concurrent.{RateLimiter=>GuavaRateLimiter} /** Provides waitToPush() method to limit the rate at which receivers consume data. * @@ -33,37 +33,12 @@ import java.util.concurrent.TimeUnit._ */ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { - private var lastSyncTime = System.nanoTime - private var messagesWrittenSinceSync = 0L private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0) - private val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + private lazy val rateLimiter = GuavaRateLimiter.create(desiredRate) def waitToPush() { - if( desiredRate <= 0 ) { - return - } - val now = System.nanoTime - val elapsedNanosecs = math.max(now - lastSyncTime, 1) - val rate = messagesWrittenSinceSync.toDouble * 1000000000 / elapsedNanosecs - if (rate < desiredRate) { - // It's okay to write; just update some variables and return - messagesWrittenSinceSync += 1 - if (now > lastSyncTime + SYNC_INTERVAL) { - // Sync interval has passed; let's resync - lastSyncTime = now - messagesWrittenSinceSync = 1 - } - } else { - // Calculate how much time we should sleep to bring ourselves to the desired rate. - val targetTimeInMillis = messagesWrittenSinceSync * 1000 / desiredRate - val elapsedTimeInMillis = elapsedNanosecs / 1000000 - val sleepTimeInMillis = targetTimeInMillis - elapsedTimeInMillis - if (sleepTimeInMillis > 0) { - logTrace("Natural rate is " + rate + " per second but desired rate is " + - desiredRate + ", sleeping for " + sleepTimeInMillis + " ms to compensate.") - Thread.sleep(sleepTimeInMillis) - } - waitToPush() + if (desiredRate > 0) { + rateLimiter.acquire() } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 91261a9db7360..e7aee6eadbfc7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -158,7 +158,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { test("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener val blockIntervalMs = 100 - val maxRate = 100 + val maxRate = 1001 val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms"). set("spark.streaming.receiver.maxRate", maxRate.toString) val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) @@ -176,7 +176,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { blockGenerator.addData(count) generatedData += count count += 1 - Thread.sleep(1) } blockGenerator.stop() @@ -185,25 +184,31 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { assert(blockGeneratorListener.arrayBuffers.size > 0, "No blocks received") assert(recordedData.toSet === generatedData.toSet, "Received data not same") - // recordedData size should be close to the expected rate - val minExpectedMessages = expectedMessages - 3 - val maxExpectedMessages = expectedMessages + 1 + // recordedData size should be close to the expected rate; use an error margin proportional to + // the value, so that rate changes don't cause a brittle test + val minExpectedMessages = expectedMessages - 0.05 * expectedMessages + val maxExpectedMessages = expectedMessages + 0.05 * expectedMessages val numMessages = recordedData.size assert( numMessages >= minExpectedMessages && numMessages <= maxExpectedMessages, s"#records received = $numMessages, not between $minExpectedMessages and $maxExpectedMessages" ) - val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3 - val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1 + // XXX Checking every block would require an even distribution of messages across blocks, + // which throttling code does not control. Therefore, test against the average. + val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 0.05 * expectedMessagesPerBlock + val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 0.05 * expectedMessagesPerBlock val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",") + + // the first and last block may be incomplete, so we slice them out + val validBlocks = recordedBlocks.drop(1).dropRight(1) + val averageBlockSize = validBlocks.map(block => block.size).sum / validBlocks.size + assert( - // the first and last block may be incomplete, so we slice them out - recordedBlocks.drop(1).dropRight(1).forall { block => - block.size >= minExpectedMessagesPerBlock && block.size <= maxExpectedMessagesPerBlock - }, + averageBlockSize >= minExpectedMessagesPerBlock && + averageBlockSize <= maxExpectedMessagesPerBlock, s"# records in received blocks = [$receivedBlockSizes], not between " + - s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock" + s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock, on average" ) } From c035c0f2d72f2a303b86fe0037ec43d756fff060 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 21 Apr 2015 11:01:18 -0700 Subject: [PATCH 065/191] [SPARK-5360] [SPARK-6606] Eliminate duplicate objects in serialized CoGroupedRDD CoGroupPartition, part of CoGroupedRDD, includes references to each RDD that the CoGroupedRDD narrowly depends on, and a reference to the ShuffleHandle. The partition is serialized separately from the RDD, so when the RDD and partition arrive on the worker, the references in the partition and in the RDD no longer point to the same object. This is a relatively minor performance issue (the closure can be 2x larger than it needs to be because the rdds and partitions are serialized twice; see numbers below) but is more annoying as a developer issue (this is where I ran into): if any state is stored in the RDD or ShuffleHandle on the worker side, subtle bugs can appear due to the fact that the references to the RDD / ShuffleHandle in the RDD and in the partition point to separate objects. I'm not sure if this is enough of a potential future problem to fix this old and central part of the code, so hoping to get input from others here. I did some simple experiments to see how much this effects closure size. For this example: $ val a = sc.parallelize(1 to 10).map((_, 1)) $ val b = sc.parallelize(1 to 2).map(x => (x, 2*x)) $ a.cogroup(b).collect() the closure was 1902 bytes with current Spark, and 1129 bytes after my change. The difference comes from eliminating duplicate serialization of the shuffle handle. For this example: $ val sortedA = a.sortByKey() $ val sortedB = b.sortByKey() $ sortedA.cogroup(sortedB).collect() the closure was 3491 bytes with current Spark, and 1333 bytes after my change. Here, the difference comes from eliminating duplicate serialization of the two RDDs for the narrow dependencies. The ShuffleHandle includes the ShuffleDependency, so this difference will get larger if a ShuffleDependency includes a serializer, a key ordering, or an aggregator (all set to None by default). It would also get bigger for a big RDD -- although I can't think of any examples where the RDD object gets large. The difference is not affected by the size of the function the user specifies, which (based on my understanding) is typically the source of large task closures. Author: Kay Ousterhout Closes #4145 from kayousterhout/SPARK-5360 and squashes the following commits: 85156c3 [Kay Ousterhout] Better comment the narrowDeps parameter cff0209 [Kay Ousterhout] Fixed spelling issue 658e1af [Kay Ousterhout] [SPARK-5360] Eliminate duplicate objects in serialized CoGroupedRDD --- .../org/apache/spark/rdd/CoGroupedRDD.scala | 43 +++++++++++-------- .../org/apache/spark/rdd/SubtractedRDD.scala | 30 +++++++------ 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 7021a339e879b..658e8c8b89318 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -29,15 +29,16 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.ShuffleHandle - -private[spark] sealed trait CoGroupSplitDep extends Serializable +/** The references to rdd and splitIndex are transient because redundant information is stored + * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from + * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the + * task closure. */ private[spark] case class NarrowCoGroupSplitDep( - rdd: RDD[_], - splitIndex: Int, + @transient rdd: RDD[_], + @transient splitIndex: Int, var split: Partition - ) extends CoGroupSplitDep { + ) extends Serializable { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { @@ -47,9 +48,16 @@ private[spark] case class NarrowCoGroupSplitDep( } } -private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep - -private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) +/** + * Stores information about the narrow dependencies used by a CoGroupedRdd. + * + * @param narrowDeps maps to the dependencies variable in the parent RDD: for each one to one + * dependency in dependencies, narrowDeps has a NarrowCoGroupSplitDep (describing + * the partition for that dependency) at the corresponding index. The size of + * narrowDeps should always be equal to the number of parents. + */ +private[spark] class CoGroupPartition( + idx: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]]) extends Partition with Serializable { override val index: Int = idx override def hashCode(): Int = idx @@ -105,9 +113,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: // Assume each RDD contributed a single dependency, and get it dependencies(j) match { case s: ShuffleDependency[_, _, _] => - new ShuffleCoGroupSplitDep(s.shuffleHandle) + None case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))) } }.toArray) } @@ -120,20 +128,21 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val sparkConf = SparkEnv.get.conf val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true) val split = s.asInstanceOf[CoGroupPartition] - val numRdds = split.deps.length + val numRdds = dependencies.length // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] - for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => + for ((dep, depNum) <- dependencies.zipWithIndex) dep match { + case oneToOneDependency: OneToOneDependency[Product2[K, Any]] => + val dependencyPartition = split.narrowDeps(depNum).get.split // Read them from the parent - val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] + val it = oneToOneDependency.rdd.iterator(dependencyPartition, context) rddIterators += ((it, depNum)) - case ShuffleCoGroupSplitDep(handle) => + case shuffleDependency: ShuffleDependency[_, _, _] => // Read map outputs of shuffle val it = SparkEnv.get.shuffleManager - .getReader(handle, split.index, split.index + 1, context) + .getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context) .read() rddIterators += ((it, depNum)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index e9d745588ee9a..633aeba3bbae6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -81,9 +81,9 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { case s: ShuffleDependency[_, _, _] => - new ShuffleCoGroupSplitDep(s.shuffleHandle) + None case _ => - new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))) } }.toArray) } @@ -105,20 +105,26 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( seq } } - def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit): Unit = dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) + def integrate(depNum: Int, op: Product2[K, V] => Unit) = { + dependencies(depNum) match { + case oneToOneDependency: OneToOneDependency[_] => + val dependencyPartition = partition.narrowDeps(depNum).get.split + oneToOneDependency.rdd.iterator(dependencyPartition, context) + .asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - case ShuffleCoGroupSplitDep(handle) => - val iter = SparkEnv.get.shuffleManager - .getReader(handle, partition.index, partition.index + 1, context) - .read() - iter.foreach(op) + case shuffleDependency: ShuffleDependency[_, _, _] => + val iter = SparkEnv.get.shuffleManager + .getReader( + shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context) + .read() + iter.foreach(op) + } } + // the first dep is rdd1; add all values to the map - integrate(partition.deps(0), t => getSeq(t._1) += t._2) + integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys - integrate(partition.deps(1), t => map.remove(t._1)) + integrate(1, t => map.remove(t._1)) map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten } From c25ca7c5a1f2a4f88f40b0c5cdbfa927c186cfa8 Mon Sep 17 00:00:00 2001 From: emres Date: Tue, 21 Apr 2015 16:39:56 -0400 Subject: [PATCH 066/191] SPARK-3276 Added a new configuration spark.streaming.minRememberDuration SPARK-3276 Added a new configuration parameter ``spark.streaming.minRememberDuration``, with a default value of 1 minute. So that when a Spark Streaming application is started, an arbitrary number of minutes can be taken as threshold for remembering. Author: emres Closes #5438 from emres/SPARK-3276 and squashes the following commits: 766f938 [emres] SPARK-3276 Switched to using newly added getTimeAsSeconds method. affee1d [emres] SPARK-3276 Changed the property name and variable name for minRememberDuration c9d58ca [emres] SPARK-3276 Minor code re-formatting. 1c53ba9 [emres] SPARK-3276 Started to use ssc.conf rather than ssc.sparkContext.getConf, and also getLong method directly. bfe0acb [emres] SPARK-3276 Moved the minRememberDurationMin to the class daccc82 [emres] SPARK-3276 Changed the property name to reflect the unit of value and reduced number of fields. 43cc1ce [emres] SPARK-3276 Added a new configuration parameter spark.streaming.minRemember duration, with a default value of 1 minute. --- .../streaming/dstream/FileInputDStream.scala | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 66d519171fd76..eca69f00188e4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.spark.SerializableWritable +import org.apache.spark.{SparkConf, SerializableWritable} import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming._ import org.apache.spark.util.{TimeStampedHashMap, Utils} @@ -63,7 +63,7 @@ import org.apache.spark.util.{TimeStampedHashMap, Utils} * the streaming app. * - If a file is to be visible in the directory listings, it must be visible within a certain * duration of the mod time of the file. This duration is the "remember window", which is set to - * 1 minute (see `FileInputDStream.MIN_REMEMBER_DURATION`). Otherwise, the file will never be + * 1 minute (see `FileInputDStream.minRememberDuration`). Otherwise, the file will never be * selected as the mod time will be less than the ignore threshold when it becomes visible. * - Once a file is visible, the mod time cannot change. If it does due to appends, then the * processing semantics are undefined. @@ -80,6 +80,15 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( private val serializableConfOpt = conf.map(new SerializableWritable(_)) + /** + * Minimum duration of remembering the information of selected files. Defaults to 60 seconds. + * + * Files with mod times older than this "window" of remembering will be ignored. So if new + * files are visible within this window, then the file will get selected in the next batch. + */ + private val minRememberDurationS = + Seconds(ssc.conf.getTimeAsSeconds("spark.streaming.minRememberDuration", "60s")) + // This is a def so that it works during checkpoint recovery: private def clock = ssc.scheduler.clock @@ -95,7 +104,8 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( * This would allow us to filter away not-too-old files which have already been recently * selected and processed. */ - private val numBatchesToRemember = FileInputDStream.calculateNumBatchesToRemember(slideDuration) + private val numBatchesToRemember = FileInputDStream + .calculateNumBatchesToRemember(slideDuration, minRememberDurationS) private val durationToRemember = slideDuration * numBatchesToRemember remember(durationToRemember) @@ -330,20 +340,14 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( private[streaming] object FileInputDStream { - /** - * Minimum duration of remembering the information of selected files. Files with mod times - * older than this "window" of remembering will be ignored. So if new files are visible - * within this window, then the file will get selected in the next batch. - */ - private val MIN_REMEMBER_DURATION = Minutes(1) - def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") /** * Calculate the number of last batches to remember, such that all the files selected in - * at least last MIN_REMEMBER_DURATION duration can be remembered. + * at least last minRememberDurationS duration can be remembered. */ - def calculateNumBatchesToRemember(batchDuration: Duration): Int = { - math.ceil(MIN_REMEMBER_DURATION.milliseconds.toDouble / batchDuration.milliseconds).toInt + def calculateNumBatchesToRemember(batchDuration: Duration, + minRememberDurationS: Duration): Int = { + math.ceil(minRememberDurationS.milliseconds.toDouble / batchDuration.milliseconds).toInt } } From 45c47fa4176ea75886a58f5d73c44afcb29aa629 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 21 Apr 2015 14:36:50 -0700 Subject: [PATCH 067/191] [SPARK-6845] [MLlib] [PySpark] Add isTranposed flag to DenseMatrix Since sparse matrices now support a isTransposed flag for row major data, DenseMatrices should do the same. Author: MechCoder Closes #5455 from MechCoder/spark-6845 and squashes the following commits: 525c370 [MechCoder] minor 004a37f [MechCoder] Cast boolean to int 151f3b6 [MechCoder] [WIP] Add isTransposed to pickle DenseMatrix cc0b90a [MechCoder] [SPARK-6845] Add isTranposed flag to DenseMatrix --- .../mllib/api/python/PythonMLLibAPI.scala | 13 +++-- python/pyspark/mllib/linalg.py | 49 +++++++++++++------ python/pyspark/mllib/tests.py | 16 ++++++ 3 files changed, 58 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index f976d2f97b043..6237b64c8f984 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -985,8 +985,10 @@ private[spark] object SerDe extends Serializable { val m: DenseMatrix = obj.asInstanceOf[DenseMatrix] val bytes = new Array[Byte](8 * m.values.size) val order = ByteOrder.nativeOrder() + val isTransposed = if (m.isTransposed) 1 else 0 ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values) + out.write(Opcodes.MARK) out.write(Opcodes.BININT) out.write(PickleUtils.integer_to_bytes(m.numRows)) out.write(Opcodes.BININT) @@ -994,19 +996,22 @@ private[spark] object SerDe extends Serializable { out.write(Opcodes.BINSTRING) out.write(PickleUtils.integer_to_bytes(bytes.length)) out.write(bytes) - out.write(Opcodes.TUPLE3) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(isTransposed)) + out.write(Opcodes.TUPLE) } def construct(args: Array[Object]): Object = { - if (args.length != 3) { - throw new PickleException("should be 3") + if (args.length != 4) { + throw new PickleException("should be 4") } val bytes = getBytes(args(2)) val n = bytes.length / 8 val values = new Array[Double](n) val order = ByteOrder.nativeOrder() ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values) - new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values) + val isTransposed = args(3).asInstanceOf[Int] == 1 + new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed) } } diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index ec8c879ea9389..cc9a4cf8ba170 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -638,9 +638,10 @@ class Matrix(object): Represents a local matrix. """ - def __init__(self, numRows, numCols): + def __init__(self, numRows, numCols, isTransposed=False): self.numRows = numRows self.numCols = numCols + self.isTransposed = isTransposed def toArray(self): """ @@ -662,14 +663,16 @@ class DenseMatrix(Matrix): """ Column-major dense matrix. """ - def __init__(self, numRows, numCols, values): - Matrix.__init__(self, numRows, numCols) + def __init__(self, numRows, numCols, values, isTransposed=False): + Matrix.__init__(self, numRows, numCols, isTransposed) values = self._convert_to_array(values, np.float64) assert len(values) == numRows * numCols self.values = values def __reduce__(self): - return DenseMatrix, (self.numRows, self.numCols, self.values.tostring()) + return DenseMatrix, ( + self.numRows, self.numCols, self.values.tostring(), + int(self.isTransposed)) def toArray(self): """ @@ -680,15 +683,23 @@ def toArray(self): array([[ 0., 2.], [ 1., 3.]]) """ - return self.values.reshape((self.numRows, self.numCols), order='F') + if self.isTransposed: + return np.asfortranarray( + self.values.reshape((self.numRows, self.numCols))) + else: + return self.values.reshape((self.numRows, self.numCols), order='F') def toSparse(self): """Convert to SparseMatrix""" - indices = np.nonzero(self.values)[0] + if self.isTransposed: + values = np.ravel(self.toArray(), order='F') + else: + values = self.values + indices = np.nonzero(values)[0] colCounts = np.bincount(indices // self.numRows) colPtrs = np.cumsum(np.hstack( (0, colCounts, np.zeros(self.numCols - colCounts.size)))) - values = self.values[indices] + values = values[indices] rowIndices = indices % self.numRows return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values) @@ -701,21 +712,28 @@ def __getitem__(self, indices): if j >= self.numCols or j < 0: raise ValueError("Column index %d is out of range [0, %d)" % (j, self.numCols)) - return self.values[i + j * self.numRows] + + if self.isTransposed: + return self.values[i * self.numCols + j] + else: + return self.values[i + j * self.numRows] def __eq__(self, other): - return (isinstance(other, DenseMatrix) and - self.numRows == other.numRows and - self.numCols == other.numCols and - all(self.values == other.values)) + if (not isinstance(other, DenseMatrix) or + self.numRows != other.numRows or + self.numCols != other.numCols): + return False + + self_values = np.ravel(self.toArray(), order='F') + other_values = np.ravel(other.toArray(), order='F') + return all(self_values == other_values) class SparseMatrix(Matrix): """Sparse Matrix stored in CSC format.""" def __init__(self, numRows, numCols, colPtrs, rowIndices, values, isTransposed=False): - Matrix.__init__(self, numRows, numCols) - self.isTransposed = isTransposed + Matrix.__init__(self, numRows, numCols, isTransposed) self.colPtrs = self._convert_to_array(colPtrs, np.int32) self.rowIndices = self._convert_to_array(rowIndices, np.int32) self.values = self._convert_to_array(values, np.float64) @@ -777,8 +795,7 @@ def toArray(self): return A def toDense(self): - densevals = np.reshape( - self.toArray(), (self.numRows * self.numCols), order='F') + densevals = np.ravel(self.toArray(), order='F') return DenseMatrix(self.numRows, self.numCols, densevals) # TODO: More efficient implementation: diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 849c88341a967..8f89e2cee0592 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -195,6 +195,22 @@ def test_sparse_matrix(self): self.assertEquals(expected[i][j], sm1t[i, j]) self.assertTrue(array_equal(sm1t.toArray(), expected)) + def test_dense_matrix_is_transposed(self): + mat1 = DenseMatrix(3, 2, [0, 4, 1, 6, 3, 9], isTransposed=True) + mat = DenseMatrix(3, 2, [0, 1, 3, 4, 6, 9]) + self.assertEquals(mat1, mat) + + expected = [[0, 4], [1, 6], [3, 9]] + for i in range(3): + for j in range(2): + self.assertEquals(mat1[i, j], expected[i][j]) + self.assertTrue(array_equal(mat1.toArray(), expected)) + + sm = mat1.toSparse() + self.assertTrue(array_equal(sm.rowIndices, [1, 2, 0, 1, 2])) + self.assertTrue(array_equal(sm.colPtrs, [0, 2, 5])) + self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) + class ListTests(PySparkTestCase): From 04bf34e34f22e3d7e972fe755251774fc6a6d52e Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 21 Apr 2015 14:43:46 -0700 Subject: [PATCH 068/191] [SPARK-7011] Build(compilation) fails with scala 2.11 option, because a protected[sql] type is accessed in ml package. [This](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala#L58) is where it is used and fails compilations at. Author: Prashant Sharma Closes #5593 from ScrapCodes/SPARK-7011/build-fix and squashes the following commits: e6d57a3 [Prashant Sharma] [SPARK-7011] Build fails with scala 2.11 option, because a protected[sql] type is accessed in ml package. --- .../src/main/scala/org/apache/spark/sql/types/dataTypes.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index c6fb22c26bd3c..a108413497829 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -299,7 +299,7 @@ class NullType private() extends DataType { case object NullType extends NullType -protected[sql] object NativeType { +protected[spark] object NativeType { val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) @@ -327,7 +327,7 @@ protected[sql] object PrimitiveType { } } -protected[sql] abstract class NativeType extends DataType { +protected[spark] abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] private[sql] val ordering: Ordering[JvmType] From 2e8c6ca47df14681c1110f0736234ce76a3eca9b Mon Sep 17 00:00:00 2001 From: vidmantas zemleris Date: Tue, 21 Apr 2015 14:47:09 -0700 Subject: [PATCH 069/191] [SPARK-6994] Allow to fetch field values by name in sql.Row It looked weird that up to now there was no way in Spark's Scala API to access fields of `DataFrame/sql.Row` by name, only by their index. This tries to solve this issue. Author: vidmantas zemleris Closes #5573 from vidma/features/row-with-named-fields and squashes the following commits: 6145ae3 [vidmantas zemleris] [SPARK-6994][SQL] Allow to fetch field values by name on Row 9564ebb [vidmantas zemleris] [SPARK-6994][SQL] Add fieldIndex to schema (StructType) --- .../main/scala/org/apache/spark/sql/Row.scala | 32 +++++++++ .../spark/sql/catalyst/expressions/rows.scala | 2 + .../apache/spark/sql/types/dataTypes.scala | 9 +++ .../scala/org/apache/spark/sql/RowTest.scala | 71 +++++++++++++++++++ .../spark/sql/types/DataTypeSuite.scala | 13 ++++ .../scala/org/apache/spark/sql/RowSuite.scala | 10 +++ 6 files changed, 137 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index ac8a782976465..4190b7ffe1c8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -306,6 +306,38 @@ trait Row extends Serializable { */ def getAs[T](i: Int): T = apply(i).asInstanceOf[T] + /** + * Returns the value of a given fieldName. + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws IllegalArgumentException when fieldName do not exist. + * @throws ClassCastException when data type does not match. + */ + def getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName)) + + /** + * Returns the index of a given field name. + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws IllegalArgumentException when fieldName do not exist. + */ + def fieldIndex(name: String): Int = { + throw new UnsupportedOperationException("fieldIndex on a Row without schema is undefined.") + } + + /** + * Returns a Map(name -> value) for the requested fieldNames + * + * @throws UnsupportedOperationException when schema is not defined. + * @throws IllegalArgumentException when fieldName do not exist. + * @throws ClassCastException when data type does not match. + */ + def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = { + fieldNames.map { name => + name -> getAs[T](name) + }.toMap + } + override def toString(): String = s"[${this.mkString(",")}]" /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index b6ec7d3417ef8..981373477a4bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -181,6 +181,8 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) /** No-arg constructor for serialization. */ protected def this() = this(null, null) + + override def fieldIndex(name: String): Int = schema.fieldIndex(name) } class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index a108413497829..7cd7bd1914c95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -1025,6 +1025,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private lazy val fieldNamesSet: Set[String] = fieldNames.toSet private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not @@ -1049,6 +1050,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(fields.filter(f => names.contains(f.name))) } + /** + * Returns index of a given field + */ + def fieldIndex(name: String): Int = { + nameToIndex.getOrElse(name, + throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + } + protected[sql] def toAttributes: Seq[AttributeReference] = map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala new file mode 100644 index 0000000000000..bbb9739e9cc76 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} +import org.apache.spark.sql.types._ +import org.scalatest.{Matchers, FunSpec} + +class RowTest extends FunSpec with Matchers { + + val schema = StructType( + StructField("col1", StringType) :: + StructField("col2", StringType) :: + StructField("col3", IntegerType) :: Nil) + val values = Array("value1", "value2", 1) + + val sampleRow: Row = new GenericRowWithSchema(values, schema) + val noSchemaRow: Row = new GenericRow(values) + + describe("Row (without schema)") { + it("throws an exception when accessing by fieldName") { + intercept[UnsupportedOperationException] { + noSchemaRow.fieldIndex("col1") + } + intercept[UnsupportedOperationException] { + noSchemaRow.getAs("col1") + } + } + } + + describe("Row (with schema)") { + it("fieldIndex(name) returns field index") { + sampleRow.fieldIndex("col1") shouldBe 0 + sampleRow.fieldIndex("col3") shouldBe 2 + } + + it("getAs[T] retrieves a value by fieldname") { + sampleRow.getAs[String]("col1") shouldBe "value1" + sampleRow.getAs[Int]("col3") shouldBe 1 + } + + it("Accessing non existent field throws an exception") { + intercept[IllegalArgumentException] { + sampleRow.getAs[String]("non_existent") + } + } + + it("getValuesMap() retrieves values of multiple fields as a Map(field -> value)") { + val expected = Map( + "col1" -> "value1", + "col2" -> "value2" + ) + sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index a1341ea13d810..d797510f36685 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -56,6 +56,19 @@ class DataTypeSuite extends FunSuite { } } + test("extract field index from a StructType") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + assert(struct.fieldIndex("a") === 0) + assert(struct.fieldIndex("b") === 1) + + intercept[IllegalArgumentException] { + struct.fieldIndex("non_existent") + } + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index bf6cf1321a056..fb3ba4bc1b908 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -62,4 +62,14 @@ class RowSuite extends FunSuite { val de = instance.deserialize(ser).asInstanceOf[Row] assert(de === row) } + + test("get values by field name on Row created via .toDF") { + val row = Seq((1, Seq(1))).toDF("a", "b").first() + assert(row.getAs[Int]("a") === 1) + assert(row.getAs[Seq[Int]]("b") === Seq(1)) + + intercept[IllegalArgumentException]{ + row.getAs[Int]("c") + } + } } From 03fd92167107f1d061c1a7ef216468b508546ac7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 21 Apr 2015 14:48:02 -0700 Subject: [PATCH 070/191] [SQL][minor] make it more clear that we only need to re-throw GetField exception for UnresolvedAttribute For `GetField` outside `UnresolvedAttribute`, we will throw exception in `Analyzer`. Author: Wenchen Fan Closes #5588 from cloud-fan/tmp and squashes the following commits: 7ac74d2 [Wenchen Fan] small refactor --- .../spark/sql/catalyst/analysis/CheckAnalysis.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1155dac28fc78..a986dd5387c38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -46,12 +46,11 @@ trait CheckAnalysis { operator transformExpressionsUp { case a: Attribute if !a.resolved => if (operator.childrenResolved) { - val nameParts = a match { - case UnresolvedAttribute(nameParts) => nameParts - case _ => Seq(a.name) + a match { + case UnresolvedAttribute(nameParts) => + // Throw errors for specific problems with get field. + operator.resolveChildren(nameParts, resolver, throwErrors = true) } - // Throw errors for specific problems with get field. - operator.resolveChildren(nameParts, resolver, throwErrors = true) } val from = operator.inputSet.map(_.name).mkString(", ") From 6265cba00f6141575b4be825735d77d4cea500ab Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 21 Apr 2015 14:48:42 -0700 Subject: [PATCH 071/191] [SPARK-6969][SQL] Refresh the cached table when REFRESH TABLE is used https://issues.apache.org/jira/browse/SPARK-6969 Author: Yin Huai Closes #5583 from yhuai/refreshTableRefreshDataCache and squashes the following commits: 1e5142b [Yin Huai] Add todo. 92b2498 [Yin Huai] Minor updates. 367df92 [Yin Huai] Recache data in the command of REFRESH TABLE. --- .../org/apache/spark/sql/sources/ddl.scala | 17 +++++++ .../spark/sql/hive/CachedTableSuite.scala | 50 ++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 2e861b84b7133..78d494184e759 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -347,7 +347,24 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + // Refresh the given table's metadata first. sqlContext.catalog.refreshTable(databaseName, tableName) + + // If this table is cached as a InMemoryColumnarRelation, drop the original + // cached version and make the new version cached lazily. + val logicalPlan = sqlContext.catalog.lookupRelation(Seq(databaseName, tableName)) + // Use lookupCachedData directly since RefreshTable also takes databaseName. + val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty + if (isCached) { + // Create a data frame to represent the table. + // TODO: Use uncacheTable once it supports database name. + val df = DataFrame(sqlContext, logicalPlan) + // Uncache the logicalPlan. + sqlContext.cacheManager.tryUncacheQuery(df, blocking = true) + // Cache it again. + sqlContext.cacheManager.cacheQuery(df, Some(tableName)) + } + Seq.empty[Row] } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index c188264072a84..fc6c3c35037b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.hive +import java.io.File + import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest} +import org.apache.spark.sql.{SaveMode, AnalysisException, DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId +import org.apache.spark.util.Utils class CachedTableSuite extends QueryTest { @@ -155,4 +158,49 @@ class CachedTableSuite extends QueryTest { assertCached(table("udfTest")) uncacheTable("udfTest") } + + test("REFRESH TABLE also needs to recache the data (data source tables)") { + val tempPath: File = Utils.createTempDir() + tempPath.delete() + table("src").save(tempPath.toString, "parquet", SaveMode.Overwrite) + sql("DROP TABLE IF EXISTS refreshTable") + createExternalTable("refreshTable", tempPath.toString, "parquet") + checkAnswer( + table("refreshTable"), + table("src").collect()) + // Cache the table. + sql("CACHE TABLE refreshTable") + assertCached(table("refreshTable")) + // Append new data. + table("src").save(tempPath.toString, "parquet", SaveMode.Append) + // We are still using the old data. + assertCached(table("refreshTable")) + checkAnswer( + table("refreshTable"), + table("src").collect()) + // Refresh the table. + sql("REFRESH TABLE refreshTable") + // We are using the new data. + assertCached(table("refreshTable")) + checkAnswer( + table("refreshTable"), + table("src").unionAll(table("src")).collect()) + + // Drop the table and create it again. + sql("DROP TABLE refreshTable") + createExternalTable("refreshTable", tempPath.toString, "parquet") + // It is not cached. + assert(!isCached("refreshTable"), "refreshTable should not be cached.") + // Refresh the table. REFRESH TABLE command should not make a uncached + // table cached. + sql("REFRESH TABLE refreshTable") + checkAnswer( + table("refreshTable"), + table("src").unionAll(table("src")).collect()) + // It is not cached. + assert(!isCached("refreshTable"), "refreshTable should not be cached.") + + sql("DROP TABLE refreshTable") + Utils.deleteRecursively(tempPath) + } } From 2a24bf92e6d36e876bad6a8b4e0ff12c407ebb8a Mon Sep 17 00:00:00 2001 From: Punya Biswal Date: Tue, 21 Apr 2015 14:50:02 -0700 Subject: [PATCH 072/191] [SPARK-6996][SQL] Support map types in java beans liancheng mengxr this is similar to #5146. Author: Punya Biswal Closes #5578 from punya/feature/SPARK-6996 and squashes the following commits: d56c3e0 [Punya Biswal] Fix imports c7e308b [Punya Biswal] Support java iterable types in POJOs 5e00685 [Punya Biswal] Support map types in java beans --- .../sql/catalyst/CatalystTypeConverters.scala | 20 ++++ .../apache/spark/sql/JavaTypeInference.scala | 110 ++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 52 +-------- .../apache/spark/sql/JavaDataFrameSuite.java | 57 +++++++-- 4 files changed, 180 insertions(+), 59 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d4f9fdacda4fb..a13e2f36a1a1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst +import java.lang.{Iterable => JavaIterable} import java.util.{Map => JavaMap} import scala.collection.mutable.HashMap @@ -49,6 +50,16 @@ object CatalystTypeConverters { case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (jit: JavaIterable[_], arrayType: ArrayType) => { + val iter = jit.iterator + var listOfItems: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + listOfItems :+= convertToCatalyst(item, arrayType.elementType) + } + listOfItems + } + case (s: Array[_], arrayType: ArrayType) => s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) @@ -124,6 +135,15 @@ object CatalystTypeConverters { extractOption(item) match { case a: Array[_] => a.toSeq.map(elementConverter) case s: Seq[_] => s.map(elementConverter) + case i: JavaIterable[_] => { + val iter = i.iterator + var convertedIterable: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + convertedIterable :+= elementConverter(item) + } + convertedIterable + } case null => null } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala new file mode 100644 index 0000000000000..db484c5f50074 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.beans.Introspector +import java.lang.{Iterable => JIterable} +import java.util.{Iterator => JIterator, Map => JMap} + +import com.google.common.reflect.TypeToken + +import org.apache.spark.sql.types._ + +import scala.language.existentials + +/** + * Type-inference utilities for POJOs and Java collections. + */ +private [sql] object JavaTypeInference { + + private val iterableType = TypeToken.of(classOf[JIterable[_]]) + private val mapType = TypeToken.of(classOf[JMap[_, _]]) + private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType + private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType + private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType + private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType + + /** + * Infers the corresponding SQL data type of a Java type. + * @param typeToken Java type + * @return (SQL data type, nullable) + */ + private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. + typeToken.getRawType match { + case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) + + case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) + case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) + case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) + case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) + case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) + case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) + case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) + + case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) + case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) + case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) + case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) + case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) + case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) + case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) + + case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) + case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) + case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) + + case _ if typeToken.isArray => + val (dataType, nullable) = inferDataType(typeToken.getComponentType) + (ArrayType(dataType, nullable), true) + + case _ if iterableType.isAssignableFrom(typeToken) => + val (dataType, nullable) = inferDataType(elementType(typeToken)) + (ArrayType(dataType, nullable), true) + + case _ if mapType.isAssignableFrom(typeToken) => + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] + val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]]) + val keyType = elementType(mapSupertype.resolveType(keySetReturnType)) + val valueType = elementType(mapSupertype.resolveType(valuesReturnType)) + val (keyDataType, _) = inferDataType(keyType) + val (valueDataType, nullable) = inferDataType(valueType) + (MapType(keyDataType, valueDataType, nullable), true) + + case _ => + val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) + val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + val fields = properties.map { property => + val returnType = typeToken.method(property.getReadMethod).getReturnType + val (dataType, nullable) = inferDataType(returnType) + new StructField(property.getName, dataType, nullable) + } + (new StructType(fields), true) + } + } + + private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]] + val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]]) + val iteratorType = iterableSupertype.resolveType(iteratorReturnType) + val itemType = iteratorType.resolveType(nextReturnType) + itemType + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f9f3eb2e03817..bcd20c06c6dca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -25,6 +25,8 @@ import scala.collection.immutable import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag +import com.google.common.reflect.TypeToken + import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD @@ -1222,56 +1224,12 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns a Catalyst Schema for the given java bean class. */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { - val (dataType, _) = inferDataType(beanClass) + val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass)) dataType.asInstanceOf[StructType].fields.map { f => AttributeReference(f.name, f.dataType, f.nullable)() } } - /** - * Infers the corresponding SQL data type of a Java class. - * @param clazz Java class - * @return (SQL data type, nullable) - */ - private def inferDataType(clazz: Class[_]): (DataType, Boolean) = { - // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. - clazz match { - case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => - (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) - - case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) - case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) - case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) - case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) - case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false) - case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false) - case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false) - case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false) - - case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true) - case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true) - case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true) - case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true) - case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true) - case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) - case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) - - case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) - case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) - case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) - - case c: Class[_] if c.isArray => - val (dataType, nullable) = inferDataType(c.getComponentType) - (ArrayType(dataType, nullable), true) - - case _ => - val beanInfo = Introspector.getBeanInfo(clazz) - val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") - val fields = properties.map { property => - val (dataType, nullable) = inferDataType(property.getPropertyType) - new StructField(property.getName, dataType, nullable) - } - (new StructType(fields), true) - } - } } + + diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 6d0fbe83c2f36..fc3ed4a708d46 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -17,23 +17,28 @@ package test.org.apache.spark.sql; -import java.io.Serializable; -import java.util.Arrays; - -import scala.collection.Seq; - -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Ints; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.TestData$; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.*; +import org.junit.*; + +import scala.collection.JavaConversions; +import scala.collection.Seq; +import scala.collection.mutable.Buffer; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; +import java.util.Map; import static org.apache.spark.sql.functions.*; @@ -106,6 +111,8 @@ public void testShow() { public static class Bean implements Serializable { private double a = 0.0; private Integer[] b = new Integer[]{0, 1}; + private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); + private List d = Arrays.asList("floppy", "disk"); public double getA() { return a; @@ -114,6 +121,14 @@ public double getA() { public Integer[] getB() { return b; } + + public Map getC() { + return c; + } + + public List getD() { + return d; + } } @Test @@ -127,7 +142,15 @@ public void testCreateDataFrameFromJavaBeans() { Assert.assertEquals( new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()), schema.apply("b")); - Row first = df.select("a", "b").first(); + ArrayType valueType = new ArrayType(DataTypes.IntegerType, false); + MapType mapType = new MapType(DataTypes.StringType, valueType, true); + Assert.assertEquals( + new StructField("c", mapType, true, Metadata.empty()), + schema.apply("c")); + Assert.assertEquals( + new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()), + schema.apply("d")); + Row first = df.select("a", "b", "c", "d").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. @@ -136,5 +159,15 @@ public void testCreateDataFrameFromJavaBeans() { for (int i = 0; i < result.length(); i++) { Assert.assertEquals(bean.getB()[i], result.apply(i)); } + Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello"); + Assert.assertArrayEquals( + bean.getC().get("hello"), + Ints.toArray(JavaConversions.asJavaList(outputBuffer))); + Seq d = first.getAs(3); + Assert.assertEquals(bean.getD().size(), d.length()); + for (int i = 0; i < d.length(); i++) { + Assert.assertEquals(bean.getD().get(i), d.apply(i)); + } } + } From 7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 21 Apr 2015 15:11:15 -0700 Subject: [PATCH 073/191] [SPARK-5817] [SQL] Fix bug of udtf with column names It's a bug while do query like: ```sql select d from (select explode(array(1,1)) d from src limit 1) t ``` And it will throws exception like: ``` org.apache.spark.sql.AnalysisException: cannot resolve 'd' given input columns _c0; line 1 pos 7 at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:48) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:45) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:50) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:249) at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$transformExpressionUp$1(QueryPlan.scala:103) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2$$anonfun$apply$2.apply(QueryPlan.scala:117) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47) at scala.collection.TraversableLike$class.map(TraversableLike.scala:244) at scala.collection.AbstractTraversable.map(Traversable.scala:105) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:116) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) ``` To solve the bug, it requires code refactoring for UDTF The major changes are about: * Simplifying the UDTF development, UDTF will manage the output attribute names any more, instead, the `logical.Generate` will handle that properly. * UDTF will be asked for the output schema (data types) during the logical plan analyzing. Author: Cheng Hao Closes #4602 from chenghao-intel/explode_bug and squashes the following commits: c2a5132 [Cheng Hao] add back resolved for Alias 556e982 [Cheng Hao] revert the unncessary change 002c361 [Cheng Hao] change the rule of resolved for Generate 04ae500 [Cheng Hao] add qualifier only for generator output 5ee5d2c [Cheng Hao] prepend the new qualifier d2e8b43 [Cheng Hao] Update the code as feedback ca5e7f4 [Cheng Hao] shrink the commits --- .../sql/catalyst/analysis/Analyzer.scala | 57 ++++++++++++++++++- .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++++ .../spark/sql/catalyst/dsl/package.scala | 3 +- .../sql/catalyst/expressions/generators.scala | 49 ++++------------ .../expressions/namedExpressions.scala | 2 + .../sql/catalyst/optimizer/Optimizer.scala | 8 +-- .../plans/logical/basicOperators.scala | 37 +++++++----- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../optimizer/FilterPushdownSuite.scala | 8 +-- .../org/apache/spark/sql/DataFrame.scala | 21 +++++-- .../apache/spark/sql/execution/Generate.scala | 22 ++----- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../apache/spark/sql/hive/HiveContext.scala | 1 - .../org/apache/spark/sql/hive/HiveQl.scala | 37 ++++++------ .../org/apache/spark/sql/hive/hiveUdfs.scala | 38 +------------ ... output-0-d1f244bce64f22b34ad5bf9fd360b632 | 1 + ...mn name-0-7ac701cf43e73e9e416888e4df694348 | 0 ...mn name-1-5cdf9d51fc0e105e365d82e7611e37f3 | 0 ...mn name-2-f963396461294e06cb7cafe22a1419e4 | 3 + ...n names-0-46bdb27b3359dc81d8c246b9f69d4b82 | 0 ...n names-1-cdf6989f3b055257f1692c3bbd80dc73 | 0 ...n names-2-ab3954b69d7a991bc801a509c3166cc5 | 3 + ...mn name-0-7ac701cf43e73e9e416888e4df694348 | 0 ...mn name-1-26599718c322ff4f9740040c066d8292 | 0 ...mn name-2-f963396461294e06cb7cafe22a1419e4 | 3 + .../sql/hive/execution/HiveQuerySuite.scala | 40 ++++++++++++- 26 files changed, 207 insertions(+), 145 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 create mode 100644 sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cb49e5ad5586f..5e42b409dcc59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -59,6 +58,7 @@ class Analyzer( ResolveReferences :: ResolveGroupingAnalytics :: ResolveSortReferences :: + ResolveGenerate :: ImplicitGenerate :: ResolveFunctions :: GlobalAggregates :: @@ -474,8 +474,59 @@ class Analyzer( */ object ImplicitGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Project(Seq(Alias(g: Generator, _)), child) => - Generate(g, join = false, outer = false, None, child) + case Project(Seq(Alias(g: Generator, name)), child) => + Generate(g, join = false, outer = false, + qualifier = None, UnresolvedAttribute(name) :: Nil, child) + case Project(Seq(MultiAlias(g: Generator, names)), child) => + Generate(g, join = false, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), child) + } + } + + /** + * Resolve the Generate, if the output names specified, we will take them, otherwise + * we will try to provide the default names, which follow the same rule with Hive. + */ + object ResolveGenerate extends Rule[LogicalPlan] { + // Construct the output attributes for the generator, + // The output attribute names can be either specified or + // auto generated. + private def makeGeneratorOutput( + generator: Generator, + generatorOutput: Seq[Attribute]): Seq[Attribute] = { + val elementTypes = generator.elementTypes + + if (generatorOutput.length == elementTypes.length) { + generatorOutput.zip(elementTypes).map { + case (a, (t, nullable)) if !a.resolved => + AttributeReference(a.name, t, nullable)() + case (a, _) => a + } + } else if (generatorOutput.length == 0) { + elementTypes.zipWithIndex.map { + // keep the default column names as Hive does _c0, _c1, _cN + case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)() + } + } else { + throw new AnalysisException( + s""" + |The number of aliases supplied in the AS clause does not match + |the number of columns output by the UDTF expected + |${elementTypes.size} aliases but got ${generatorOutput.size} + """.stripMargin) + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Generate if !p.child.resolved || !p.generator.resolved => p + case p: Generate if p.resolved == false => + // if the generator output names are not specified, we will use the default ones. + Generate( + p.generator, + join = p.join, + outer = p.outer, + p.qualifier, + makeGeneratorOutput(p.generator, p.generatorOutput), p.child) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a986dd5387c38..2381689e17525 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -38,6 +38,12 @@ trait CheckAnalysis { throw new AnalysisException(msg) } + def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { + exprs.flatMap(_.collect { + case e: Generator => true + }).length >= 1 + } + def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. @@ -110,6 +116,12 @@ trait CheckAnalysis { failAnalysis( s"unresolved operator ${operator.simpleString}") + case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => + failAnalysis( + s"""Only a single table generating function is allowed in a SELECT clause, found: + | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) + + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 21c15ad14fd19..4e5c64bb63c9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -284,12 +284,13 @@ package object dsl { seed: Int = (math.random * 1000).toInt): LogicalPlan = Sample(fraction, withReplacement, seed, logicalPlan) + // TODO specify the output column names def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, alias: Option[String] = None): LogicalPlan = - Generate(generator, join, outer, None, logicalPlan) + Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 67caadb839ff9..9a6cb048af5ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -42,47 +42,30 @@ abstract class Generator extends Expression { override type EvaluatedType = TraversableOnce[Row] - override lazy val dataType = - ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) + // TODO ideally we should return the type of ArrayType(StructType), + // however, we don't keep the output field names in the Generator. + override def dataType: DataType = throw new UnsupportedOperationException override def nullable: Boolean = false /** - * Should be overridden by specific generators. Called only once for each instance to ensure - * that rule application does not change the output schema of a generator. + * The output element data types in structure of Seq[(DataType, Nullable)] + * TODO we probably need to add more information like metadata etc. */ - protected def makeOutput(): Seq[Attribute] - - private var _output: Seq[Attribute] = null - - def output: Seq[Attribute] = { - if (_output == null) { - _output = makeOutput() - } - _output - } + def elementTypes: Seq[(DataType, Boolean)] /** Should be implemented by child classes to perform specific Generators. */ override def eval(input: Row): TraversableOnce[Row] - - /** Overridden `makeCopy` also copies the attributes that are produced by this generator. */ - override def makeCopy(newArgs: Array[AnyRef]): this.type = { - val copy = super.makeCopy(newArgs) - copy._output = _output - copy - } } /** * A generator that produces its output using the provided lambda function. */ case class UserDefinedGenerator( - schema: Seq[Attribute], + elementTypes: Seq[(DataType, Boolean)], function: Row => TraversableOnce[Row], children: Seq[Expression]) - extends Generator{ - - override protected def makeOutput(): Seq[Attribute] = schema + extends Generator { override def eval(input: Row): TraversableOnce[Row] = { // TODO(davies): improve this @@ -98,30 +81,18 @@ case class UserDefinedGenerator( /** * Given an input array produces a sequence of rows for each value in the array. */ -case class Explode(attributeNames: Seq[String], child: Expression) +case class Explode(child: Expression) extends Generator with trees.UnaryNode[Expression] { override lazy val resolved = child.resolved && (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) - private lazy val elementTypes = child.dataType match { + override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match { case ArrayType(et, containsNull) => (et, containsNull) :: Nil case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil } - // TODO: Move this pattern into Generator. - protected def makeOutput() = - if (attributeNames.size == elementTypes.size) { - attributeNames.zip(elementTypes).map { - case (n, (t, nullable)) => AttributeReference(n, t, nullable)() - } - } else { - elementTypes.zipWithIndex.map { - case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)() - } - } - override def eval(input: Row): TraversableOnce[Row] = { child.dataType match { case ArrayType(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index bcbcbeb31c7b5..afcb2ce8b9cb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -112,6 +112,8 @@ case class Alias(child: Expression, name: String)( extends NamedExpression with trees.UnaryNode[Expression] { override type EvaluatedType = Any + // Alias(Generator, xx) need to be transformed into Generate(generator, ...) + override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator] override def eval(input: Row): Any = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7c80634d2c852..2d03fbfb0d311 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -482,16 +482,16 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, - generate @ Generate(generator, join, outer, alias, grandChild)) => + case filter @ Filter(condition, g: Generate) => // Predicates that reference attributes produced by the `Generate` operator cannot // be pushed below the operator. val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf grandChild.outputSet + conjunct => conjunct.references subsetOf g.child.outputSet } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild)) + val withPushdown = Generate(g.generator, join = g.join, outer = g.outer, + g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) } else { filter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 17522976dc2c9..bbc94a7ab3398 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -40,34 +40,43 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. - * @param alias when set, this string is applied to the schema of the output of the transformation - * as a qualifier. + * @param qualifier Qualifier for the attributes of generator(UDTF) + * @param generatorOutput The output schema of the Generator. + * @param child Children logical plan node */ case class Generate( generator: Generator, join: Boolean, outer: Boolean, - alias: Option[String], + qualifier: Option[String], + generatorOutput: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - protected def generatorOutput: Seq[Attribute] = { - val output = alias - .map(a => generator.output.map(_.withQualifiers(a :: Nil))) - .getOrElse(generator.output) - if (join && outer) { - output.map(_.withNullability(true)) - } else { - output - } + override lazy val resolved: Boolean = { + generator.resolved && + childrenResolved && + generator.elementTypes.length == generatorOutput.length && + !generatorOutput.exists(!_.resolved) } - override def output: Seq[Attribute] = - if (join) child.output ++ generatorOutput else generatorOutput + // we don't want the gOutput to be taken as part of the expressions + // as that will cause exceptions like unresolved attributes etc. + override def expressions: Seq[Expression] = generator :: Nil + + def output: Seq[Attribute] = { + val qualified = qualifier.map(q => + // prepend the new qualifier to the existed one + generatorOutput.map(a => a.withQualifiers(q +: a.qualifiers)) + ).getOrElse(generatorOutput) + + if (join) child.output ++ qualified else qualified + } } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e10ddfdf5127c..7c249215bd6b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -90,7 +90,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved) - val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)()) + val explode = Explode(AttributeReference("a", IntegerType, nullable = true)()) assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved) assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 1448098c770aa..45cf695d20b01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -454,21 +454,21 @@ class FilterPushdownSuite extends PlanTest { test("generate: predicate referenced no generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), true, false, Some("arr")) .where(('b >= 5) && ('a > 6)) } val optimized = Optimize(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where(('b >= 5) && ('a > 6)) - .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze + .generate(Explode('c_arr), true, false, Some("arr")).analyze } comparePlans(optimized, correctAnswer) } test("generate: part of conjuncts referenced generated column") { - val generator = Explode(Seq("c"), 'c_arr) + val generator = Explode('c_arr) val originalQuery = { testRelationWithArrayType .generate(generator, true, false, Some("arr")) @@ -499,7 +499,7 @@ class FilterPushdownSuite extends PlanTest { test("generate: all conjuncts referenced generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), true, false, Some("arr")) .where(('c > 6) || ('b > 5)).analyze } val optimized = Optimize(originalQuery) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 45f5da387692e..03d9834d1d131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,7 +34,7 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ @@ -711,12 +711,16 @@ class DataFrame private[sql]( */ def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - val attributes = schema.toAttributes + + val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) } + val names = schema.toAttributes.map(_.name) + val rowFunction = f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row])) - val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) + val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) - Generate(generator, join = true, outer = false, None, logicalPlan) + Generate(generator, join = true, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) } /** @@ -733,12 +737,17 @@ class DataFrame private[sql]( : DataFrame = { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil + // TODO handle the metadata? + val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) } + val names = attributes.map(_.name) + def rowFunction(row: Row): TraversableOnce[Row] = { f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType))) } - val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) + val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) - Generate(generator, join = true, outer = false, None, logicalPlan) + Generate(generator, join = true, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) } ///////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 12271048bb39c..5201e20a10565 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -27,44 +27,34 @@ import org.apache.spark.sql.catalyst.expressions._ * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. + * @param output the output attributes of this node, which constructed in analysis phase, + * and we can not change it, as the parent node bound with it already. */ @DeveloperApi case class Generate( generator: Generator, join: Boolean, outer: Boolean, + output: Seq[Attribute], child: SparkPlan) extends UnaryNode { - // This must be a val since the generator output expr ids are not preserved by serialization. - protected val generatorOutput: Seq[Attribute] = { - if (join && outer) { - generator.output.map(_.withNullability(true)) - } else { - generator.output - } - } - - // This must be a val since the generator output expr ids are not preserved by serialization. - override val output = - if (join) child.output ++ generatorOutput else generatorOutput - val boundGenerator = BindReferences.bindReference(generator, child.output) override def execute(): RDD[Row] = { if (join) { child.execute().mapPartitions { iter => - val nullValues = Seq.fill(generator.output.size)(Literal(null)) + val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null)) // Used to produce rows with no matches when outer = true. val outerProjection = newProjection(child.output ++ nullValues, child.output) - val joinProjection = - newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput) + val joinProjection = newProjection(output, output) val joinedRow = new JoinedRow iter.flatMap {row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e687d01f57520..030ef118f75d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -312,8 +312,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Except(planLater(left), planLater(right)) :: Nil case logical.Intersect(left, right) => execution.Intersect(planLater(left), planLater(right)) :: Nil - case logical.Generate(generator, join, outer, _, child) => - execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil + case g @ logical.Generate(generator, join, outer, _, _, child) => + execution.Generate( + generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7c6a7df2bd01e..c4a73b3004076 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -249,7 +249,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.CreateTables :: catalog.PreInsertionCasts :: ExtractPythonUdfs :: - ResolveUdtfsAlias :: sources.PreInsertCastAndRename :: Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index fd305eb480e63..85061f22772dd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -725,12 +725,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText - Generate( - nodesToGenerator(clauses), - join = true, - outer = false, - Some(alias.toLowerCase), - withWhere) + val (generator, attributes) = nodesToGenerator(clauses) + Generate( + generator, + join = true, + outer = false, + Some(alias.toLowerCase), + attributes.map(UnresolvedAttribute(_)), + withWhere) }.getOrElse(withWhere) // The projection of the query can either be a normal projection, an aggregation @@ -833,12 +835,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText - Generate( - nodesToGenerator(clauses), - join = true, - outer = isOuter.nonEmpty, - Some(alias.toLowerCase), - nodeToRelation(relationClause)) + val (generator, attributes) = nodesToGenerator(clauses) + Generate( + generator, + join = true, + outer = isOuter.nonEmpty, + Some(alias.toLowerCase), + attributes.map(UnresolvedAttribute(_)), + nodeToRelation(relationClause)) /* All relations, possibly with aliases or sampling clauses. */ case Token("TOK_TABREF", clauses) => @@ -1311,7 +1315,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val explode = "(?i)explode".r - def nodesToGenerator(nodes: Seq[Node]): Generator = { + def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { val function = nodes.head val attributes = nodes.flatMap { @@ -1321,7 +1325,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C function match { case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => - Explode(attributes, nodeToExpr(child)) + (Explode(nodeToExpr(child)), attributes) case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => val functionInfo: FunctionInfo = @@ -1329,10 +1333,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - HiveGenericUdtf( + (HiveGenericUdtf( new HiveFunctionWrapper(functionClassName), - attributes, - children.map(nodeToExpr)) + children.map(nodeToExpr)), attributes) case a: ASTNode => throw new NotImplementedError( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 47305571e579e..4b6f0ad75f54f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -66,7 +66,7 @@ private[hive] abstract class HiveFunctionRegistry } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveUdaf(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children) + HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } @@ -266,7 +266,6 @@ private[hive] case class HiveUdaf( */ private[hive] case class HiveGenericUdtf( funcWrapper: HiveFunctionWrapper, - aliasNames: Seq[String], children: Seq[Expression]) extends Generator with HiveInspectors { @@ -282,23 +281,8 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val udtInput = new Array[AnyRef](children.length) - protected lazy val outputDataTypes = outputInspector.getAllStructFieldRefs.map { - field => inspectorToDataType(field.getFieldObjectInspector) - } - - override protected def makeOutput() = { - // Use column names when given, otherwise _c1, _c2, ... _cn. - if (aliasNames.size == outputDataTypes.size) { - aliasNames.zip(outputDataTypes).map { - case (attrName, attrDataType) => - AttributeReference(attrName, attrDataType, nullable = true)() - } - } else { - outputDataTypes.zipWithIndex.map { - case (attrDataType, i) => - AttributeReference(s"_c$i", attrDataType, nullable = true)() - } - } + lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { + field => (inspectorToDataType(field.getFieldObjectInspector), true) } override def eval(input: Row): TraversableOnce[Row] = { @@ -333,22 +317,6 @@ private[hive] case class HiveGenericUdtf( } } -/** - * Resolve Udtfs Alias. - */ -private[spark] object ResolveUdtfsAlias extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p @ Project(projectList, _) - if projectList.exists(_.isInstanceOf[MultiAlias]) && projectList.size != 1 => - throw new TreeNodeException(p, "only single Generator supported for SELECT clause") - - case Project(Seq(Alias(udtf @ HiveGenericUdtf(_, _, _), name)), child) => - Generate(udtf.copy(aliasNames = Seq(name)), join = false, outer = false, None, child) - case Project(Seq(MultiAlias(udtf @ HiveGenericUdtf(_, _, _), names)), child) => - Generate(udtf.copy(aliasNames = names), join = false, outer = false, None, child) - } -} - private[hive] case class HiveUdafFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], diff --git a/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 b/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 b/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 b/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 b/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 new file mode 100644 index 0000000000000..01e79c32a8c99 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 @@ -0,0 +1,3 @@ +1 +2 +3 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 new file mode 100644 index 0000000000000..0c7520f2090dd --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 @@ -0,0 +1,3 @@ +86 val_86 +238 val_238 +311 val_311 diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 b/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 b/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 b/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 new file mode 100644 index 0000000000000..01e79c32a8c99 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 @@ -0,0 +1,3 @@ +1 +2 +3 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 300b1f7920473..ac10b173307d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -27,7 +27,7 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive._ @@ -67,6 +67,40 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + createQueryTest("insert table with generator with column name", + """ + | CREATE TABLE gen_tmp (key Int); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(array(1,2,3)) AS val FROM src LIMIT 3; + | SELECT key FROM gen_tmp ORDER BY key ASC; + """.stripMargin) + + createQueryTest("insert table with generator with multiple column names", + """ + | CREATE TABLE gen_tmp (key Int, value String); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(map(key, value)) as (k1, k2) FROM src LIMIT 3; + | SELECT key, value FROM gen_tmp ORDER BY key, value ASC; + """.stripMargin) + + createQueryTest("insert table with generator without column name", + """ + | CREATE TABLE gen_tmp (key Int); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(array(1,2,3)) FROM src LIMIT 3; + | SELECT key FROM gen_tmp ORDER BY key ASC; + """.stripMargin) + + test("multiple generator in projection") { + intercept[AnalysisException] { + sql("SELECT explode(map(key, value)), key FROM src").collect() + } + + intercept[AnalysisException] { + sql("SELECT explode(map(key, value)) as k1, k2, key FROM src").collect() + } + } + createQueryTest("! operator", """ |SELECT a FROM ( @@ -456,7 +490,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view2", "SELECT * FROM src LATERAL VIEW explode(array(1,2)) tbl") - createQueryTest("lateral view3", "FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX") @@ -478,6 +511,9 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view6", "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v") + createQueryTest("Specify the udtf output", + "SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t") + test("sampling") { sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s") From f83c0f112d04173f4fc2c5eaf0f9cb11d9180077 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 21 Apr 2015 16:24:15 -0700 Subject: [PATCH 074/191] [SPARK-3386] Share and reuse SerializerInstances in shuffle paths This patch modifies several shuffle-related code paths to share and re-use SerializerInstances instead of creating new ones. Some serializers, such as KryoSerializer or SqlSerializer, can be fairly expensive to create or may consume moderate amounts of memory, so it's probably best to avoid unnecessary serializer creation in hot code paths. The key change in this patch is modifying `getDiskWriter()` / `DiskBlockObjectWriter` to accept `SerializerInstance`s instead of `Serializer`s (which are factories for instances). This allows the disk writer's creator to decide whether the serializer instance can be shared or re-used. The rest of the patch modifies several write and read paths to use shared serializers. One big win is in `ShuffleBlockFetcherIterator`, where we used to create a new serializer per received block. Similarly, the shuffle write path used to create a new serializer per file even though in many cases only a single thread would be writing to a file at a time. I made a small serializer reuse optimization in CoarseGrainedExecutorBackend as well, since it seemed like a small and obvious improvement. Author: Josh Rosen Closes #5606 from JoshRosen/SPARK-3386 and squashes the following commits: f661ce7 [Josh Rosen] Remove thread local; add comment instead 64f8398 [Josh Rosen] Use ThreadLocal for serializer instance in CoarseGrainedExecutorBackend aeb680e [Josh Rosen] [SPARK-3386] Reuse SerializerInstance in shuffle code paths --- .../executor/CoarseGrainedExecutorBackend.scala | 6 +++++- .../spark/shuffle/FileShuffleBlockManager.scala | 6 ++++-- .../org/apache/spark/storage/BlockManager.scala | 8 ++++---- .../apache/spark/storage/BlockObjectWriter.scala | 6 +++--- .../storage/ShuffleBlockFetcherIterator.scala | 6 ++++-- .../util/collection/ExternalAppendOnlyMap.scala | 6 ++---- .../spark/util/collection/ExternalSorter.scala | 14 +++++++++----- .../spark/storage/BlockObjectWriterSuite.scala | 6 +++--- 8 files changed, 34 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 8300f9f2190b9..8af46f3327adb 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -30,6 +30,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.{SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( @@ -47,6 +48,10 @@ private[spark] class CoarseGrainedExecutorBackend( var executor: Executor = null @volatile var driver: Option[RpcEndpointRef] = None + // If this CoarseGrainedExecutorBackend is changed to support multiple threads, then this may need + // to be changed so that we don't share the serializer instance across threads + private[this] val ser: SerializerInstance = env.closureSerializer.newInstance() + override def onStart() { import scala.concurrent.ExecutionContext.Implicits.global logInfo("Connecting to driver: " + driverUrl) @@ -83,7 +88,6 @@ private[spark] class CoarseGrainedExecutorBackend( logError("Received LaunchTask command but executor was null") System.exit(1) } else { - val ser = env.closureSerializer.newInstance() val taskDesc = ser.deserialize[TaskDescription](data.value) logInfo("Got assigned task " + taskDesc.taskId) executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber, diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 5be3ed771e534..538e150ead05a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -113,11 +113,12 @@ class FileShuffleBlockManager(conf: SparkConf) private var fileGroup: ShuffleFileGroup = null val openStartTime = System.nanoTime + val serializerInstance = serializer.newInstance() val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { fileGroup = getUnusedFileGroup() Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize, + blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize, writeMetrics) } } else { @@ -133,7 +134,8 @@ class FileShuffleBlockManager(conf: SparkConf) logWarning(s"Failed to remove existing shuffle file $blockFile") } } - blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics) + blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize, + writeMetrics) } } // Creating the file to write to and creating a disk writer both involve interacting with diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1aa0ef18de118..145a9c1ae3391 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -37,7 +37,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.{SerializerInstance, Serializer} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util._ @@ -646,13 +646,13 @@ private[spark] class BlockManager( def getDiskWriter( blockId: BlockId, file: File, - serializer: Serializer, + serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites, - writeMetrics) + new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream, + syncWrites, writeMetrics) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 0dfc91dfaff85..14833791f7a4d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -21,7 +21,7 @@ import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} import java.nio.channels.FileChannel import org.apache.spark.Logging -import org.apache.spark.serializer.{SerializationStream, Serializer} +import org.apache.spark.serializer.{SerializerInstance, SerializationStream} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.util.Utils @@ -71,7 +71,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { private[spark] class DiskBlockObjectWriter( blockId: BlockId, file: File, - serializer: Serializer, + serializerInstance: SerializerInstance, bufferSize: Int, compressStream: OutputStream => OutputStream, syncWrites: Boolean, @@ -134,7 +134,7 @@ private[spark] class DiskBlockObjectWriter( ts = new TimeTrackingOutputStream(fos) channel = fos.getChannel() bs = compressStream(new BufferedOutputStream(ts, bufferSize)) - objOut = serializer.newInstance().serializeStream(bs) + objOut = serializerInstance.serializeStream(bs) initialized = true this } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 8f28ef49a8a6f..f3379521d55e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -27,7 +27,7 @@ import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.Serializer +import org.apache.spark.serializer.{SerializerInstance, Serializer} import org.apache.spark.util.{CompletionIterator, Utils} /** @@ -106,6 +106,8 @@ final class ShuffleBlockFetcherIterator( private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + private[this] val serializerInstance: SerializerInstance = serializer.newInstance() + /** * Whether the iterator is still active. If isZombie is true, the callback interface will no * longer place fetched blocks into [[results]]. @@ -299,7 +301,7 @@ final class ShuffleBlockFetcherIterator( // the scheduler gets a FetchFailedException. Try(buf.createInputStream()).map { is0 => val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializer.newInstance().deserializeStream(is).asIterator + val iter = serializerInstance.deserializeStream(is).asIterator CompletionIterator[Any, Iterator[Any]](iter, { // Once the iterator is exhausted, release the buffer and set currentResult to null // so we don't release it again in cleanup. diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 9ff4744593d4d..30dd7f22e494f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -151,8 +151,7 @@ class ExternalAppendOnlyMap[K, V, C]( override protected[this] def spill(collection: SizeTracker): Unit = { val (blockId, file) = diskBlockManager.createTempLocalBlock() curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, - curWriteMetrics) + var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) var objectsWritten = 0 // List of batch sizes (bytes) in the order they are written to disk @@ -179,8 +178,7 @@ class ExternalAppendOnlyMap[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, - curWriteMetrics) + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) } } if (objectsWritten > 0) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 035f3767ff554..79a1a8a0dae38 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -272,7 +272,8 @@ private[spark] class ExternalSorter[K, V, C]( // createTempShuffleBlock here; see SPARK-3426 for more context. val (blockId, file) = diskBlockManager.createTempShuffleBlock() curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + var writer = blockManager.getDiskWriter( + blockId, file, serInstance, fileBufferSize, curWriteMetrics) var objectsWritten = 0 // Objects written since the last flush // List of batch sizes (bytes) in the order they are written to disk @@ -308,7 +309,8 @@ private[spark] class ExternalSorter[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + writer = blockManager.getDiskWriter( + blockId, file, serInstance, fileBufferSize, curWriteMetrics) } } if (objectsWritten > 0) { @@ -358,7 +360,9 @@ private[spark] class ExternalSorter[K, V, C]( // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use // createTempShuffleBlock here; see SPARK-3426 for more context. val (blockId, file) = diskBlockManager.createTempShuffleBlock() - blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open() + val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, + curWriteMetrics) + writer.open() } // Creating the file to write to and creating a disk writer both involve interacting with // the disk, and can take a long time in aggregate when we open many files, so should be @@ -749,8 +753,8 @@ private[spark] class ExternalSorter[K, V, C]( // partition and just write everything directly. for ((id, elements) <- this.partitionedIterator) { if (elements.hasNext) { - val writer = blockManager.getDiskWriter( - blockId, outputFile, ser, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get) + val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, + context.taskMetrics.shuffleWriteMetrics.get) for (elem <- elements) { writer.write(elem) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala index 78bbc4ec2c620..003a728cb84a0 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -30,7 +30,7 @@ class BlockObjectWriterSuite extends FunSuite { val file = new File(Utils.createTempDir(), "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.write(Long.box(20)) // Record metrics update on every write @@ -52,7 +52,7 @@ class BlockObjectWriterSuite extends FunSuite { val file = new File(Utils.createTempDir(), "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.write(Long.box(20)) // Record metrics update on every write @@ -75,7 +75,7 @@ class BlockObjectWriterSuite extends FunSuite { val file = new File(Utils.createTempDir(), "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, - new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) writer.open() writer.close() From a70e849c7f9e3df5e86113d45b8c4537597cfb29 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 21 Apr 2015 16:35:37 -0700 Subject: [PATCH 075/191] [minor] [build] Set java options when generating mima ignores. The default java options make the call to GenerateMIMAIgnore take forever to run since it's gc'ing all the time. Improve things by setting the perm gen size / max heap size to larger values. Author: Marcelo Vanzin Closes #5615 from vanzin/gen-mima-fix and squashes the following commits: f44e921 [Marcelo Vanzin] [minor] [build] Set java options when generating mima ignores. --- dev/mima | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dev/mima b/dev/mima index bed5cd042634e..2952fa65d42ff 100755 --- a/dev/mima +++ b/dev/mima @@ -27,16 +27,21 @@ cd "$FWDIR" echo -e "q\n" | build/sbt oldDeps/update rm -f .generated-mima* +generate_mima_ignore() { + SPARK_JAVA_OPTS="-XX:MaxPermSize=1g -Xmx2g" \ + ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore +} + # Generate Mima Ignore is called twice, first with latest built jars # on the classpath and then again with previous version jars on the classpath. # Because of a bug in GenerateMIMAIgnore that when old jars are ahead on classpath # it did not process the new classes (which are in assembly jar). -./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore +generate_mima_ignore export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" -./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore +generate_mima_ignore echo -e "q\n" | build/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" ret_val=$? From 7fe6142cd3c39ec79899878c3deca9d5130d05b1 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 21 Apr 2015 16:42:45 -0700 Subject: [PATCH 076/191] [SPARK-6065] [MLlib] Optimize word2vec.findSynonyms using blas calls 1. Use blas calls to find the dot product between two vectors. 2. Prevent re-computing the L2 norm of the given vector for each word in model. Author: MechCoder Closes #5467 from MechCoder/spark-6065 and squashes the following commits: dd0b0b2 [MechCoder] Preallocate wordVectors ffc9240 [MechCoder] Minor 6b74c81 [MechCoder] Switch back to native blas calls da1642d [MechCoder] Explicit types and indexing 64575b0 [MechCoder] Save indexedmap and a wordvecmat instead of matrix fbe0108 [MechCoder] Made the following changes 1. Calculate norms during initialization. 2. Use Blas calls from linalg.blas 1350cf3 [MechCoder] [SPARK-6065] Optimize word2vec.findSynonynms using blas calls --- .../apache/spark/mllib/feature/Word2Vec.scala | 57 +++++++++++++++++-- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index b2d9053f70145..98e83112f52ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -34,7 +34,7 @@ import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, BLAS, DenseVector} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.util.Utils @@ -429,7 +429,36 @@ class Word2Vec extends Serializable with Logging { */ @Experimental class Word2VecModel private[mllib] ( - private val model: Map[String, Array[Float]]) extends Serializable with Saveable { + model: Map[String, Array[Float]]) extends Serializable with Saveable { + + // wordList: Ordered list of words obtained from model. + private val wordList: Array[String] = model.keys.toArray + + // wordIndex: Maps each word to an index, which can retrieve the corresponding + // vector from wordVectors (see below). + private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap + + // vectorSize: Dimension of each word's vector. + private val vectorSize = model.head._2.size + private val numWords = wordIndex.size + + // wordVectors: Array of length numWords * vectorSize, vector corresponding to the word + // mapped with index i can be retrieved by the slice + // (ind * vectorSize, ind * vectorSize + vectorSize) + // wordVecNorms: Array of length numWords, each value being the Euclidean norm + // of the wordVector. + private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = { + val wordVectors = new Array[Float](vectorSize * numWords) + val wordVecNorms = new Array[Double](numWords) + var i = 0 + while (i < numWords) { + val vec = model.get(wordList(i)).get + Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize) + wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) + i += 1 + } + (wordVectors, wordVecNorms) + } private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -443,7 +472,7 @@ class Word2VecModel private[mllib] ( override protected def formatVersion = "1.0" def save(sc: SparkContext, path: String): Unit = { - Word2VecModel.SaveLoadV1_0.save(sc, path, model) + Word2VecModel.SaveLoadV1_0.save(sc, path, getVectors) } /** @@ -479,9 +508,23 @@ class Word2VecModel private[mllib] ( */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - // TODO: optimize top-k + val fVector = vector.toArray.map(_.toFloat) - model.mapValues(vec => cosineSimilarity(fVector, vec)) + val cosineVec = Array.fill[Float](numWords)(0) + val alpha: Float = 1 + val beta: Float = 0 + + blas.sgemv( + "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) + + // Need not divide with the norm of the given vector since it is constant. + val updatedCosines = new Array[Double](numWords) + var ind = 0 + while (ind < numWords) { + updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind) + ind += 1 + } + wordList.zip(updatedCosines) .toSeq .sortBy(- _._2) .take(num + 1) @@ -493,7 +536,9 @@ class Word2VecModel private[mllib] ( * Returns a map of words to their vector representations. */ def getVectors: Map[String, Array[Float]] = { - model + wordIndex.map { case (word, ind) => + (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) + } } } From 686dd742e11f6ad0078b7ff9b30b83a118fd8109 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 21 Apr 2015 16:44:52 -0700 Subject: [PATCH 077/191] [SPARK-7036][MLLIB] ALS.train should support DataFrames in PySpark SchemaRDD works with ALS.train in 1.2, so we should continue support DataFrames for compatibility. coderxiang Author: Xiangrui Meng Closes #5619 from mengxr/SPARK-7036 and squashes the following commits: dfcaf5a [Xiangrui Meng] ALS.train should support DataFrames in PySpark --- python/pyspark/mllib/recommendation.py | 36 +++++++++++++++++++------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 80e0a356bb78a..4b7d17d64e947 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -22,6 +22,7 @@ from pyspark.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc from pyspark.mllib.util import JavaLoader, JavaSaveable +from pyspark.sql import DataFrame __all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] @@ -78,18 +79,23 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): True >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) - >>> model.predict(2,2) + >>> model.predict(2, 2) + 3.8... + + >>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)]) + >>> model = ALS.train(df, 1, nonnegative=True, seed=10) + >>> model.predict(2, 2) 3.8... >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) - >>> model.predict(2,2) + >>> model.predict(2, 2) 0.4... >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = MatrixFactorizationModel.load(sc, path) - >>> sameModel.predict(2,2) + >>> sameModel.predict(2, 2) 0.4... >>> sameModel.predictAll(testset).collect() [Rating(... @@ -125,13 +131,20 @@ class ALS(object): @classmethod def _prepare(cls, ratings): - assert isinstance(ratings, RDD), "ratings should be RDD" + if isinstance(ratings, RDD): + pass + elif isinstance(ratings, DataFrame): + ratings = ratings.rdd + else: + raise TypeError("Ratings should be represented by either an RDD or a DataFrame, " + "but got %s." % type(ratings)) first = ratings.first() - if not isinstance(first, Rating): - if isinstance(first, (tuple, list)): - ratings = ratings.map(lambda x: Rating(*x)) - else: - raise ValueError("rating should be RDD of Rating or tuple/list") + if isinstance(first, Rating): + pass + elif isinstance(first, (tuple, list)): + ratings = ratings.map(lambda x: Rating(*x)) + else: + raise TypeError("Expect a Rating or a tuple/list, but got %s." % type(first)) return ratings @classmethod @@ -152,8 +165,11 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp def _test(): import doctest import pyspark.mllib.recommendation + from pyspark.sql import SQLContext globs = pyspark.mllib.recommendation.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest') + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: From ae036d08170202074b266afd17ce34b689c70b0c Mon Sep 17 00:00:00 2001 From: Alain Date: Tue, 21 Apr 2015 16:46:17 -0700 Subject: [PATCH 078/191] [Minor][MLLIB] Fix a minor formatting bug in toString method in Node.scala add missing comma and space Author: Alain Closes #5621 from AiHe/tree-node-issue and squashes the following commits: 159a7bb [Alain] [Minor][MLLIB] Fix a minor formatting bug in toString methods in Node.scala (cherry picked from commit 4508f01890a723f80d631424ff8eda166a13a727) Signed-off-by: Xiangrui Meng --- .../src/main/scala/org/apache/spark/mllib/tree/model/Node.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 708ba04b567d3..86390a20cb5cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -52,7 +52,7 @@ class Node ( override def toString: String = { "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "impurity = " + impurity + "split = " + split + ", stats = " + stats + "impurity = " + impurity + ", split = " + split + ", stats = " + stats } /** From b063a61b9852cf9b9d2c905332d2ecb2fd716cc4 Mon Sep 17 00:00:00 2001 From: mweindel Date: Tue, 21 Apr 2015 20:19:33 -0400 Subject: [PATCH 079/191] Avoid warning message about invalid refuse_seconds value in Mesos >=0.21... Starting with version 0.21.0, Apache Mesos is very noisy if the filter parameter refuse_seconds is set to an invalid value like `-1`. I have seen systems with millions of log lines like ``` W0420 18:00:48.773059 32352 hierarchical_allocator_process.hpp:589] Using the default value of 'refuse_seconds' to create the refused resources filter because the input value is negative ``` in the Mesos master INFO and WARNING log files. Therefore the CoarseMesosSchedulerBackend should set the default value for refuse seconds (i.e. 5 seconds) directly. This is no problem for the fine-grained MesosSchedulerBackend, as it uses the value 1 second for this parameter. Author: mweindel Closes #5597 from MartinWeindel/master and squashes the following commits: 2f99ffd [mweindel] Avoid warning message about invalid refuse_seconds value in Mesos >=0.21. --- .../scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index b037a4966ced0..82f652dae0378 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -207,7 +207,7 @@ private[spark] class CoarseMesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { synchronized { - val filters = Filters.newBuilder().setRefuseSeconds(-1).build() + val filters = Filters.newBuilder().setRefuseSeconds(5).build() for (offer <- offers) { val slaveId = offer.getSlaveId.toString From e72c16e30d85cdc394d318b5551698885cfda9b8 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 21 Apr 2015 20:33:57 -0400 Subject: [PATCH 080/191] [SPARK-6014] [core] Revamp Spark shutdown hooks, fix shutdown races. This change adds some new utility code to handle shutdown hooks in Spark. The main goal is to take advantage of Hadoop 2.x's API for shutdown hooks, which allows Spark to register a hook that will run before the one that cleans up HDFS clients, and thus avoids some races that would cause exceptions to show up and other issues such as failure to properly close event logs. Unfortunately, Hadoop 1.x does not have such APIs, so in that case correctness is still left to chance. Author: Marcelo Vanzin Closes #5560 from vanzin/SPARK-6014 and squashes the following commits: edfafb1 [Marcelo Vanzin] Better scaladoc. fcaeedd [Marcelo Vanzin] Merge branch 'master' into SPARK-6014 e7039dc [Marcelo Vanzin] [SPARK-6014] [core] Revamp Spark shutdown hooks, fix shutdown races. --- .../spark/deploy/history/HistoryServer.scala | 6 +- .../spark/deploy/worker/ExecutorRunner.scala | 12 +- .../spark/storage/DiskBlockManager.scala | 18 +-- .../spark/storage/TachyonBlockManager.scala | 24 ++-- .../scala/org/apache/spark/util/Utils.scala | 136 +++++++++++++++--- .../org/apache/spark/util/UtilsSuite.scala | 32 +++-- .../hive/thriftserver/HiveThriftServer2.scala | 9 +- .../hive/thriftserver/SparkSQLCLIDriver.scala | 9 +- .../spark/deploy/yarn/ApplicationMaster.scala | 63 ++++---- 9 files changed, 195 insertions(+), 114 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 72f6048239297..56bef57e55392 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -27,7 +27,7 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.SignalLogger +import org.apache.spark.util.{SignalLogger, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -194,9 +194,7 @@ object HistoryServer extends Logging { val server = new HistoryServer(conf, provider, securityManager, port) server.bind() - Runtime.getRuntime().addShutdownHook(new Thread("HistoryServerStopper") { - override def run(): Unit = server.stop() - }) + Utils.addShutdownHook { () => server.stop() } // Wait until the end of the world... or if the HistoryServer process is manually stopped while(true) { Thread.sleep(Int.MaxValue) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 7d5acabb95a48..7aa85b732fc87 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -28,6 +28,7 @@ import com.google.common.io.Files import org.apache.spark.{SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.util.Utils import org.apache.spark.util.logging.FileAppender /** @@ -61,7 +62,7 @@ private[deploy] class ExecutorRunner( // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might // make sense to remove this in the future. - private var shutdownHook: Thread = null + private var shutdownHook: AnyRef = null private[worker] def start() { workerThread = new Thread("ExecutorRunner for " + fullId) { @@ -69,12 +70,7 @@ private[deploy] class ExecutorRunner( } workerThread.start() // Shutdown hook that kills actors on shutdown. - shutdownHook = new Thread() { - override def run() { - killProcess(Some("Worker shutting down")) - } - } - Runtime.getRuntime.addShutdownHook(shutdownHook) + shutdownHook = Utils.addShutdownHook { () => killProcess(Some("Worker shutting down")) } } /** @@ -106,7 +102,7 @@ private[deploy] class ExecutorRunner( workerThread = null state = ExecutorState.KILLED try { - Runtime.getRuntime.removeShutdownHook(shutdownHook) + Utils.removeShutdownHook(shutdownHook) } catch { case e: IllegalStateException => None } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 2883137872600..7ea5e54f9e1fe 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -138,25 +138,17 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } } - private def addShutdownHook(): Thread = { - val shutdownHook = new Thread("delete Spark local dirs") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - DiskBlockManager.this.doStop() - } + private def addShutdownHook(): AnyRef = { + Utils.addShutdownHook { () => + logDebug("Shutdown hook called") + DiskBlockManager.this.doStop() } - Runtime.getRuntime.addShutdownHook(shutdownHook) - shutdownHook } /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { // Remove the shutdown hook. It causes memory leaks if we leave it around. - try { - Runtime.getRuntime.removeShutdownHook(shutdownHook) - } catch { - case e: IllegalStateException => None - } + Utils.removeShutdownHook(shutdownHook) doStop() } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index af873034215a9..951897cead996 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -135,21 +135,19 @@ private[spark] class TachyonBlockManager( private def addShutdownHook() { tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir)) - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark tachyon dirs") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - tachyonDirs.foreach { tachyonDir => - try { - if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { - Utils.deleteRecursively(tachyonDir, client) - } - } catch { - case e: Exception => - logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) + Utils.addShutdownHook { () => + logDebug("Shutdown hook called") + tachyonDirs.foreach { tachyonDir => + try { + if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { + Utils.deleteRecursively(tachyonDir, client) } + } catch { + case e: Exception => + logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } - client.close() } - }) + client.close() + } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 1029b0f9fce1e..7b0de1ae55b78 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,7 +21,7 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer -import java.util.{Properties, Locale, Random, UUID} +import java.util.{PriorityQueue, Properties, Locale, Random, UUID} import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection @@ -30,7 +30,7 @@ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.ClassTag -import scala.util.Try +import scala.util.{Failure, Success, Try} import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.{ByteStreams, Files} @@ -64,9 +64,15 @@ private[spark] object CallSite { private[spark] object Utils extends Logging { val random = new Random() + val DEFAULT_SHUTDOWN_PRIORITY = 100 + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null + + private val shutdownHooks = new SparkShutdownHookManager() + shutdownHooks.install() + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -176,18 +182,16 @@ private[spark] object Utils extends Logging { private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() // Add a shutdown hook to delete the temp dirs when the JVM exits - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dirs") { - override def run(): Unit = Utils.logUncaughtExceptions { - logDebug("Shutdown hook called") - shutdownDeletePaths.foreach { dirPath => - try { - Utils.deleteRecursively(new File(dirPath)) - } catch { - case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) - } + addShutdownHook { () => + logDebug("Shutdown hook called") + shutdownDeletePaths.foreach { dirPath => + try { + Utils.deleteRecursively(new File(dirPath)) + } catch { + case e: Exception => logError(s"Exception while deleting Spark temp dir: $dirPath", e) } } - }) + } // Register the path to be deleted via shutdown hook def registerShutdownDeleteDir(file: File) { @@ -613,7 +617,7 @@ private[spark] object Utils extends Logging { } Utils.setupSecureURLConnection(uc, securityMgr) - val timeoutMs = + val timeoutMs = conf.getTimeAsSeconds("spark.files.fetchTimeout", "60s").toInt * 1000 uc.setConnectTimeout(timeoutMs) uc.setReadTimeout(timeoutMs) @@ -1172,7 +1176,7 @@ private[spark] object Utils extends Logging { /** * Execute a block of code that evaluates to Unit, forwarding any uncaught exceptions to the * default UncaughtExceptionHandler - * + * * NOTE: This method is to be called by the spark-started JVM process. */ def tryOrExit(block: => Unit) { @@ -1185,11 +1189,11 @@ private[spark] object Utils extends Logging { } /** - * Execute a block of code that evaluates to Unit, stop SparkContext is there is any uncaught + * Execute a block of code that evaluates to Unit, stop SparkContext is there is any uncaught * exception - * - * NOTE: This method is to be called by the driver-side components to avoid stopping the - * user-started JVM process completely; in contrast, tryOrExit is to be called in the + * + * NOTE: This method is to be called by the driver-side components to avoid stopping the + * user-started JVM process completely; in contrast, tryOrExit is to be called in the * spark-started JVM process . */ def tryOrStopSparkContext(sc: SparkContext)(block: => Unit) { @@ -2132,6 +2136,102 @@ private[spark] object Utils extends Logging { .getOrElse(UserGroupInformation.getCurrentUser().getShortUserName()) } + /** + * Adds a shutdown hook with default priority. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(hook: () => Unit): AnyRef = { + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY, hook) + } + + /** + * Adds a shutdown hook with the given priority. Hooks with lower priority values run + * first. + * + * @param hook The code to run during shutdown. + * @return A handle that can be used to unregister the shutdown hook. + */ + def addShutdownHook(priority: Int, hook: () => Unit): AnyRef = { + shutdownHooks.add(priority, hook) + } + + /** + * Remove a previously installed shutdown hook. + * + * @param ref A handle returned by `addShutdownHook`. + * @return Whether the hook was removed. + */ + def removeShutdownHook(ref: AnyRef): Boolean = { + shutdownHooks.remove(ref) + } + +} + +private [util] class SparkShutdownHookManager { + + private val hooks = new PriorityQueue[SparkShutdownHook]() + private var shuttingDown = false + + /** + * Install a hook to run at shutdown and run all registered hooks in order. Hadoop 1.x does not + * have `ShutdownHookManager`, so in that case we just use the JVM's `Runtime` object and hope for + * the best. + */ + def install(): Unit = { + val hookTask = new Runnable() { + override def run(): Unit = runAll() + } + Try(Class.forName("org.apache.hadoop.util.ShutdownHookManager")) match { + case Success(shmClass) => + val fsPriority = classOf[FileSystem].getField("SHUTDOWN_HOOK_PRIORITY").get() + .asInstanceOf[Int] + val shm = shmClass.getMethod("get").invoke(null) + shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) + .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) + + case Failure(_) => + Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); + } + } + + def runAll(): Unit = synchronized { + shuttingDown = true + while (!hooks.isEmpty()) { + Utils.logUncaughtExceptions(hooks.poll().run()) + } + } + + def add(priority: Int, hook: () => Unit): AnyRef = synchronized { + checkState() + val hookRef = new SparkShutdownHook(priority, hook) + hooks.add(hookRef) + hookRef + } + + def remove(ref: AnyRef): Boolean = synchronized { + checkState() + hooks.remove(ref) + } + + private def checkState(): Unit = { + if (shuttingDown) { + throw new IllegalStateException("Shutdown hooks cannot be modified during shutdown.") + } + } + +} + +private class SparkShutdownHook(private val priority: Int, hook: () => Unit) + extends Comparable[SparkShutdownHook] { + + override def compareTo(other: SparkShutdownHook): Int = { + other.priority - priority + } + + def run(): Unit = hook() + } /** diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index fb97e650ff95c..1ba99803f5a0e 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.util -import scala.util.Random - import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols import java.util.concurrent.TimeUnit import java.util.Locale +import java.util.PriorityQueue + +import scala.collection.mutable.ListBuffer +import scala.util.Random import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files @@ -36,14 +38,14 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { - + test("timeConversion") { // Test -1 assert(Utils.timeStringAsSeconds("-1") === -1) - + // Test zero assert(Utils.timeStringAsSeconds("0") === 0) - + assert(Utils.timeStringAsSeconds("1") === 1) assert(Utils.timeStringAsSeconds("1s") === 1) assert(Utils.timeStringAsSeconds("1000ms") === 1) @@ -52,7 +54,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.timeStringAsSeconds("1min") === TimeUnit.MINUTES.toSeconds(1)) assert(Utils.timeStringAsSeconds("1h") === TimeUnit.HOURS.toSeconds(1)) assert(Utils.timeStringAsSeconds("1d") === TimeUnit.DAYS.toSeconds(1)) - + assert(Utils.timeStringAsMs("1") === 1) assert(Utils.timeStringAsMs("1ms") === 1) assert(Utils.timeStringAsMs("1000us") === 1) @@ -61,7 +63,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.timeStringAsMs("1min") === TimeUnit.MINUTES.toMillis(1)) assert(Utils.timeStringAsMs("1h") === TimeUnit.HOURS.toMillis(1)) assert(Utils.timeStringAsMs("1d") === TimeUnit.DAYS.toMillis(1)) - + // Test invalid strings intercept[NumberFormatException] { Utils.timeStringAsMs("This breaks 600s") @@ -79,7 +81,7 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { Utils.timeStringAsMs("This 123s breaks") } } - + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") @@ -466,4 +468,18 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { val newFileName = new File(testFileDir, testFileName) assert(newFileName.isFile()) } + + test("shutdown hook manager") { + val manager = new SparkShutdownHookManager() + val output = new ListBuffer[Int]() + + val hook1 = manager.add(1, () => output += 1) + manager.add(3, () => output += 3) + manager.add(2, () => output += 2) + manager.add(4, () => output += 4) + manager.remove(hook1) + + manager.runAll() + assert(output.toList === List(4, 3, 2)) + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index c3a3f8c0f41df..832596fc8bee5 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.scheduler.{SparkListenerApplicationEnd, SparkListener} +import org.apache.spark.util.Utils /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a @@ -57,13 +58,7 @@ object HiveThriftServer2 extends Logging { logInfo("Starting SparkContext") SparkSQLEnv.init() - Runtime.getRuntime.addShutdownHook( - new Thread() { - override def run() { - SparkSQLEnv.stop() - } - } - ) + Utils.addShutdownHook { () => SparkSQLEnv.stop() } try { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 85281c6d73a3b..7e307bb4ad1e8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -40,6 +40,7 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveShim +import org.apache.spark.util.Utils private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" @@ -101,13 +102,7 @@ private[hive] object SparkSQLCLIDriver { SessionState.start(sessionState) // Clean up after we exit - Runtime.getRuntime.addShutdownHook( - new Thread() { - override def run() { - SparkSQLEnv.stop() - } - } - ) + Utils.addShutdownHook { () => SparkSQLEnv.stop() } // "-h" option has been passed, so connect to Hive thrift server. if (sessionState.getHost != null) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index f7a84207e9da6..93ae45133ce24 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -25,7 +25,6 @@ import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -95,44 +94,38 @@ private[spark] class ApplicationMaster( logInfo("ApplicationAttemptId: " + appAttemptId) val fs = FileSystem.get(yarnConf) - val cleanupHook = new Runnable { - override def run() { - // If the SparkContext is still registered, shut it down as a best case effort in case - // users do not call sc.stop or do System.exit(). - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - } - val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) - val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts - - if (!finished) { - // This happens when the user application calls System.exit(). We have the choice - // of either failing or succeeding at this point. We report success to avoid - // retrying applications that have succeeded (System.exit(0)), which means that - // applications that explicitly exit with a non-zero status will also show up as - // succeeded in the RM UI. - finish(finalStatus, - ApplicationMaster.EXIT_SUCCESS, - "Shutdown hook called before final status was reported.") - } - if (!unregistered) { - // we only want to unregister if we don't want the RM to retry - if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { - unregister(finalStatus, finalMsg) - cleanupStagingDir(fs) - } + Utils.addShutdownHook { () => + // If the SparkContext is still registered, shut it down as a best case effort in case + // users do not call sc.stop or do System.exit(). + val sc = sparkContextRef.get() + if (sc != null) { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() + } + val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) + val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts + + if (!finished) { + // This happens when the user application calls System.exit(). We have the choice + // of either failing or succeeding at this point. We report success to avoid + // retrying applications that have succeeded (System.exit(0)), which means that + // applications that explicitly exit with a non-zero status will also show up as + // succeeded in the RM UI. + finish(finalStatus, + ApplicationMaster.EXIT_SUCCESS, + "Shutdown hook called before final status was reported.") + } + + if (!unregistered) { + // we only want to unregister if we don't want the RM to retry + if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { + unregister(finalStatus, finalMsg) + cleanupStagingDir(fs) } } } - // Use higher priority than FileSystem. - assert(ApplicationMaster.SHUTDOWN_HOOK_PRIORITY > FileSystem.SHUTDOWN_HOOK_PRIORITY) - ShutdownHookManager - .get().addShutdownHook(cleanupHook, ApplicationMaster.SHUTDOWN_HOOK_PRIORITY) - // Call this to force generation of secret so it gets populated into the // Hadoop UGI. This has to happen before the startUserApplication which does a // doAs in order for the credentials to be passed on to the executor containers. @@ -546,8 +539,6 @@ private[spark] class ApplicationMaster( object ApplicationMaster extends Logging { - val SHUTDOWN_HOOK_PRIORITY: Int = 30 - // exit codes for different causes, no reason behind the values private val EXIT_SUCCESS = 0 private val EXIT_UNCAUGHT_EXCEPTION = 10 From 3134c3fe495862b7687b5aa00d3344d09cd5e08e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 21 Apr 2015 17:49:55 -0700 Subject: [PATCH 081/191] [SPARK-6953] [PySpark] speed up python tests This PR try to speed up some python tests: ``` tests.py 144s -> 103s -41s mllib/classification.py 24s -> 17s -7s mllib/regression.py 27s -> 15s -12s mllib/tree.py 27s -> 13s -14s mllib/tests.py 64s -> 31s -33s streaming/tests.py 185s -> 84s -101s ``` Considering python3, the total saving will be 558s (almost 10 minutes) (core, and streaming run three times, mllib runs twice). During testing, it will show used time for each test file: ``` Run core tests ... Running test: pyspark/rdd.py ... ok (22s) Running test: pyspark/context.py ... ok (16s) Running test: pyspark/conf.py ... ok (4s) Running test: pyspark/broadcast.py ... ok (4s) Running test: pyspark/accumulators.py ... ok (4s) Running test: pyspark/serializers.py ... ok (6s) Running test: pyspark/profiler.py ... ok (5s) Running test: pyspark/shuffle.py ... ok (1s) Running test: pyspark/tests.py ... ok (103s) 144s ``` Author: Reynold Xin Author: Xiangrui Meng Closes #5605 from rxin/python-tests-speed and squashes the following commits: d08542d [Reynold Xin] Merge pull request #14 from mengxr/SPARK-6953 89321ee [Xiangrui Meng] fix seed in tests 3ad2387 [Reynold Xin] Merge pull request #5427 from davies/python_tests --- python/pyspark/mllib/classification.py | 17 ++--- python/pyspark/mllib/regression.py | 25 ++++--- python/pyspark/mllib/tests.py | 69 +++++++++--------- python/pyspark/mllib/tree.py | 15 ++-- python/pyspark/shuffle.py | 7 +- python/pyspark/sql/tests.py | 4 +- python/pyspark/streaming/tests.py | 63 ++++++++++------- python/pyspark/tests.py | 96 ++++++++++++++++---------- python/run-tests | 13 ++-- 9 files changed, 182 insertions(+), 127 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index eda0b60f8b1e7..a70c664a71fdb 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -86,7 +86,7 @@ class LogisticRegressionModel(LinearClassificationModel): ... LabeledPoint(0.0, [0.0, 1.0]), ... LabeledPoint(1.0, [1.0, 0.0]), ... ] - >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data)) + >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data), iterations=10) >>> lrm.predict([1.0, 0.0]) 1 >>> lrm.predict([0.0, 1.0]) @@ -95,7 +95,7 @@ class LogisticRegressionModel(LinearClassificationModel): [1, 0] >>> lrm.clearThreshold() >>> lrm.predict([0.0, 1.0]) - 0.123... + 0.279... >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), @@ -103,7 +103,7 @@ class LogisticRegressionModel(LinearClassificationModel): ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] - >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data)) + >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data), iterations=10) >>> lrm.predict(array([0.0, 1.0])) 1 >>> lrm.predict(array([1.0, 0.0])) @@ -129,7 +129,8 @@ class LogisticRegressionModel(LinearClassificationModel): ... LabeledPoint(1.0, [1.0, 0.0, 0.0]), ... LabeledPoint(2.0, [0.0, 0.0, 1.0]) ... ] - >>> mcm = LogisticRegressionWithLBFGS.train(data=sc.parallelize(multi_class_data), numClasses=3) + >>> data = sc.parallelize(multi_class_data) + >>> mcm = LogisticRegressionWithLBFGS.train(data, iterations=10, numClasses=3) >>> mcm.predict([0.0, 0.5, 0.0]) 0 >>> mcm.predict([0.8, 0.0, 0.0]) @@ -298,7 +299,7 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType ... LabeledPoint(0.0, [0.0, 1.0]), ... LabeledPoint(1.0, [1.0, 0.0]), ... ] - >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data)) + >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data), iterations=10) >>> lrm.predict([1.0, 0.0]) 1 >>> lrm.predict([0.0, 1.0]) @@ -330,14 +331,14 @@ class SVMModel(LinearClassificationModel): ... LabeledPoint(1.0, [2.0]), ... LabeledPoint(1.0, [3.0]) ... ] - >>> svm = SVMWithSGD.train(sc.parallelize(data)) + >>> svm = SVMWithSGD.train(sc.parallelize(data), iterations=10) >>> svm.predict([1.0]) 1 >>> svm.predict(sc.parallelize([[1.0]])).collect() [1] >>> svm.clearThreshold() >>> svm.predict(array([1.0])) - 1.25... + 1.44... >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: -1.0})), @@ -345,7 +346,7 @@ class SVMModel(LinearClassificationModel): ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] - >>> svm = SVMWithSGD.train(sc.parallelize(sparse_data)) + >>> svm = SVMWithSGD.train(sc.parallelize(sparse_data), iterations=10) >>> svm.predict(SparseVector(2, {1: 1.0})) 1 >>> svm.predict(SparseVector(2, {0: -1.0})) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index a0117c57133ae..4bc6351bdf02f 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -108,7 +108,8 @@ class LinearRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=np.array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=np.array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 @@ -135,12 +136,13 @@ class LinearRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, SparseVector(1, {0: 2.0})), ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, ... miniBatchFraction=1.0, initialWeights=array([1.0]), regParam=0.1, regType="l2", ... intercept=True, validateData=True) >>> abs(lrm.predict(array([0.0])) - 0) < 0.5 @@ -238,7 +240,7 @@ class LassoModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = LassoWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=10, initialWeights=array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 @@ -265,12 +267,13 @@ class LassoModel(LinearRegressionModelBase): ... LabeledPoint(3.0, SparseVector(1, {0: 2.0})), ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True - >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, ... validateData=True) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 @@ -321,7 +324,8 @@ class RidgeRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, [2.0]), ... LabeledPoint(2.0, [3.0]) ... ] - >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 @@ -348,12 +352,13 @@ class RidgeRegressionModel(LinearRegressionModelBase): ... LabeledPoint(3.0, SparseVector(1, {0: 2.0})), ... LabeledPoint(2.0, SparseVector(1, {0: 3.0})) ... ] - >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) + >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, + ... initialWeights=array([1.0])) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True - >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=100, step=1.0, + >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, ... regParam=0.01, miniBatchFraction=1.0, initialWeights=array([1.0]), intercept=True, ... validateData=True) >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 @@ -396,7 +401,7 @@ def _test(): from pyspark import SparkContext import pyspark.mllib.regression globs = pyspark.mllib.regression.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 8f89e2cee0592..1b008b93bc137 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -36,6 +36,7 @@ else: import unittest +from pyspark import SparkContext from pyspark.mllib.common import _to_java_object_rdd from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices @@ -47,7 +48,6 @@ from pyspark.mllib.feature import StandardScaler from pyspark.serializers import PickleSerializer from pyspark.sql import SQLContext -from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase _have_scipy = False try: @@ -58,6 +58,12 @@ pass ser = PickleSerializer() +sc = SparkContext('local[4]', "MLlib tests") + + +class MLlibTestCase(unittest.TestCase): + def setUp(self): + self.sc = sc def _squared_distance(a, b): @@ -67,7 +73,7 @@ def _squared_distance(a, b): return b.squared_distance(a) -class VectorTests(PySparkTestCase): +class VectorTests(MLlibTestCase): def _test_serialize(self, v): self.assertEqual(v, ser.loads(ser.dumps(v))) @@ -212,7 +218,7 @@ def test_dense_matrix_is_transposed(self): self.assertTrue(array_equal(sm.values, [1, 3, 4, 6, 9])) -class ListTests(PySparkTestCase): +class ListTests(MLlibTestCase): """ Test MLlib algorithms on plain lists, to make sure they're passed through @@ -255,7 +261,7 @@ def test_gmm(self): [-6, -7], ]) clusters = GaussianMixture.train(data, 2, convergenceTol=0.001, - maxIterations=100, seed=56) + maxIterations=10, seed=56) labels = clusters.predict(data).collect() self.assertEquals(labels[0], labels[1]) self.assertEquals(labels[2], labels[3]) @@ -266,9 +272,9 @@ def test_gmm_deterministic(self): y = range(0, 100, 10) data = self.sc.parallelize([[a, b] for a, b in zip(x, y)]) clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=100, seed=63) + maxIterations=10, seed=63) clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001, - maxIterations=100, seed=63) + maxIterations=10, seed=63) for c1, c2 in zip(clusters1.weights, clusters2.weights): self.assertEquals(round(c1, 7), round(c2, 7)) @@ -287,13 +293,13 @@ def test_classification(self): temp_dir = tempfile.mkdtemp() - lr_model = LogisticRegressionWithSGD.train(rdd) + lr_model = LogisticRegressionWithSGD.train(rdd, iterations=10) self.assertTrue(lr_model.predict(features[0]) <= 0) self.assertTrue(lr_model.predict(features[1]) > 0) self.assertTrue(lr_model.predict(features[2]) <= 0) self.assertTrue(lr_model.predict(features[3]) > 0) - svm_model = SVMWithSGD.train(rdd) + svm_model = SVMWithSGD.train(rdd, iterations=10) self.assertTrue(svm_model.predict(features[0]) <= 0) self.assertTrue(svm_model.predict(features[1]) > 0) self.assertTrue(svm_model.predict(features[2]) <= 0) @@ -307,7 +313,7 @@ def test_classification(self): categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories dt_model = DecisionTree.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) @@ -319,7 +325,8 @@ def test_classification(self): self.assertEqual(same_dt_model.toDebugString(), dt_model.toDebugString()) rf_model = RandomForest.trainClassifier( - rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100) + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, + maxBins=4, seed=1) self.assertTrue(rf_model.predict(features[0]) <= 0) self.assertTrue(rf_model.predict(features[1]) > 0) self.assertTrue(rf_model.predict(features[2]) <= 0) @@ -331,7 +338,7 @@ def test_classification(self): self.assertEqual(same_rf_model.toDebugString(), rf_model.toDebugString()) gbt_model = GradientBoostedTrees.trainClassifier( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) self.assertTrue(gbt_model.predict(features[0]) <= 0) self.assertTrue(gbt_model.predict(features[1]) > 0) self.assertTrue(gbt_model.predict(features[2]) <= 0) @@ -360,19 +367,19 @@ def test_regression(self): rdd = self.sc.parallelize(data) features = [p.features.tolist() for p in data] - lr_model = LinearRegressionWithSGD.train(rdd) + lr_model = LinearRegressionWithSGD.train(rdd, iterations=10) self.assertTrue(lr_model.predict(features[0]) <= 0) self.assertTrue(lr_model.predict(features[1]) > 0) self.assertTrue(lr_model.predict(features[2]) <= 0) self.assertTrue(lr_model.predict(features[3]) > 0) - lasso_model = LassoWithSGD.train(rdd) + lasso_model = LassoWithSGD.train(rdd, iterations=10) self.assertTrue(lasso_model.predict(features[0]) <= 0) self.assertTrue(lasso_model.predict(features[1]) > 0) self.assertTrue(lasso_model.predict(features[2]) <= 0) self.assertTrue(lasso_model.predict(features[3]) > 0) - rr_model = RidgeRegressionWithSGD.train(rdd) + rr_model = RidgeRegressionWithSGD.train(rdd, iterations=10) self.assertTrue(rr_model.predict(features[0]) <= 0) self.assertTrue(rr_model.predict(features[1]) > 0) self.assertTrue(rr_model.predict(features[2]) <= 0) @@ -380,35 +387,35 @@ def test_regression(self): categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories dt_model = DecisionTree.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, maxBins=4) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) rf_model = RandomForest.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100, seed=1) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=10, maxBins=4, seed=1) self.assertTrue(rf_model.predict(features[0]) <= 0) self.assertTrue(rf_model.predict(features[1]) > 0) self.assertTrue(rf_model.predict(features[2]) <= 0) self.assertTrue(rf_model.predict(features[3]) > 0) gbt_model = GradientBoostedTrees.trainRegressor( - rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numIterations=4) self.assertTrue(gbt_model.predict(features[0]) <= 0) self.assertTrue(gbt_model.predict(features[1]) > 0) self.assertTrue(gbt_model.predict(features[2]) <= 0) self.assertTrue(gbt_model.predict(features[3]) > 0) try: - LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0])) - LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0])) - RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0])) + LinearRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + LassoWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) + RidgeRegressionWithSGD.train(rdd, initialWeights=array([1.0, 1.0]), iterations=10) except ValueError: self.fail() -class StatTests(PySparkTestCase): +class StatTests(MLlibTestCase): # SPARK-4023 def test_col_with_different_rdds(self): # numpy @@ -438,7 +445,7 @@ def test_col_norms(self): self.assertTrue(math.fabs(summary2.normL2()[0] - expectedNormL2) < 1e-14) -class VectorUDTTests(PySparkTestCase): +class VectorUDTTests(MLlibTestCase): dv0 = DenseVector([]) dv1 = DenseVector([1.0, 2.0]) @@ -472,7 +479,7 @@ def test_infer_schema(self): @unittest.skipIf(not _have_scipy, "SciPy not installed") -class SciPyTests(PySparkTestCase): +class SciPyTests(MLlibTestCase): """ Test both vector operations and MLlib algorithms with SciPy sparse matrices, @@ -613,7 +620,7 @@ def test_regression(self): self.assertTrue(dt_model.predict(features[3]) > 0) -class ChiSqTestTests(PySparkTestCase): +class ChiSqTestTests(MLlibTestCase): def test_goodness_of_fit(self): from numpy import inf @@ -711,13 +718,13 @@ def test_right_number_of_results(self): self.assertIsNotNone(chi[1000]) -class SerDeTest(PySparkTestCase): +class SerDeTest(MLlibTestCase): def test_to_java_object_rdd(self): # SPARK-6660 data = RandomRDDs.uniformRDD(self.sc, 10, 5, seed=0) self.assertEqual(_to_java_object_rdd(data).count(), 10) -class FeatureTest(PySparkTestCase): +class FeatureTest(MLlibTestCase): def test_idf_model(self): data = [ Vectors.dense([1, 2, 6, 0, 2, 3, 1, 1, 0, 0, 3]), @@ -730,13 +737,8 @@ def test_idf_model(self): self.assertEqual(len(idf), 11) -class Word2VecTests(PySparkTestCase): +class Word2VecTests(MLlibTestCase): def test_word2vec_setters(self): - data = [ - ["I", "have", "a", "pen"], - ["I", "like", "soccer", "very", "much"], - ["I", "live", "in", "Tokyo"] - ] model = Word2Vec() \ .setVectorSize(2) \ .setLearningRate(0.01) \ @@ -765,7 +767,7 @@ def test_word2vec_get_vectors(self): self.assertEquals(len(model.getVectors()), 3) -class StandardScalerTests(PySparkTestCase): +class StandardScalerTests(MLlibTestCase): def test_model_setters(self): data = [ [1.0, 2.0, 3.0], @@ -793,3 +795,4 @@ def test_model_transform(self): unittest.main() if not _have_scipy: print("NOTE: SciPy tests were skipped as it does not seem to be installed") + sc.stop() diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 0fe6e4fabe43a..cfcbea573fd22 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -482,13 +482,13 @@ def trainClassifier(cls, data, categoricalFeaturesInfo, ... LabeledPoint(1.0, [3.0]) ... ] >>> - >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}) + >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}, numIterations=10) >>> model.numTrees() - 100 + 10 >>> model.totalNumNodes() - 300 + 30 >>> print(model) # it already has newline - TreeEnsembleModel classifier with 100 trees + TreeEnsembleModel classifier with 10 trees >>> model.predict([2.0]) 1.0 @@ -541,11 +541,12 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] >>> - >>> model = GradientBoostedTrees.trainRegressor(sc.parallelize(sparse_data), {}) + >>> data = sc.parallelize(sparse_data) + >>> model = GradientBoostedTrees.trainRegressor(data, {}, numIterations=10) >>> model.numTrees() - 100 + 10 >>> model.totalNumNodes() - 102 + 12 >>> model.predict(SparseVector(2, {1: 1.0})) 1.0 >>> model.predict(SparseVector(2, {0: 1.0})) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index b54baa57ec28a..1d0b16cade8bb 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -486,7 +486,7 @@ def sorted(self, iterator, key=None, reverse=False): goes above the limit. """ global MemoryBytesSpilled, DiskBytesSpilled - batch, limit = 100, self._next_limit() + batch, limit = 100, self.memory_limit chunks, current_chunk = [], [] iterator = iter(iterator) while True: @@ -497,7 +497,7 @@ def sorted(self, iterator, key=None, reverse=False): break used_memory = get_used_memory() - if used_memory > self.memory_limit: + if used_memory > limit: # sort them inplace will save memory current_chunk.sort(key=key, reverse=reverse) path = self._get_path(len(chunks)) @@ -513,13 +513,14 @@ def load(f): chunks.append(load(open(path, 'rb'))) current_chunk = [] gc.collect() + batch //= 2 limit = self._next_limit() MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 DiskBytesSpilled += os.path.getsize(path) os.unlink(path) # data will be deleted after close elif not chunks: - batch = min(batch * 2, 10000) + batch = min(int(batch * 1.5), 10000) current_chunk.sort(key=key, reverse=reverse) if not chunks: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 23e84283679e1..fe43c374f1cb1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -109,7 +109,7 @@ def setUpClass(cls): os.unlink(cls.tempdir.name) cls.sqlCtx = SQLContext(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - rdd = cls.sc.parallelize(cls.testData) + rdd = cls.sc.parallelize(cls.testData, 2) cls.df = rdd.toDF() @classmethod @@ -303,7 +303,7 @@ def test_apply_schema(self): abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" schema = _parse_schema_abstract(abstract) typedSchema = _infer_schema_type(rdd.first(), schema) - df = self.sqlCtx.applySchema(rdd, typedSchema) + df = self.sqlCtx.createDataFrame(rdd, typedSchema) r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3]) self.assertEqual(r, tuple(df.first())) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 33f958a601f3a..5fa1e5ef081ab 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -16,14 +16,23 @@ # import os +import sys from itertools import chain import time import operator -import unittest import tempfile import struct from functools import reduce +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import KafkaUtils @@ -31,19 +40,25 @@ class PySparkStreamingTestCase(unittest.TestCase): - timeout = 20 # seconds - duration = 1 + timeout = 4 # seconds + duration = .2 - def setUp(self): - class_name = self.__class__.__name__ + @classmethod + def setUpClass(cls): + class_name = cls.__name__ conf = SparkConf().set("spark.default.parallelism", 1) - self.sc = SparkContext(appName=class_name, conf=conf) - self.sc.setCheckpointDir("/tmp") - # TODO: decrease duration to speed up tests + cls.sc = SparkContext(appName=class_name, conf=conf) + cls.sc.setCheckpointDir("/tmp") + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + + def setUp(self): self.ssc = StreamingContext(self.sc, self.duration) def tearDown(self): - self.ssc.stop() + self.ssc.stop(False) def wait_for(self, result, n): start_time = time.time() @@ -363,13 +378,13 @@ def func(dstream): class WindowFunctionTests(PySparkStreamingTestCase): - timeout = 20 + timeout = 5 def test_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.window(3, 1).count() + return dstream.window(.6, .2).count() expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -378,7 +393,7 @@ def test_count_by_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.countByWindow(3, 1) + return dstream.countByWindow(.6, .2) expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -387,7 +402,7 @@ def test_count_by_window_large(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByWindow(5, 1) + return dstream.countByWindow(1, .2) expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] self._test_func(input, func, expected) @@ -396,7 +411,7 @@ def test_count_by_value_and_window(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByValueAndWindow(5, 1) + return dstream.countByValueAndWindow(1, .2) expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] self._test_func(input, func, expected) @@ -405,7 +420,7 @@ def test_group_by_key_and_window(self): input = [[('a', i)] for i in range(5)] def func(dstream): - return dstream.groupByKeyAndWindow(3, 1).mapValues(list) + return dstream.groupByKeyAndWindow(.6, .2).mapValues(list) expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] @@ -436,8 +451,8 @@ def test_stop_only_streaming_context(self): def test_stop_multiple_times(self): self._add_input_stream() self.ssc.start() - self.ssc.stop() - self.ssc.stop() + self.ssc.stop(False) + self.ssc.stop(False) def test_queue_stream(self): input = [list(range(i + 1)) for i in range(3)] @@ -495,10 +510,7 @@ def func(rdds): self.assertEqual([2, 3, 1], self._take(dstream, 3)) -class CheckpointTests(PySparkStreamingTestCase): - - def setUp(self): - pass +class CheckpointTests(unittest.TestCase): def test_get_or_create(self): inputd = tempfile.mkdtemp() @@ -518,12 +530,12 @@ def setup(): return ssc cpd = tempfile.mkdtemp("test_streaming_cps") - self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc = StreamingContext.getOrCreate(cpd, setup) ssc.start() def check_output(n): while not os.listdir(outputd): - time.sleep(0.1) + time.sleep(0.01) time.sleep(1) # make sure mtime is larger than the previous one with open(os.path.join(inputd, str(n)), 'w') as f: f.writelines(["%d\n" % i for i in range(10)]) @@ -553,12 +565,15 @@ def check_output(n): ssc.stop(True, True) time.sleep(1) - self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc = StreamingContext.getOrCreate(cpd, setup) ssc.start() check_output(3) + ssc.stop(True, True) class KafkaStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 def setUp(self): super(KafkaStreamTests, self).setUp() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 75f39d9e75f38..ea63a396da5b8 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -31,7 +31,6 @@ import time import zipfile import random -import itertools import threading import hashlib @@ -49,6 +48,11 @@ xrange = range basestring = str +if sys.version >= "3": + from io import StringIO +else: + from StringIO import StringIO + from pyspark.conf import SparkConf from pyspark.context import SparkContext @@ -196,7 +200,7 @@ def test_external_sort_in_rdd(self): sc = SparkContext(conf=conf) l = list(range(10240)) random.shuffle(l) - rdd = sc.parallelize(l, 2) + rdd = sc.parallelize(l, 4) self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) sc.stop() @@ -300,6 +304,18 @@ def test_hash_serializer(self): hash(FlattenedValuesSerializer(PickleSerializer())) +class QuietTest(object): + def __init__(self, sc): + self.log4j = sc._jvm.org.apache.log4j + + def __enter__(self): + self.old_level = self.log4j.LogManager.getRootLogger().getLevel() + self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.log4j.LogManager.getRootLogger().setLevel(self.old_level) + + class PySparkTestCase(unittest.TestCase): def setUp(self): @@ -371,15 +387,11 @@ def test_add_py_file(self): # To ensure that we're actually testing addPyFile's effects, check that # this job fails due to `userlibrary` not being on the Python path: # disable logging in log4j temporarily - log4j = self.sc._jvm.org.apache.log4j - old_level = log4j.LogManager.getRootLogger().getLevel() - log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) - def func(x): from userlibrary import UserClass return UserClass().hello() - self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) - log4j.LogManager.getRootLogger().setLevel(old_level) + with QuietTest(self.sc): + self.assertRaises(Exception, self.sc.parallelize(range(2)).map(func).first) # Add the file, so the job should now succeed: path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py") @@ -496,7 +508,8 @@ def test_deleting_input_files(self): filtered_data = data.filter(lambda x: True) self.assertEqual(1, filtered_data.count()) os.unlink(tempFile.name) - self.assertRaises(Exception, lambda: filtered_data.count()) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) def test_sampling_default_seed(self): # Test for SPARK-3995 (default seed setting) @@ -536,9 +549,9 @@ def test_namedtuple_in_rdd(self): self.assertEqual([jon, jane], theDoes.collect()) def test_large_broadcast(self): - N = 100000 + N = 10000 data = [[float(i) for i in range(300)] for i in range(N)] - bdata = self.sc.broadcast(data) # 270MB + bdata = self.sc.broadcast(data) # 27MB m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEqual(N, m) @@ -569,7 +582,7 @@ def test_multiple_broadcasts(self): self.assertEqual(checksum, csum) def test_large_closure(self): - N = 1000000 + N = 200000 data = [float(i) for i in xrange(N)] rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) self.assertEqual(N, rdd.first()) @@ -604,17 +617,18 @@ def test_zip_with_different_number_of_items(self): # different number of partitions b = self.sc.parallelize(range(100, 106), 3) self.assertRaises(ValueError, lambda: a.zip(b)) - # different number of batched items in JVM - b = self.sc.parallelize(range(100, 104), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # different number of items in one pair - b = self.sc.parallelize(range(100, 106), 2) - self.assertRaises(Exception, lambda: a.zip(b).count()) - # same total number of items, but different distributions - a = self.sc.parallelize([2, 3], 2).flatMap(range) - b = self.sc.parallelize([3, 2], 2).flatMap(range) - self.assertEqual(a.count(), b.count()) - self.assertRaises(Exception, lambda: a.zip(b).count()) + with QuietTest(self.sc): + # different number of batched items in JVM + b = self.sc.parallelize(range(100, 104), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # different number of items in one pair + b = self.sc.parallelize(range(100, 106), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # same total number of items, but different distributions + a = self.sc.parallelize([2, 3], 2).flatMap(range) + b = self.sc.parallelize([3, 2], 2).flatMap(range) + self.assertEqual(a.count(), b.count()) + self.assertRaises(Exception, lambda: a.zip(b).count()) def test_count_approx_distinct(self): rdd = self.sc.parallelize(range(1000)) @@ -877,7 +891,12 @@ def test_profiler(self): func_names = [func_name for fname, n, func_name in stat_list] self.assertTrue("heavy_foo" in func_names) + old_stdout = sys.stdout + sys.stdout = io = StringIO() self.sc.show_profiles() + self.assertTrue("heavy_foo" in io.getvalue()) + sys.stdout = old_stdout + d = tempfile.gettempdir() self.sc.dump_profiles(d) self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) @@ -901,7 +920,7 @@ def show(self, id): def do_computation(self): def heavy_foo(x): - for i in range(1 << 20): + for i in range(1 << 18): x = 1 rdd = self.sc.parallelize(range(100)) @@ -1417,7 +1436,7 @@ def test_termination_sigterm(self): self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) -class WorkerTests(PySparkTestCase): +class WorkerTests(ReusedPySparkTestCase): def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) temp.close() @@ -1432,7 +1451,10 @@ def sleep(x): # start job in background thread def run(): - self.sc.parallelize(range(1), 1).foreach(sleep) + try: + self.sc.parallelize(range(1), 1).foreach(sleep) + except Exception: + pass import threading t = threading.Thread(target=run) t.daemon = True @@ -1473,7 +1495,8 @@ def test_after_exception(self): def raise_exception(_): raise Exception() rdd = self.sc.parallelize(range(100), 1) - self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) self.assertEqual(100, rdd.map(str).count()) def test_after_jvm_exception(self): @@ -1484,7 +1507,8 @@ def test_after_jvm_exception(self): filtered_data = data.filter(lambda x: True) self.assertEqual(1, filtered_data.count()) os.unlink(tempFile.name) - self.assertRaises(Exception, lambda: filtered_data.count()) + with QuietTest(self.sc): + self.assertRaises(Exception, lambda: filtered_data.count()) rdd = self.sc.parallelize(range(100), 1) self.assertEqual(100, rdd.map(str).count()) @@ -1522,14 +1546,11 @@ def test_with_different_versions_of_python(self): rdd.count() version = sys.version_info sys.version_info = (2, 0, 0) - log4j = self.sc._jvm.org.apache.log4j - old_level = log4j.LogManager.getRootLogger().getLevel() - log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) try: - self.assertRaises(Py4JJavaError, lambda: rdd.count()) + with QuietTest(self.sc): + self.assertRaises(Py4JJavaError, lambda: rdd.count()) finally: sys.version_info = version - log4j.LogManager.getRootLogger().setLevel(old_level) class SparkSubmitTests(unittest.TestCase): @@ -1751,9 +1772,14 @@ def test_with_stop(self): def test_progress_api(self): with SparkContext() as sc: sc.setJobGroup('test_progress_api', '', True) - rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) - t = threading.Thread(target=rdd.collect) + + def run(): + try: + rdd.count() + except Exception: + pass + t = threading.Thread(target=run) t.daemon = True t.start() # wait for scheduler to start diff --git a/python/run-tests b/python/run-tests index ed3e819ef30c1..88b63b84fdc27 100755 --- a/python/run-tests +++ b/python/run-tests @@ -28,6 +28,7 @@ cd "$FWDIR/python" FAILED=0 LOG_FILE=unit-tests.log +START=$(date +"%s") rm -f $LOG_FILE @@ -35,8 +36,8 @@ rm -f $LOG_FILE rm -rf metastore warehouse function run_test() { - echo "Running test: $1" | tee -a $LOG_FILE - + echo -en "Running test: $1 ... " | tee -a $LOG_FILE + start=$(date +"%s") SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1 FAILED=$((PIPESTATUS[0]||$FAILED)) @@ -48,6 +49,9 @@ function run_test() { echo "Had test failures; see logs." echo -en "\033[0m" # No color exit -1 + else + now=$(date +"%s") + echo "ok ($(($now - $start))s)" fi } @@ -161,9 +165,8 @@ if [ $(which pypy) ]; then fi if [[ $FAILED == 0 ]]; then - echo -en "\033[32m" # Green - echo "Tests passed." - echo -en "\033[0m" # No color + now=$(date +"%s") + echo -e "\033[32mTests passed \033[0min $(($now - $START)) seconds" fi # TODO: in the long-run, it would be nice to use a test runner like `nose`. From 41ef78a94105bb995bb14d15d47cbb6ca1638f62 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 21 Apr 2015 17:52:52 -0700 Subject: [PATCH 082/191] Closes #5427 From a0761ec7063f984dcadc8d154f83dd9cfd1c5e0b Mon Sep 17 00:00:00 2001 From: texasmichelle Date: Tue, 21 Apr 2015 18:08:29 -0700 Subject: [PATCH 083/191] [SPARK-1684] [PROJECT INFRA] Merge script should standardize SPARK-XXX prefix Cleans up the pull request title in the merge script to follow conventions outlined in the wiki under Contributing Code. https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-ContributingCode [MODULE] SPARK-XXXX: Description Author: texasmichelle Closes #5149 from texasmichelle/master and squashes the following commits: 9b6b0a7 [texasmichelle] resolved variable scope issue 7d5fa20 [texasmichelle] only prompt if title has been modified 8c195bb [texasmichelle] removed erroneous line 4f1ed46 [texasmichelle] Deque removal, logic simplifications, & prompt user to pick a title (orig or modified) df73f6a [texasmichelle] reworked regex's to enforce brackets around JIRA ref 43b5aed [texasmichelle] Merge remote-tracking branch 'apache/master' 25229c6 [texasmichelle] Merge remote-tracking branch 'apache/master' aa20a6e [texasmichelle] Move code into main() and add doctest for new text parsing method 48520ba [texasmichelle] SPARK-1684: Corrected import statement 042099d [texasmichelle] SPARK-1684 Merge script should standardize SPARK-XXX prefix 8f4a7d1 [texasmichelle] SPARK-1684 Merge script should standardize SPARK-XXX prefix --- dev/merge_spark_pr.py | 199 +++++++++++++++++++++++++++++------------- 1 file changed, 140 insertions(+), 59 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 3062e9c3c6651..b69cd15f99f63 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -55,8 +55,6 @@ # Prefix added to temporary branches BRANCH_PREFIX = "PR_TOOL" -os.chdir(SPARK_HOME) - def get_json(url): try: @@ -85,10 +83,6 @@ def continue_maybe(prompt): if result.lower() != "y": fail("Okay, exiting") - -original_head = run_cmd("git rev-parse HEAD")[:8] - - def clean_up(): print "Restoring head pointer to %s" % original_head run_cmd("git checkout %s" % original_head) @@ -101,7 +95,7 @@ def clean_up(): # merge the requested PR and return the merge hash -def merge_pr(pr_num, target_ref): +def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): pr_branch_name = "%s_MERGE_PR_%s" % (BRANCH_PREFIX, pr_num) target_branch_name = "%s_MERGE_PR_%s_%s" % (BRANCH_PREFIX, pr_num, target_ref.upper()) run_cmd("git fetch %s pull/%s/head:%s" % (PR_REMOTE_NAME, pr_num, pr_branch_name)) @@ -274,7 +268,7 @@ def get_version_json(version_str): asf_jira.transition_issue( jira_id, resolve["id"], fixVersions=jira_fix_versions, comment=comment) - print "Succesfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) + print "Successfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) def resolve_jira_issues(title, merge_branches, comment): @@ -286,68 +280,155 @@ def resolve_jira_issues(title, merge_branches, comment): resolve_jira_issue(merge_branches, comment, jira_id) -branches = get_json("%s/branches" % GITHUB_API_BASE) -branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) -# Assumes branch names can be sorted lexicographically -latest_branch = sorted(branch_names, reverse=True)[0] - -pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") -pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) -pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) +def standardize_jira_ref(text): + """ + Standardize the [SPARK-XXXXX] [MODULE] prefix + Converts "[SPARK-XXX][mllib] Issue", "[MLLib] SPARK-XXX. Issue" or "SPARK XXX [MLLIB]: Issue" to "[SPARK-XXX] [MLLIB] Issue" + + >>> standardize_jira_ref("[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful") + '[SPARK-5821] [SQL] ParquetRelation2 CTAS should check if delete is successful' + >>> standardize_jira_ref("[SPARK-4123][Project Infra][WIP]: Show new dependencies added in pull requests") + '[SPARK-4123] [PROJECT INFRA] [WIP] Show new dependencies added in pull requests' + >>> standardize_jira_ref("[MLlib] Spark 5954: Top by key") + '[SPARK-5954] [MLLIB] Top by key' + >>> standardize_jira_ref("[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl") + '[SPARK-979] a LRU scheduler for load balancing in TaskSchedulerImpl' + >>> standardize_jira_ref("SPARK-1094 Support MiMa for reporting binary compatibility accross versions.") + '[SPARK-1094] Support MiMa for reporting binary compatibility accross versions.' + >>> standardize_jira_ref("[WIP] [SPARK-1146] Vagrant support for Spark") + '[SPARK-1146] [WIP] Vagrant support for Spark' + >>> standardize_jira_ref("SPARK-1032. If Yarn app fails before registering, app master stays aroun...") + '[SPARK-1032] If Yarn app fails before registering, app master stays aroun...' + >>> standardize_jira_ref("[SPARK-6250][SPARK-6146][SPARK-5911][SQL] Types are now reserved words in DDL parser.") + '[SPARK-6250] [SPARK-6146] [SPARK-5911] [SQL] Types are now reserved words in DDL parser.' + >>> standardize_jira_ref("Additional information for users building from source code") + 'Additional information for users building from source code' + """ + jira_refs = [] + components = [] + + # If the string is compliant, no need to process any further + if (re.search(r'^\[SPARK-[0-9]{3,6}\] (\[[A-Z0-9_\s,]+\] )+\S+', text)): + return text + + # Extract JIRA ref(s): + pattern = re.compile(r'(SPARK[-\s]*[0-9]{3,6})+', re.IGNORECASE) + for ref in pattern.findall(text): + # Add brackets, replace spaces with a dash, & convert to uppercase + jira_refs.append('[' + re.sub(r'\s+', '-', ref.upper()) + ']') + text = text.replace(ref, '') + + # Extract spark component(s): + # Look for alphanumeric chars, spaces, dashes, periods, and/or commas + pattern = re.compile(r'(\[[\w\s,-\.]+\])', re.IGNORECASE) + for component in pattern.findall(text): + components.append(component.upper()) + text = text.replace(component, '') + + # Cleanup any remaining symbols: + pattern = re.compile(r'^\W+(.*)', re.IGNORECASE) + if (pattern.search(text) is not None): + text = pattern.search(text).groups()[0] + + # Assemble full text (JIRA ref(s), module(s), remaining text) + clean_text = ' '.join(jira_refs).strip() + " " + ' '.join(components).strip() + " " + text.strip() + + # Replace multiple spaces with a single space, e.g. if no jira refs and/or components were included + clean_text = re.sub(r'\s+', ' ', clean_text.strip()) + + return clean_text + +def main(): + global original_head + + os.chdir(SPARK_HOME) + original_head = run_cmd("git rev-parse HEAD")[:8] + + branches = get_json("%s/branches" % GITHUB_API_BASE) + branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]) + # Assumes branch names can be sorted lexicographically + latest_branch = sorted(branch_names, reverse=True)[0] + + pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") + pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) + pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) + + url = pr["url"] + + # Decide whether to use the modified title or not + modified_title = standardize_jira_ref(pr["title"]) + if modified_title != pr["title"]: + print "I've re-written the title as follows to match the standard format:" + print "Original: %s" % pr["title"] + print "Modified: %s" % modified_title + result = raw_input("Would you like to use the modified title? (y/n): ") + if result.lower() == "y": + title = modified_title + print "Using modified title:" + else: + title = pr["title"] + print "Using original title:" + print title + else: + title = pr["title"] -url = pr["url"] -title = pr["title"] -body = pr["body"] -target_ref = pr["base"]["ref"] -user_login = pr["user"]["login"] -base_ref = pr["head"]["ref"] -pr_repo_desc = "%s/%s" % (user_login, base_ref) + body = pr["body"] + target_ref = pr["base"]["ref"] + user_login = pr["user"]["login"] + base_ref = pr["head"]["ref"] + pr_repo_desc = "%s/%s" % (user_login, base_ref) -# Merged pull requests don't appear as merged in the GitHub API; -# Instead, they're closed by asfgit. -merge_commits = \ - [e for e in pr_events if e["actor"]["login"] == "asfgit" and e["event"] == "closed"] + # Merged pull requests don't appear as merged in the GitHub API; + # Instead, they're closed by asfgit. + merge_commits = \ + [e for e in pr_events if e["actor"]["login"] == "asfgit" and e["event"] == "closed"] -if merge_commits: - merge_hash = merge_commits[0]["commit_id"] - message = get_json("%s/commits/%s" % (GITHUB_API_BASE, merge_hash))["commit"]["message"] + if merge_commits: + merge_hash = merge_commits[0]["commit_id"] + message = get_json("%s/commits/%s" % (GITHUB_API_BASE, merge_hash))["commit"]["message"] - print "Pull request %s has already been merged, assuming you want to backport" % pr_num - commit_is_downloaded = run_cmd(['git', 'rev-parse', '--quiet', '--verify', + print "Pull request %s has already been merged, assuming you want to backport" % pr_num + commit_is_downloaded = run_cmd(['git', 'rev-parse', '--quiet', '--verify', "%s^{commit}" % merge_hash]).strip() != "" - if not commit_is_downloaded: - fail("Couldn't find any merge commit for #%s, you may need to update HEAD." % pr_num) + if not commit_is_downloaded: + fail("Couldn't find any merge commit for #%s, you may need to update HEAD." % pr_num) - print "Found commit %s:\n%s" % (merge_hash, message) - cherry_pick(pr_num, merge_hash, latest_branch) - sys.exit(0) + print "Found commit %s:\n%s" % (merge_hash, message) + cherry_pick(pr_num, merge_hash, latest_branch) + sys.exit(0) -if not bool(pr["mergeable"]): - msg = "Pull request %s is not mergeable in its current form.\n" % pr_num + \ - "Continue? (experts only!)" - continue_maybe(msg) + if not bool(pr["mergeable"]): + msg = "Pull request %s is not mergeable in its current form.\n" % pr_num + \ + "Continue? (experts only!)" + continue_maybe(msg) -print ("\n=== Pull Request #%s ===" % pr_num) -print ("title\t%s\nsource\t%s\ntarget\t%s\nurl\t%s" % ( - title, pr_repo_desc, target_ref, url)) -continue_maybe("Proceed with merging pull request #%s?" % pr_num) + print ("\n=== Pull Request #%s ===" % pr_num) + print ("title\t%s\nsource\t%s\ntarget\t%s\nurl\t%s" % ( + title, pr_repo_desc, target_ref, url)) + continue_maybe("Proceed with merging pull request #%s?" % pr_num) -merged_refs = [target_ref] + merged_refs = [target_ref] -merge_hash = merge_pr(pr_num, target_ref) + merge_hash = merge_pr(pr_num, target_ref, title, body, pr_repo_desc) -pick_prompt = "Would you like to pick %s into another branch?" % merge_hash -while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": - merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] + pick_prompt = "Would you like to pick %s into another branch?" % merge_hash + while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": + merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] -if JIRA_IMPORTED: - if JIRA_USERNAME and JIRA_PASSWORD: - continue_maybe("Would you like to update an associated JIRA?") - jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) - resolve_jira_issues(title, merged_refs, jira_comment) + if JIRA_IMPORTED: + if JIRA_USERNAME and JIRA_PASSWORD: + continue_maybe("Would you like to update an associated JIRA?") + jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) + resolve_jira_issues(title, merged_refs, jira_comment) + else: + print "JIRA_USERNAME and JIRA_PASSWORD not set" + print "Exiting without trying to close the associated JIRA." else: - print "JIRA_USERNAME and JIRA_PASSWORD not set" + print "Could not find jira-python library. Run 'sudo pip install jira-python' to install." print "Exiting without trying to close the associated JIRA." -else: - print "Could not find jira-python library. Run 'sudo pip install jira-python' to install." - print "Exiting without trying to close the associated JIRA." + +if __name__ == "__main__": + import doctest + doctest.testmod() + + main() From 3a3f7100f4ead9b7ac50e9711ac50b603ebf6bea Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 21 Apr 2015 18:37:53 -0700 Subject: [PATCH 084/191] [SPARK-6490][Docs] Add docs for rpc configurations Added docs for rpc configurations and also fixed two places that should have been fixed in #5595. Author: zsxwing Closes #5607 from zsxwing/SPARK-6490-docs and squashes the following commits: 25a6736 [zsxwing] Increase the default timeout to 120s 6e37c30 [zsxwing] Update docs 5577540 [zsxwing] Use spark.network.timeout as the default timeout if it presents 4f07174 [zsxwing] Fix unit tests 1c2cf26 [zsxwing] Add docs for rpc configurations --- .../org/apache/spark/util/RpcUtils.scala | 6 ++-- .../org/apache/spark/SparkConfSuite.scala | 2 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 2 +- docs/configuration.md | 34 +++++++++++++++++-- 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 5ae793e0e87a3..f16cc8e7e42c6 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -48,11 +48,13 @@ object RpcUtils { /** Returns the default Spark timeout to use for RPC ask operations. */ def askTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.askTimeout", "30s") seconds + conf.getTimeAsSeconds("spark.rpc.askTimeout", + conf.get("spark.network.timeout", "120s")) seconds } /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ def lookupTimeout(conf: SparkConf): FiniteDuration = { - conf.getTimeAsSeconds("spark.rpc.lookupTimeout", "30s") seconds + conf.getTimeAsSeconds("spark.rpc.lookupTimeout", + conf.get("spark.network.timeout", "120s")) seconds } } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index d7d8014a20498..272e6af0514e4 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -227,7 +227,7 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro test("akka deprecated configs") { val conf = new SparkConf() - assert(!conf.contains("spark.rpc.num.retries")) + assert(!conf.contains("spark.rpc.numRetries")) assert(!conf.contains("spark.rpc.retry.wait")) assert(!conf.contains("spark.rpc.askTimeout")) assert(!conf.contains("spark.rpc.lookupTimeout")) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 5fbda37c7cb88..44c88b00c442a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -156,7 +156,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val conf = new SparkConf() conf.set("spark.rpc.retry.wait", "0") - conf.set("spark.rpc.num.retries", "1") + conf.set("spark.rpc.numRetries", "1") val anotherEnv = createRpcEnv(conf, "remote", 13345) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") diff --git a/docs/configuration.md b/docs/configuration.md index d9e9e67026cbb..d587b91124cb8 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -963,8 +963,9 @@ Apart from these, the following properties are also available, and may be useful Default timeout for all network interactions. This config will be used in place of spark.core.connection.ack.wait.timeout, spark.akka.timeout, - spark.storage.blockManagerSlaveTimeoutMs or - spark.shuffle.io.connectionTimeout, if they are not configured. + spark.storage.blockManagerSlaveTimeoutMs, + spark.shuffle.io.connectionTimeout, spark.rpc.askTimeout or + spark.rpc.lookupTimeout if they are not configured. @@ -982,6 +983,35 @@ Apart from these, the following properties are also available, and may be useful This is only relevant for the Spark shell. + + spark.rpc.numRetries + 3 + Number of times to retry before an RPC task gives up. + An RPC task will run at most times of this number. + + + + + spark.rpc.retry.wait + 3s + + Duration for an RPC ask operation to wait before retrying. + + + + spark.rpc.askTimeout + 120s + + Duration for an RPC ask operation to wait before timing out. + + + + spark.rpc.lookupTimeout + 120s + Duration for an RPC remote endpoint lookup operation to wait before timing out. + + + #### Scheduling From 70f9f8ff38560967f2c84de77263a5455c45c495 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 21 Apr 2015 21:04:04 -0700 Subject: [PATCH 085/191] [MINOR] Comment improvements in ExternalSorter. 1. Clearly specifies the contract/interactions for users of this class. 2. Minor fix in one doc to avoid ambiguity. Author: Patrick Wendell Closes #5620 from pwendell/cleanup and squashes the following commits: 8d8f44f [Patrick Wendell] [Minor] Comment improvements in ExternalSorter. --- .../util/collection/ExternalSorter.scala | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 79a1a8a0dae38..79a695fb62086 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -53,7 +53,18 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId} * probably want to pass None as the ordering to avoid extra sorting. On the other hand, if you do * want to do combining, having an Ordering is more efficient than not having it. * - * At a high level, this class works as follows: + * Users interact with this class in the following way: + * + * 1. Instantiate an ExternalSorter. + * + * 2. Call insertAll() with a set of records. + * + * 3. Request an iterator() back to traverse sorted/aggregated records. + * - or - + * Invoke writePartitionedFile() to create a file containing sorted/aggregated outputs + * that can be used in Spark's sort shuffle. + * + * At a high level, this class works internally as follows: * * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if * we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers, @@ -65,11 +76,11 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId} * aggregation. For each file, we track how many objects were in each partition in memory, so we * don't have to write out the partition ID for every element. * - * - When the user requests an iterator, the spilled files are merged, along with any remaining - * in-memory data, using the same sort order defined above (unless both sorting and aggregation - * are disabled). If we need to aggregate by key, we either use a total ordering from the - * ordering parameter, or read the keys with the same hash code and compare them with each other - * for equality to merge values. + * - When the user requests an iterator or file output, the spilled files are merged, along with + * any remaining in-memory data, using the same sort order defined above (unless both sorting + * and aggregation are disabled). If we need to aggregate by key, we either use a total ordering + * from the ordering parameter, or read the keys with the same hash code and compare them with + * each other for equality to merge values. * * - Users are expected to call stop() at the end to delete all the intermediate files. * @@ -259,8 +270,8 @@ private[spark] class ExternalSorter[K, V, C]( * Spill our in-memory collection to a sorted file that we can merge later (normal code path). * We add this file into spilledFiles to find it later. * - * Alternatively, if bypassMergeSort is true, we spill to separate files for each partition. - * See spillToPartitionedFiles() for that code path. + * This should not be invoked if bypassMergeSort is true. In that case, spillToPartitionedFiles() + * is used to write files for each partition. * * @param collection whichever collection we're using (map or buffer) */ From 607eff0edfc10a1473fa9713a0500bf09f105c13 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 21 Apr 2015 21:44:44 -0700 Subject: [PATCH 086/191] [SPARK-6113] [ML] Small cleanups after original tree API PR This does a few clean-ups. With this PR, all spark.ml tree components have ```private[ml]``` constructors. CC: mengxr Author: Joseph K. Bradley Closes #5567 from jkbradley/dt-api-dt2 and squashes the following commits: 2263b5b [Joseph K. Bradley] Added note about tree example issue. bb9f610 [Joseph K. Bradley] Small cleanups after original tree API PR --- .../examples/ml/DecisionTreeExample.scala | 25 ++++++++++++++----- .../spark/ml/impl/tree/treeParams.scala | 4 +-- .../org/apache/spark/ml/tree/Split.scala | 7 +++--- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 921b396e799e7..2cd515c89d3d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -44,6 +44,13 @@ import org.apache.spark.sql.{SQLContext, DataFrame} * {{{ * ./bin/run-example ml.DecisionTreeExample [options] * }}} + * Note that Decision Trees can take a large amount of memory. If the run-example command above + * fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g + * [examples JAR path] [options] + * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ object DecisionTreeExample { @@ -70,7 +77,7 @@ object DecisionTreeExample { val parser = new OptionParser[Params]("DecisionTreeExample") { head("DecisionTreeExample: an example decision tree app.") opt[String]("algo") - .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") .action((x, c) => c.copy(algo = x)) opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") @@ -222,18 +229,23 @@ object DecisionTreeExample { // (1) For classification, re-index classes. val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { - val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName) + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) stages += labelIndexer } // (2) Identify categorical features using VectorIndexer. // Features with more than maxCategories values will be treated as continuous. - val featuresIndexer = new VectorIndexer().setInputCol("features") - .setOutputCol("indexedFeatures").setMaxCategories(10) + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) stages += featuresIndexer // (3) Learn DecisionTree val dt = algo match { case "classification" => - new DecisionTreeClassifier().setFeaturesCol("indexedFeatures") + new DecisionTreeClassifier() + .setFeaturesCol("indexedFeatures") .setLabelCol(labelColName) .setMaxDepth(params.maxDepth) .setMaxBins(params.maxBins) @@ -242,7 +254,8 @@ object DecisionTreeExample { .setCacheNodeIds(params.cacheNodeIds) .setCheckpointInterval(params.checkpointInterval) case "regression" => - new DecisionTreeRegressor().setFeaturesCol("indexedFeatures") + new DecisionTreeRegressor() + .setFeaturesCol("indexedFeatures") .setLabelCol(labelColName) .setMaxDepth(params.maxDepth) .setMaxBins(params.maxBins) diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala index 6f4509f03d033..eb2609faef05a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -117,7 +117,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { def setMaxDepth(value: Int): this.type = { require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") set(maxDepth, value) - this.asInstanceOf[this.type] + this } /** @group getParam */ @@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params { def getImpurity: String = getOrDefault(impurity) /** Convert new impurity to old impurity. */ - protected def getOldImpurity: OldImpurity = { + private[ml] def getOldImpurity: OldImpurity = { getImpurity match { case "variance" => OldVariance case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index cb940f62990ed..708c769087dd0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -38,7 +38,7 @@ sealed trait Split extends Serializable { private[tree] def toOld: OldSplit } -private[ml] object Split { +private[tree] object Split { def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = { oldSplit.featureType match { @@ -58,7 +58,7 @@ private[ml] object Split { * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ -final class CategoricalSplit( +final class CategoricalSplit private[ml] ( override val featureIndex: Int, leftCategories: Array[Double], private val numCategories: Int) @@ -130,7 +130,8 @@ final class CategoricalSplit( * @param threshold If the feature value is <= this threshold, then the split goes left. * Otherwise, it goes right. */ -final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split { +final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) + extends Split { override private[ml] def shouldGoLeft(features: Vector): Boolean = { features(featureIndex) <= threshold From bdc5c16e76c5d0bc147408353b2ba4faa8e914fc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 21 Apr 2015 22:34:31 -0700 Subject: [PATCH 087/191] [SPARK-6889] [DOCS] CONTRIBUTING.md updates to accompany contribution doc updates Part of the SPARK-6889 doc updates, to accompany wiki updates at https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark See draft text at https://docs.google.com/document/d/1tB9-f9lmxhC32QlOo4E8Z7eGDwHx1_Q3O8uCmRXQTo8/edit# Author: Sean Owen Closes #5623 from srowen/SPARK-6889 and squashes the following commits: 03773b1 [Sean Owen] Part of the SPARK-6889 doc updates, to accompany wiki updates at https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark --- CONTRIBUTING.md | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b6c6b050fa331..f10d7e277eea3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,12 +1,16 @@ ## Contributing to Spark -Contributions via GitHub pull requests are gladly accepted from their original -author. Along with any pull requests, please state that the contribution is -your original work and that you license the work to the project under the -project's open source license. Whether or not you state this explicitly, by -submitting any copyrighted material via pull request, email, or other means -you agree to license the material under the project's open source license and -warrant that you have the legal authority to do so. +*Before opening a pull request*, review the +[Contributing to Spark wiki](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark). +It lists steps that are required before creating a PR. In particular, consider: + +- Is the change important and ready enough to ask the community to spend time reviewing? +- Have you searched for existing, related JIRAs and pull requests? +- Is this a new feature that can stand alone as a package on http://spark-packages.org ? +- Is the change being proposed clearly explained and motivated? -Please see the [Contributing to Spark wiki page](https://cwiki.apache.org/SPARK/Contributing+to+Spark) -for more information. +When you contribute code, you affirm that the contribution is your original work and that you +license the work to the project under the project's open source license. Whether or not you +state this explicitly, by submitting any copyrighted material via pull request, email, or +other means you agree to license the material under the project's open source license and +warrant that you have the legal authority to do so. From 33b85620f910c404873d362d27cca1223084913a Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 22 Apr 2015 11:08:59 -0700 Subject: [PATCH 088/191] [SPARK-7052][Core] Add ThreadUtils and move thread methods from Utils to ThreadUtils As per rxin 's suggestion in https://github.com/apache/spark/pull/5392/files#r28757176 What's more, there is a race condition in the global shared `daemonThreadFactoryBuilder`. `daemonThreadFactoryBuilder` may be modified by multiple threads. This PR removed the global `daemonThreadFactoryBuilder` and created a new `ThreadFactoryBuilder` every time. Author: zsxwing Closes #5631 from zsxwing/thread-utils and squashes the following commits: 9fe5b0e [zsxwing] Add ThreadUtils and move thread methods from Utils to ThreadUtils --- .../spark/ExecutorAllocationManager.scala | 8 +-- .../org/apache/spark/HeartbeatReceiver.scala | 11 ++- .../deploy/history/FsHistoryProvider.scala | 4 +- .../org/apache/spark/executor/Executor.scala | 7 +- .../spark/network/nio/ConnectionManager.scala | 12 ++-- .../apache/spark/scheduler/DAGScheduler.scala | 4 +- .../spark/scheduler/TaskResultGetter.scala | 4 +- .../CoarseGrainedSchedulerBackend.scala | 6 +- .../cluster/YarnSchedulerBackend.scala | 4 +- .../spark/scheduler/local/LocalBackend.scala | 8 +-- .../storage/BlockManagerMasterEndpoint.scala | 4 +- .../storage/BlockManagerSlaveEndpoint.scala | 4 +- .../org/apache/spark/util/ThreadUtils.scala | 67 +++++++++++++++++++ .../scala/org/apache/spark/util/Utils.scala | 29 -------- .../apache/spark/util/ThreadUtilsSuite.scala | 57 ++++++++++++++++ .../streaming/kafka/KafkaInputDStream.scala | 5 +- .../kafka/ReliableKafkaReceiver.scala | 4 +- .../receiver/ReceivedBlockHandler.scala | 4 +- .../streaming/util/WriteAheadLogManager.scala | 4 +- 19 files changed, 170 insertions(+), 76 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/ThreadUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 4e7bf51fc0622..b986fa87dc2f4 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -17,12 +17,12 @@ package org.apache.spark -import java.util.concurrent.{Executors, TimeUnit} +import java.util.concurrent.TimeUnit import scala.collection.mutable import org.apache.spark.scheduler._ -import org.apache.spark.util.{Clock, SystemClock, Utils} +import org.apache.spark.util.{ThreadUtils, Clock, SystemClock, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. @@ -132,8 +132,8 @@ private[spark] class ExecutorAllocationManager( private val listener = new ExecutorAllocationListener // Executor that handles the scheduling task. - private val executor = Executors.newSingleThreadScheduledExecutor( - Utils.namedThreadFactory("spark-dynamic-executor-allocation")) + private val executor = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("spark-dynamic-executor-allocation") /** * Verify that the settings specified through the config are valid. diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index e3bd16f1cbf24..68d05d5b02537 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -17,7 +17,7 @@ package org.apache.spark -import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} +import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable @@ -25,7 +25,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} import org.apache.spark.storage.BlockManagerId import org.apache.spark.scheduler.{SlaveLost, TaskScheduler} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -76,11 +76,10 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) private var timeoutCheckingTask: ScheduledFuture[_] = null - private val timeoutCheckingThread = Executors.newSingleThreadScheduledExecutor( - Utils.namedThreadFactory("heartbeat-timeout-checking-thread")) + private val timeoutCheckingThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-timeout-checking-thread") - private val killExecutorThread = Executors.newSingleThreadExecutor( - Utils.namedThreadFactory("kill-executor-thread")) + private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread") override def onStart(): Unit = { timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 47bdd7749ec3d..9847d5944a390 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -32,7 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.scheduler._ import org.apache.spark.ui.SparkUI -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.{Logging, SecurityManager, SparkConf} @@ -99,7 +99,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis */ private val replayExecutor: ExecutorService = { if (!conf.contains("spark.testing")) { - Executors.newSingleThreadExecutor(Utils.namedThreadFactory("log-replay-executor")) + ThreadUtils.newDaemonSingleThreadExecutor("log-replay-executor") } else { MoreExecutors.sameThreadExecutor() } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 327d155b38c22..5fc04df5d6a40 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -21,7 +21,7 @@ import java.io.File import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer -import java.util.concurrent.{ConcurrentHashMap, Executors, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -76,7 +76,7 @@ private[spark] class Executor( } // Start worker thread pool - private val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") + private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") private val executorSource = new ExecutorSource(threadPool, executorId) if (!isLocal) { @@ -110,8 +110,7 @@ private[spark] class Executor( private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] // Executor for the heartbeat task. - private val heartbeater = Executors.newSingleThreadScheduledExecutor( - Utils.namedThreadFactory("driver-heartbeater")) + private val heartbeater = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-heartbeater") startDriverHeartbeater() diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 1a68e621eaee7..16e905982cf64 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -36,7 +36,7 @@ import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} import org.apache.spark._ import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import scala.util.Try import scala.util.control.NonFatal @@ -79,7 +79,7 @@ private[nio] class ConnectionManager( private val selector = SelectorProvider.provider.openSelector() private val ackTimeoutMonitor = - new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) + new HashedWheelTimer(ThreadUtils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = conf.getTimeAsSeconds("spark.core.connection.ack.wait.timeout", @@ -102,7 +102,7 @@ private[nio] class ConnectionManager( handlerThreadCount, conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-message-executor")) { + ThreadUtils.namedThreadFactory("handle-message-executor")) { override def afterExecute(r: Runnable, t: Throwable): Unit = { super.afterExecute(r, t) @@ -117,7 +117,7 @@ private[nio] class ConnectionManager( ioThreadCount, conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-read-write-executor")) { + ThreadUtils.namedThreadFactory("handle-read-write-executor")) { override def afterExecute(r: Runnable, t: Throwable): Unit = { super.afterExecute(r, t) @@ -134,7 +134,7 @@ private[nio] class ConnectionManager( connectThreadCount, conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-connect-executor")) { + ThreadUtils.namedThreadFactory("handle-connect-executor")) { override def afterExecute(r: Runnable, t: Throwable): Unit = { super.afterExecute(r, t) @@ -160,7 +160,7 @@ private[nio] class ConnectionManager( private val registerRequests = new SynchronizedQueue[SendingConnection] implicit val futureExecContext = ExecutionContext.fromExecutor( - Utils.newDaemonCachedThreadPool("Connection manager future execution context")) + ThreadUtils.newDaemonCachedThreadPool("Connection manager future execution context")) @volatile private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message] = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4a32f8936fb0e..8c4bff4e83afc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties -import java.util.concurrent.{TimeUnit, Executors} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack} @@ -129,7 +129,7 @@ class DAGScheduler( private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) private val messageScheduler = - Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message")) + ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 3938580aeea59..391827c1d2156 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. @@ -35,7 +35,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul extends Logging { private val THREADS = sparkEnv.conf.getInt("spark.resultGetter.threads", 4) - private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( + private val getTaskResultExecutor = ThreadUtils.newDaemonFixedThreadPool( THREADS, "task-result-getter") protected val serializer = new ThreadLocal[SerializerInstance] { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 63987dfb32695..9656fb76858ea 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import java.util.concurrent.{TimeUnit, Executors} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -26,7 +26,7 @@ import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -73,7 +73,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private val addressToExecutorId = new HashMap[RpcAddress, String] private val reviveThread = - Executors.newSingleThreadScheduledExecutor(Utils.namedThreadFactory("driver-revive-thread")) + ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") override def onStart() { // Periodically revive offers to allow delay scheduling to work diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 1406a36a669c5..d987c7d563579 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -24,7 +24,7 @@ import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.{ThreadUtils, RpcUtils} import scala.util.control.NonFatal @@ -97,7 +97,7 @@ private[spark] abstract class YarnSchedulerBackend( private var amEndpoint: Option[RpcEndpointRef] = None private val askAmThreadPool = - Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") + ThreadUtils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-thread-pool") implicit val askAmExecutor = ExecutionContext.fromExecutor(askAmThreadPool) override def receive: PartialFunction[Any, Unit] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 50ba0b9d5a612..ac5b524517818 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -18,14 +18,14 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer -import java.util.concurrent.{Executors, TimeUnit} +import java.util.concurrent.TimeUnit import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} private case class ReviveOffers() @@ -47,8 +47,8 @@ private[spark] class LocalEndpoint( private val totalCores: Int) extends ThreadSafeRpcEndpoint with Logging { - private val reviveThread = Executors.newSingleThreadScheduledExecutor( - Utils.namedThreadFactory("local-revive-thread")) + private val reviveThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("local-revive-thread") private var freeCores = totalCores diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 28c73a7d543ff..4682167912ff0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -28,7 +28,7 @@ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses @@ -51,7 +51,7 @@ class BlockManagerMasterEndpoint( // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val askThreadPool = Utils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") + private val askThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool") private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 8980fa8eb70e2..543df4e1350dd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -20,7 +20,7 @@ package org.apache.spark.storage import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.storage.BlockManagerMessages._ @@ -36,7 +36,7 @@ class BlockManagerSlaveEndpoint( extends RpcEndpoint with Logging { private val asyncThreadPool = - Utils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") + ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala new file mode 100644 index 0000000000000..098a4b79496b2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.util + +import java.util.concurrent._ + +import com.google.common.util.concurrent.ThreadFactoryBuilder + +private[spark] object ThreadUtils { + + /** + * Create a thread factory that names threads with a prefix and also sets the threads to daemon. + */ + def namedThreadFactory(prefix: String): ThreadFactory = { + new ThreadFactoryBuilder().setDaemon(true).setNameFormat(prefix + "-%d").build() + } + + /** + * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. + */ + def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = { + val threadFactory = namedThreadFactory(prefix) + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } + + /** + * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. + */ + def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = { + val threadFactory = namedThreadFactory(prefix) + Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor] + } + + /** + * Wrapper over newSingleThreadExecutor. + */ + def newDaemonSingleThreadExecutor(threadName: String): ExecutorService = { + val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() + Executors.newSingleThreadExecutor(threadFactory) + } + + /** + * Wrapper over newSingleThreadScheduledExecutor. + */ + def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = { + val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() + Executors.newSingleThreadScheduledExecutor(threadFactory) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 7b0de1ae55b78..2feb7341b159b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -35,7 +35,6 @@ import scala.util.control.{ControlThrowable, NonFatal} import com.google.common.io.{ByteStreams, Files} import com.google.common.net.InetAddresses -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} @@ -897,34 +896,6 @@ private[spark] object Utils extends Logging { hostPortParseResults.get(hostPort) } - private val daemonThreadFactoryBuilder: ThreadFactoryBuilder = - new ThreadFactoryBuilder().setDaemon(true) - - /** - * Create a thread factory that names threads with a prefix and also sets the threads to daemon. - */ - def namedThreadFactory(prefix: String): ThreadFactory = { - daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build() - } - - /** - * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a - * unique, sequentially assigned integer. - */ - def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = { - val threadFactory = namedThreadFactory(prefix) - Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] - } - - /** - * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a - * unique, sequentially assigned integer. - */ - def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = { - val threadFactory = namedThreadFactory(prefix) - Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor] - } - /** * Return the string to tell how long has passed in milliseconds. */ diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala new file mode 100644 index 0000000000000..a3aa3e953fbec --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.util + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import org.scalatest.FunSuite + +class ThreadUtilsSuite extends FunSuite { + + test("newDaemonSingleThreadExecutor") { + val executor = ThreadUtils.newDaemonSingleThreadExecutor("this-is-a-thread-name") + @volatile var threadName = "" + executor.submit(new Runnable { + override def run(): Unit = { + threadName = Thread.currentThread().getName() + } + }) + executor.shutdown() + executor.awaitTermination(10, TimeUnit.SECONDS) + assert(threadName === "this-is-a-thread-name") + } + + test("newDaemonSingleThreadScheduledExecutor") { + val executor = ThreadUtils.newDaemonSingleThreadScheduledExecutor("this-is-a-thread-name") + try { + val latch = new CountDownLatch(1) + @volatile var threadName = "" + executor.schedule(new Runnable { + override def run(): Unit = { + threadName = Thread.currentThread().getName() + latch.countDown() + } + }, 1, TimeUnit.MILLISECONDS) + latch.await(10, TimeUnit.SECONDS) + assert(threadName === "this-is-a-thread-name") + } finally { + executor.shutdownNow() + } + } +} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 4d26b640e8d74..cca0fac0234e1 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils /** * Input stream that pulls messages from a Kafka Broker. @@ -111,7 +111,8 @@ class KafkaReceiver[ val topicMessageStreams = consumerConnector.createMessageStreams( topics, keyDecoder, valueDecoder) - val executorPool = Utils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler") + val executorPool = + ThreadUtils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler") try { // Start the messages handler for each partition topicMessageStreams.values.foreach { streams => diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index c4a44c1822c39..ea87e960379f1 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -33,7 +33,7 @@ import org.I0Itec.zkclient.ZkClient import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} -import org.apache.spark.util.Utils +import org.apache.spark.util.ThreadUtils /** * ReliableKafkaReceiver offers the ability to reliably store data into BlockManager without loss. @@ -121,7 +121,7 @@ class ReliableKafkaReceiver[ zkClient = new ZkClient(consumerConfig.zkConnect, consumerConfig.zkSessionTimeoutMs, consumerConfig.zkConnectionTimeoutMs, ZKStringSerializer) - messageHandlerThreadPool = Utils.newDaemonFixedThreadPool( + messageHandlerThreadPool = ThreadUtils.newDaemonFixedThreadPool( topics.values.sum, "KafkaMessageHandler") blockGenerator.start() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index dcdc27d29c270..297bf04c0c25e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage._ import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogManager} -import org.apache.spark.util.{Clock, SystemClock, Utils} +import org.apache.spark.util.{ThreadUtils, Clock, SystemClock} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { @@ -150,7 +150,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // For processing futures used in parallel block storing into block manager and write ahead log // # threads = 2, so that both writing to BM and WAL can proceed in parallel implicit private val executionContext = ExecutionContext.fromExecutorService( - Utils.newDaemonFixedThreadPool(2, this.getClass.getSimpleName)) + ThreadUtils.newDaemonFixedThreadPool(2, this.getClass.getSimpleName)) /** * This implementation stores the block into the block manager as well as a write ahead log. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala index 6bdfe45dc7f83..38a93cc3c9a1f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala @@ -25,7 +25,7 @@ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.Logging -import org.apache.spark.util.{Clock, SystemClock, Utils} +import org.apache.spark.util.{ThreadUtils, Clock, SystemClock} import WriteAheadLogManager._ /** @@ -60,7 +60,7 @@ private[streaming] class WriteAheadLogManager( if (callerName.nonEmpty) s" for $callerName" else "" private val threadpoolName = s"WriteAheadLogManager $callerNameTag" implicit private val executionContext = ExecutionContext.fromExecutorService( - Utils.newDaemonFixedThreadPool(1, threadpoolName)) + ThreadUtils.newDaemonSingleThreadExecutor(threadpoolName)) override protected val logName = s"WriteAheadLogManager $callerNameTag" private var currentLogPath: Option[String] = None From cdf0328684f70ddcd49b23c23c1532aeb9caa44e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 11:18:01 -0700 Subject: [PATCH 089/191] [SQL] Rename some apply functions. I was looking at the code gen code and got confused by a few of use cases of apply, in particular apply on objects. So I went ahead and changed a few of them. Hopefully slightly more clear with a proper verb. Author: Reynold Xin Closes #5624 from rxin/apply-rename and squashes the following commits: ee45034 [Reynold Xin] [SQL] Rename some apply functions. --- .../sql/catalyst/AbstractSparkSQLParser.scala | 2 +- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 6 +-- .../codegen/GenerateMutableProjection.scala | 2 +- .../codegen/GenerateOrdering.scala | 2 +- .../codegen/GeneratePredicate.scala | 2 +- .../codegen/GenerateProjection.scala | 2 +- .../sql/catalyst/expressions/predicates.scala | 6 +-- .../sql/catalyst/rules/RuleExecutor.scala | 2 +- .../spark/sql/catalyst/SqlParserSuite.scala | 9 ++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 22 ++++---- .../analysis/DecimalPrecisionSuite.scala | 6 +-- .../GeneratedEvaluationSuite.scala | 10 ++-- .../GeneratedMutableEvaluationSuite.scala | 8 +-- .../BooleanSimplificationSuite.scala | 2 +- .../optimizer/CombiningLimitsSuite.scala | 4 +- .../optimizer/ConstantFoldingSuite.scala | 14 ++--- .../ConvertToLocalRelationSuite.scala | 2 +- .../ExpressionOptimizationSuite.scala | 2 +- .../optimizer/FilterPushdownSuite.scala | 52 +++++++++---------- .../optimizer/LikeSimplificationSuite.scala | 8 +-- .../catalyst/optimizer/OptimizeInSuite.scala | 4 +- ...mplifyCaseConversionExpressionsSuite.scala | 8 +-- .../optimizer/UnionPushdownSuite.scala | 7 ++- .../catalyst/trees/RuleExecutorSuite.scala | 6 +-- .../org/apache/spark/sql/SQLContext.scala | 12 ++--- .../spark/sql/execution/SparkPlan.scala | 10 ++-- .../joins/BroadcastNestedLoopJoin.scala | 2 +- .../sql/execution/joins/LeftSemiJoinBNL.scala | 2 +- .../apache/spark/sql/parquet/newParquet.scala | 2 +- .../org/apache/spark/sql/sources/ddl.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../org/apache/spark/sql/hive/HiveQl.scala | 4 +- .../spark/sql/hive/HiveStrategies.scala | 4 +- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- 35 files changed, 117 insertions(+), 117 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 3823584287741..1f3c02478bd68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -32,7 +32,7 @@ private[sql] object KeywordNormalizer { private[sql] abstract class AbstractSparkSQLParser extends StandardTokenParsers with PackratParsers { - def apply(input: String): LogicalPlan = { + def parse(input: String): LogicalPlan = { // Initialize the Keywords. lexical.initialize(reservedWords) phrase(start)(new lexical.Scanner(input)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 4e5c64bb63c9f..5d5aba9644ff7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -296,7 +296,7 @@ package object dsl { InsertIntoTable( analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) - def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer(logicalPlan)) + def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) } object plans { // scalastyle:ignore diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index be2c101d63a63..eeffedb558c1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -98,11 +98,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin }) /** Generates the requested evaluator binding the given expression(s) to the inputSchema. */ - def apply(expressions: InType, inputSchema: Seq[Attribute]): OutType = - apply(bind(expressions, inputSchema)) + def generate(expressions: InType, inputSchema: Seq[Attribute]): OutType = + generate(bind(expressions, inputSchema)) /** Generates the requested evaluator given already bound expression(s). */ - def apply(expressions: InType): OutType = cache.get(canonicalize(expressions)) + def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) /** * Returns a term name that is unique within this instance of a `CodeGenerator`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index a419fd7ecb39b..840260703ab74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -30,7 +30,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val mutableRowName = newTermName("mutableRow") protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer(_)) + in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index fc2a2b60703e4..b129c0d898bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -30,7 +30,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = - in.map(ExpressionCanonicalizer(_).asInstanceOf[SortOrder]) + in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = in.map(BindReferences.bindReference(_, inputSchema)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 2a0935c790cf3..40e163024360e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -26,7 +26,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] { import scala.reflect.runtime.{universe => ru} import scala.reflect.runtime.universe._ - protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer(in) + protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in) protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = BindReferences.bindReference(in, inputSchema) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 6f572ff959fb4..d491babc2bff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -31,7 +31,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer(_)) + in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index fcd6352079b4d..46522eb9c1264 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -23,10 +23,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, NativeType} object InterpretedPredicate { - def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = - apply(BindReferences.bindReference(expression, inputSchema)) + def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = + create(BindReferences.bindReference(expression, inputSchema)) - def apply(expression: Expression): (Row => Boolean) = { + def create(expression: Expression): (Row => Boolean) = { (r: Row) => expression.eval(r).asInstanceOf[Boolean] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index c441f0bf24d85..3f9858b0c4a43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -45,7 +45,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { * Executes the batches of rules defined by the subclass. The batches are executed serially * using the defined execution strategy. Within each batch, rules are also executed serially. */ - def apply(plan: TreeType): TreeType = { + def execute(plan: TreeType): TreeType = { var curPlan = plan batches.foreach { batch => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 1a0a0e6154ad2..a652c70560990 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -49,13 +49,14 @@ class SqlParserSuite extends FunSuite { test("test long keyword") { val parser = new SuperLongKeywordTestParser - assert(TestCommand("NotRealCommand") === parser("ThisIsASuperLongKeyWordTest NotRealCommand")) + assert(TestCommand("NotRealCommand") === + parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand")) } test("test case insensitive") { val parser = new CaseInsensitiveTestParser - assert(TestCommand("NotRealCommand") === parser("EXECUTE NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser("execute NotRealCommand")) - assert(TestCommand("NotRealCommand") === parser("exEcute NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand")) + assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 7c249215bd6b6..971e1ff5ec2b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -42,10 +42,10 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { def caseSensitiveAnalyze(plan: LogicalPlan): Unit = - caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan)) + caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer.execute(plan)) def caseInsensitiveAnalyze(plan: LogicalPlan): Unit = - caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan)) + caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer.execute(plan)) val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val testRelation2 = LocalRelation( @@ -82,7 +82,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) } - assert(caseInsensitiveAnalyzer(plan).resolved) + assert(caseInsensitiveAnalyzer.execute(plan).resolved) } test("check project's resolved") { @@ -98,11 +98,11 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { test("analyze project") { assert( - caseSensitiveAnalyzer(Project(Seq(UnresolvedAttribute("a")), testRelation)) === + caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) === Project(testRelation.output, testRelation)) assert( - caseSensitiveAnalyzer( + caseSensitiveAnalyzer.execute( Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) @@ -115,13 +115,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(e.getMessage().toLowerCase.contains("cannot resolve")) assert( - caseInsensitiveAnalyzer( + caseInsensitiveAnalyzer.execute( Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) assert( - caseInsensitiveAnalyzer( + caseInsensitiveAnalyzer.execute( Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) @@ -134,13 +134,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(e.getMessage == "Table Not Found: tAbLe") assert( - caseSensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) assert( - caseInsensitiveAnalyzer(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) + caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) assert( - caseInsensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) } def errorTest( @@ -219,7 +219,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) - val plan = caseInsensitiveAnalyzer( + val plan = caseInsensitiveAnalyzer.execute( testRelation2.select( 'a / Literal(2) as 'div1, 'a / 'b as 'div2, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 67bec999dfbd1..36b03d1c65e28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -48,12 +48,12 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { private def checkType(expression: Expression, expectedType: DataType): Unit = { val plan = Project(Seq(Alias(expression, "c")()), relation) - assert(analyzer(plan).schema.fields(0).dataType === expectedType) + assert(analyzer.execute(plan).schema.fields(0).dataType === expectedType) } private def checkComparison(expression: Expression, expectedType: DataType): Unit = { val plan = Project(Alias(expression, "c")() :: Nil, relation) - val comparison = analyzer(plan).collect { + val comparison = analyzer.execute(plan).collect { case Project(Alias(e: BinaryComparison, _) :: Nil, _) => e }.head assert(comparison.left.dataType === expectedType) @@ -64,7 +64,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { val plan = Union(Project(Seq(Alias(left, "l")()), relation), Project(Seq(Alias(right, "r")()), relation)) - val (l, r) = analyzer(plan).collect { + val (l, r) = analyzer.execute(plan).collect { case Union(left, right) => (left.output.head, right.output.head) }.head assert(l.dataType === expectedType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index ef3114fd4dbab..b5ebe4b38e337 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -29,7 +29,7 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { expected: Any, inputRow: Row = EmptyRow): Unit = { val plan = try { - GenerateMutableProjection(Alias(expression, s"Optimized($expression)")() :: Nil)() + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)() } catch { case e: Throwable => val evaluated = GenerateProjection.expressionEvaluator(expression) @@ -56,10 +56,10 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { val futures = (1 to 20).map { _ => future { - GeneratePredicate(EqualTo(Literal(1), Literal(1))) - GenerateProjection(EqualTo(Literal(1), Literal(1)) :: Nil) - GenerateMutableProjection(EqualTo(Literal(1), Literal(1)) :: Nil) - GenerateOrdering(Add(Literal(1), Literal(1)).asc :: Nil) + GeneratePredicate.generate(EqualTo(Literal(1), Literal(1))) + GenerateProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) + GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index bcc0c404d2cfb..97af2e0fd0502 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -25,13 +25,13 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ */ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { override def checkEvaluation( - expression: Expression, - expected: Any, - inputRow: Row = EmptyRow): Unit = { + expression: Expression, + expected: Any, + inputRow: Row = EmptyRow): Unit = { lazy val evaluated = GenerateProjection.expressionEvaluator(expression) val plan = try { - GenerateProjection(Alias(expression, s"Optimized($expression)")() :: Nil) + GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) } catch { case e: Throwable => fail( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 72f06e26e05f1..6255578d7fa57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -61,7 +61,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { def checkCondition(input: Expression, expected: Expression): Unit = { val plan = testRelation.where(input).analyze - val actual = Optimize(plan).expressions.head + val actual = Optimize.execute(plan).expressions.head compareConditions(actual, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index e2ae0d25db1a5..2d16d668fd522 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -44,7 +44,7 @@ class CombiningLimitsSuite extends PlanTest { .limit(10) .limit(5) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a) @@ -61,7 +61,7 @@ class CombiningLimitsSuite extends PlanTest { .limit(7) .limit(5) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 4396bd0dda9a9..14b28e8402610 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -47,7 +47,7 @@ class ConstantFoldingSuite extends PlanTest { .subquery('y) .select('a) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a.attr) @@ -74,7 +74,7 @@ class ConstantFoldingSuite extends PlanTest { Literal(2) * Literal(3) - Literal(6) / (Literal(4) - Literal(2)) )(Literal(9) / Literal(3) as Symbol("9/3")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -99,7 +99,7 @@ class ConstantFoldingSuite extends PlanTest { Literal(2) * 'a + Literal(4) as Symbol("c3"), 'a * (Literal(3) + Literal(4)) as Symbol("c4")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -127,7 +127,7 @@ class ConstantFoldingSuite extends PlanTest { (Literal(1) === Literal(1) || 'b > 1) && (Literal(1) === Literal(2) || 'b < 10))) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -144,7 +144,7 @@ class ConstantFoldingSuite extends PlanTest { Cast(Literal("2"), IntegerType) + Literal(3) + 'a as Symbol("c1"), Coalesce(Seq(Cast(Literal("abc"), IntegerType), Literal(3))) as Symbol("c2")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -163,7 +163,7 @@ class ConstantFoldingSuite extends PlanTest { Rand + Literal(1) as Symbol("c1"), Sum('a) as Symbol("c2")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation @@ -210,7 +210,7 @@ class ConstantFoldingSuite extends PlanTest { Contains("abc", Literal.create(null, StringType)) as 'c20 ) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index cf42d43823399..6841bd9890c97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -49,7 +49,7 @@ class ConvertToLocalRelationSuite extends PlanTest { UnresolvedAttribute("a").as("a1"), (UnresolvedAttribute("b") + 1).as("b1")) - val optimized = Optimize(projectOnLocal.analyze) + val optimized = Optimize.execute(projectOnLocal.analyze) comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala index 2f3704be59a9d..a4a3a66b8b229 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExpressionOptimizationSuite.scala @@ -30,7 +30,7 @@ class ExpressionOptimizationSuite extends ExpressionEvaluationSuite { expected: Any, inputRow: Row = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer(plan) + val optimizedPlan = DefaultOptimizer.execute(plan) super.checkEvaluation(optimizedPlan.expressions.head, expected, inputRow) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 45cf695d20b01..aa9708b164efa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -50,7 +50,7 @@ class FilterPushdownSuite extends PlanTest { .subquery('y) .select('a) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a.attr) @@ -65,7 +65,7 @@ class FilterPushdownSuite extends PlanTest { .groupBy('a)('a, Count('b)) .select('a) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a) @@ -81,7 +81,7 @@ class FilterPushdownSuite extends PlanTest { .groupBy('a)('a as 'c, Count('b)) .select('c) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select('a) @@ -98,7 +98,7 @@ class FilterPushdownSuite extends PlanTest { .select('a) .where('a === 1) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where('a === 1) @@ -115,7 +115,7 @@ class FilterPushdownSuite extends PlanTest { .where('e === 1) .analyze - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where('a + 'b === 1) @@ -131,7 +131,7 @@ class FilterPushdownSuite extends PlanTest { .where('a === 1) .where('a === 2) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where('a === 1 && 'a === 2) @@ -152,7 +152,7 @@ class FilterPushdownSuite extends PlanTest { .where("y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val right = testRelation.where('b === 2) val correctAnswer = @@ -170,7 +170,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 1) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val right = testRelation val correctAnswer = @@ -188,7 +188,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 1 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val right = testRelation.where('b === 2) val correctAnswer = @@ -206,7 +206,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 1 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 1) val correctAnswer = left.join(y, LeftOuter).where("y.b".attr === 2).analyze @@ -223,7 +223,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 1 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val right = testRelation.where('b === 2).subquery('d) val correctAnswer = x.join(right, RightOuter).where("x.b".attr === 1).analyze @@ -240,7 +240,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('d) val correctAnswer = left.join(y, LeftOuter, Some("d.b".attr === 1)).where("y.b".attr === 2).analyze @@ -257,7 +257,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val right = testRelation.where('b === 2).subquery('d) val correctAnswer = x.join(right, RightOuter, Some("d.b".attr === 1)).where("x.b".attr === 2).analyze @@ -274,7 +274,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('l) val right = testRelation.where('b === 1).subquery('r) val correctAnswer = @@ -292,7 +292,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val right = testRelation.where('b === 2).subquery('r) val correctAnswer = x.join(right, RightOuter, Some("r.b".attr === 1)).where("x.b".attr === 2).analyze @@ -309,7 +309,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('l) val right = testRelation.where('b === 1).subquery('r) val correctAnswer = @@ -327,7 +327,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.subquery('l) val right = testRelation.where('b === 2).subquery('r) val correctAnswer = @@ -346,7 +346,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('b === 2).subquery('l) val right = testRelation.where('b === 1).subquery('r) val correctAnswer = @@ -365,7 +365,7 @@ class FilterPushdownSuite extends PlanTest { .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('a === 3).subquery('l) val right = testRelation.where('b === 2).subquery('r) val correctAnswer = @@ -382,7 +382,7 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { x.join(y, condition = Some("x.b".attr === "y.b".attr)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) comparePlans(analysis.EliminateSubQueries(originalQuery.analyze), optimized) } @@ -396,7 +396,7 @@ class FilterPushdownSuite extends PlanTest { .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1) && ("y.a".attr === 1)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('a === 1).subquery('x) val right = testRelation.where('a === 1).subquery('y) val correctAnswer = @@ -415,7 +415,7 @@ class FilterPushdownSuite extends PlanTest { .where(("x.b".attr === "y.b".attr) && ("x.a".attr === 1)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where('a === 1).subquery('x) val right = testRelation.subquery('y) val correctAnswer = @@ -436,7 +436,7 @@ class FilterPushdownSuite extends PlanTest { ("z.a".attr >= 3) && ("z.a".attr === "x.b".attr)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val lleft = testRelation.where('a >= 3).subquery('z) val left = testRelation.where('a === 1).subquery('x) val right = testRelation.subquery('y) @@ -457,7 +457,7 @@ class FilterPushdownSuite extends PlanTest { .generate(Explode('c_arr), true, false, Some("arr")) .where(('b >= 5) && ('a > 6)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where(('b >= 5) && ('a > 6)) @@ -474,7 +474,7 @@ class FilterPushdownSuite extends PlanTest { .generate(generator, true, false, Some("arr")) .where(('b >= 5) && ('c > 6)) } - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val referenceResult = { testRelationWithArrayType .where('b >= 5) @@ -502,7 +502,7 @@ class FilterPushdownSuite extends PlanTest { .generate(Explode('c_arr), true, false, Some("arr")) .where(('c > 6) || ('b > 5)).analyze } - val optimized = Optimize(originalQuery) + val optimized = Optimize.execute(originalQuery) comparePlans(optimized, originalQuery) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index b10577c8001e2..b3df487c84dc8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -41,7 +41,7 @@ class LikeSimplificationSuite extends PlanTest { testRelation .where(('a like "abc%") || ('a like "abc\\%")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(StartsWith('a, "abc") || ('a like "abc\\%")) .analyze @@ -54,7 +54,7 @@ class LikeSimplificationSuite extends PlanTest { testRelation .where('a like "%xyz") - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(EndsWith('a, "xyz")) .analyze @@ -67,7 +67,7 @@ class LikeSimplificationSuite extends PlanTest { testRelation .where(('a like "%mn%") || ('a like "%mn\\%")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(Contains('a, "mn") || ('a like "%mn\\%")) .analyze @@ -80,7 +80,7 @@ class LikeSimplificationSuite extends PlanTest { testRelation .where(('a like "") || ('a like "abc")) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(('a === "") || ('a === "abc")) .analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 966bc9ada1e6e..3eb399e68e70c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -49,7 +49,7 @@ class OptimizeInSuite extends PlanTest { .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2)))) .analyze - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2)) @@ -64,7 +64,7 @@ class OptimizeInSuite extends PlanTest { .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) .analyze - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala index 22992fb6f50d4..6b1e53cd42b24 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala @@ -41,7 +41,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest { testRelation .select(Upper(Upper('a)) as 'u) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Upper('a) as 'u) @@ -55,7 +55,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest { testRelation .select(Upper(Lower('a)) as 'u) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Upper('a) as 'u) @@ -69,7 +69,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest { testRelation .select(Lower(Upper('a)) as 'l) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Lower('a) as 'l) .analyze @@ -82,7 +82,7 @@ class SimplifyCaseConversionExpressionsSuite extends PlanTest { testRelation .select(Lower(Lower('a)) as 'l) - val optimized = Optimize(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select(Lower('a) as 'l) .analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index a54751dfa9a12..a3ad200800b02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -41,7 +40,7 @@ class UnionPushdownSuite extends PlanTest { test("union: filter to each side") { val query = testUnion.where('a === 1) - val optimized = Optimize(query.analyze) + val optimized = Optimize.execute(query.analyze) val correctAnswer = Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze @@ -52,7 +51,7 @@ class UnionPushdownSuite extends PlanTest { test("union: project to each side") { val query = testUnion.select('b) - val optimized = Optimize(query.analyze) + val optimized = Optimize.execute(query.analyze) val correctAnswer = Union(testRelation.select('b), testRelation2.select('e)).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index 4b2d45584045f..2a641c63f87bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -34,7 +34,7 @@ class RuleExecutorSuite extends FunSuite { val batches = Batch("once", Once, DecrementLiterals) :: Nil } - assert(ApplyOnce(Literal(10)) === Literal(9)) + assert(ApplyOnce.execute(Literal(10)) === Literal(9)) } test("to fixed point") { @@ -42,7 +42,7 @@ class RuleExecutorSuite extends FunSuite { val batches = Batch("fixedPoint", FixedPoint(100), DecrementLiterals) :: Nil } - assert(ToFixedPoint(Literal(10)) === Literal(0)) + assert(ToFixedPoint.execute(Literal(10)) === Literal(0)) } test("to maxIterations") { @@ -50,6 +50,6 @@ class RuleExecutorSuite extends FunSuite { val batches = Batch("fixedPoint", FixedPoint(10), DecrementLiterals) :: Nil } - assert(ToFixedPoint(Literal(100)) === Literal(90)) + assert(ToFixedPoint.execute(Literal(100)) === Literal(90)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index bcd20c06c6dca..a279b0f07c38a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -132,16 +132,16 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer @transient - protected[sql] val ddlParser = new DDLParser(sqlParser.apply(_)) + protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_)) @transient protected[sql] val sqlParser = { val fallback = new catalyst.SqlParser - new SparkSQLParser(fallback(_)) + new SparkSQLParser(fallback.parse(_)) } protected[sql] def parseSql(sql: String): LogicalPlan = { - ddlParser(sql, false).getOrElse(sqlParser(sql)) + ddlParser.parse(sql, false).getOrElse(sqlParser.parse(sql)) } protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) @@ -1120,12 +1120,12 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] class QueryExecution(val logical: LogicalPlan) { def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) - lazy val analyzed: LogicalPlan = analyzer(logical) + lazy val analyzed: LogicalPlan = analyzer.execute(logical) lazy val withCachedData: LogicalPlan = { assertAnalyzed() cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData) + lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) // TODO: Don't just pick the first one... lazy val sparkPlan: SparkPlan = { @@ -1134,7 +1134,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) + lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[Row] = executedPlan.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index e159ffe66cb24..59c89800da00f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -144,7 +144,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if (codegenEnabled) { - GenerateProjection(expressions, inputSchema) + GenerateProjection.generate(expressions, inputSchema) } else { new InterpretedProjection(expressions, inputSchema) } @@ -156,7 +156,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") if(codegenEnabled) { - GenerateMutableProjection(expressions, inputSchema) + GenerateMutableProjection.generate(expressions, inputSchema) } else { () => new InterpretedMutableProjection(expressions, inputSchema) } @@ -166,15 +166,15 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = { if (codegenEnabled) { - GeneratePredicate(expression, inputSchema) + GeneratePredicate.generate(expression, inputSchema) } else { - InterpretedPredicate(expression, inputSchema) + InterpretedPredicate.create(expression, inputSchema) } } protected def newOrdering(order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[Row] = { if (codegenEnabled) { - GenerateOrdering(order, inputSchema) + GenerateOrdering.generate(order, inputSchema) } else { new RowOrdering(order, inputSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 83b1a83765153..56200f6b8c8a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -59,7 +59,7 @@ case class BroadcastNestedLoopJoin( } @transient private lazy val boundCondition = - InterpretedPredicate( + InterpretedPredicate.create( condition .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index 1fa7e7bd0406c..e06f63f94b78b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -45,7 +45,7 @@ case class LeftSemiJoinBNL( override def right: SparkPlan = broadcast @transient private lazy val boundCondition = - InterpretedPredicate( + InterpretedPredicate.create( condition .map(c => BindReferences.bindReference(c, left.output ++ right.output)) .getOrElse(Literal(true))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index af7b3c81ae7b2..88466f52bd4e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -611,7 +611,7 @@ private[sql] case class ParquetRelation2( val rawPredicate = partitionPruningPredicates.reduceOption(expressions.And).getOrElse(Literal(true)) - val boundPredicate = InterpretedPredicate(rawPredicate transform { + val boundPredicate = InterpretedPredicate.create(rawPredicate transform { case a: AttributeReference => val index = partitionColumns.indexWhere(a.name == _.name) BoundReference(index, partitionColumns(index).dataType, nullable = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 78d494184e759..e7a0685e013d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -38,9 +38,9 @@ private[sql] class DDLParser( parseQuery: String => LogicalPlan) extends AbstractSparkSQLParser with DataTypeParser with Logging { - def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { + def parse(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = { try { - Some(apply(input)) + Some(parse(input)) } catch { case ddlException: DDLException => throw ddlException case _ if !exceptionOnError => None diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c4a73b3004076..dd06b2620c5ee 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -93,7 +93,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { if (conf.dialect == "sql") { super.sql(substituted) } else if (conf.dialect == "hiveql") { - val ddlPlan = ddlParserWithHiveQL(sqlText, exceptionOnError = false) + val ddlPlan = ddlParserWithHiveQL.parse(sqlText, exceptionOnError = false) DataFrame(this, ddlPlan.getOrElse(HiveQl.parseSql(substituted))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 85061f22772dd..0ea6d57b816c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -144,7 +144,7 @@ private[hive] object HiveQl { protected val hqlParser = { val fallback = new ExtendedHiveQlParser - new SparkSQLParser(fallback(_)) + new SparkSQLParser(fallback.parse(_)) } /** @@ -240,7 +240,7 @@ private[hive] object HiveQl { /** Returns a LogicalPlan for a given HiveQL string. */ - def parseSql(sql: String): LogicalPlan = hqlParser(sql) + def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql) val errorRegEx = "line (\\d+):(\\d+) (.*)".r diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index a6f4fbe8aba06..be9249a8b1f44 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -119,9 +119,9 @@ private[hive] trait HiveStrategies { val inputData = new GenericMutableRow(relation.partitionKeys.size) val pruningCondition = if (codegenEnabled) { - GeneratePredicate(castedPredicate) + GeneratePredicate.generate(castedPredicate) } else { - InterpretedPredicate(castedPredicate) + InterpretedPredicate.create(castedPredicate) } val partitions = relation.hiveQlPartitions.filter { part => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 6570fa1043900..9f17bca083d13 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -185,7 +185,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) // Proceed with analysis. - analyzer(logical) + analyzer.execute(logical) } } From fbe7106d75c6a1624d10793fba6759703bc5c6e6 Mon Sep 17 00:00:00 2001 From: szheng79 Date: Wed, 22 Apr 2015 13:02:55 -0700 Subject: [PATCH 090/191] [SPARK-7039][SQL]JDBCRDD: Add support on type NVARCHAR Issue: https://issues.apache.org/jira/browse/SPARK-7039 Add support to column type NVARCHAR in Sql Server java.sql.Types: http://docs.oracle.com/javase/7/docs/api/java/sql/Types.html Author: szheng79 Closes #5618 from szheng79/patch-1 and squashes the following commits: 10da99c [szheng79] Update JDBCRDD.scala eab0bd8 [szheng79] Add support on type NVARCHAR --- sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index b9022fcd9e3ad..8b1edec20feee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -60,6 +60,7 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.NCLOB => StringType case java.sql.Types.NULL => null case java.sql.Types.NUMERIC => DecimalType.Unlimited + case java.sql.Types.NVARCHAR => StringType case java.sql.Types.OTHER => null case java.sql.Types.REAL => DoubleType case java.sql.Types.REF => StringType From baf865ddc2cff9b99d6aeab9861e030da511257f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 15:26:58 -0700 Subject: [PATCH 091/191] [SPARK-7059][SQL] Create a DataFrame join API to facilitate equijoin. Author: Reynold Xin Closes #5638 from rxin/joinUsing and squashes the following commits: 13e9cc9 [Reynold Xin] Code review + Python. b1bd914 [Reynold Xin] [SPARK-7059][SQL] Create a DataFrame join API to facilitate equijoin and self join. --- python/pyspark/sql/dataframe.py | 9 ++++- .../org/apache/spark/sql/DataFrame.scala | 37 +++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 40 ++++++++++++++----- 3 files changed, 74 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ca9bf8efb945c..c8c30ce4022c8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -459,16 +459,23 @@ def join(self, other, joinExprs=None, joinType=None): The following performs a full outer join between ``df1`` and ``df2``. :param other: Right side of the join - :param joinExprs: Join expression + :param joinExprs: a string for join column name, or a join expression (Column). + If joinExprs is a string indicating the name of the join column, + the column must exist on both sides, and this performs an inner equi-join. :param joinType: str, default 'inner'. One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + + >>> df.join(df2, 'name').select(df.name, df2.height).collect() + [Row(name=u'Bob', height=85)] """ if joinExprs is None: jdf = self._jdf.join(other._jdf) + elif isinstance(joinExprs, basestring): + jdf = self._jdf.join(other._jdf, joinExprs) else: assert isinstance(joinExprs, Column), "joinExprs should be Column" if joinType is None: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 03d9834d1d131..ca6ae482eb2ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -342,6 +342,43 @@ class DataFrame private[sql]( Join(logicalPlan, right.logicalPlan, joinType = Inner, None) } + /** + * Inner equi-join with another [[DataFrame]] using the given column. + * + * Different from other join functions, the join column will only appear once in the output, + * i.e. similar to SQL's `JOIN USING` syntax. + * + * {{{ + * // Joining df1 and df2 using the column "user_id" + * df1.join(df2, "user_id") + * }}} + * + * Note that if you perform a self-join using this function without aliasing the input + * [[DataFrame]]s, you will NOT be able to reference any columns after the join, since + * there is no way to disambiguate which side of the join you would like to reference. + * + * @param right Right side of the join operation. + * @param usingColumn Name of the column to join on. This column must exist on both sides. + * @group dfops + */ + def join(right: DataFrame, usingColumn: String): DataFrame = { + // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right + // by creating a new instance for one of the branch. + val joined = sqlContext.executePlan( + Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] + + // Project only one of the join column. + val joinedCol = joined.right.resolve(usingColumn) + Project( + joined.output.filterNot(_ == joinedCol), + Join( + joined.left, + joined.right, + joinType = Inner, + Some(EqualTo(joined.left.resolve(usingColumn), joined.right.resolve(usingColumn)))) + ) + } + /** * Inner join with another [[DataFrame]], using the given join expression. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b9b6a400ae195..5ec06d448e50f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -109,15 +109,6 @@ class DataFrameSuite extends QueryTest { assert(testData.head(2).head.schema === testData.schema) } - test("self join") { - val df1 = testData.select(testData("key")).as('df1) - val df2 = testData.select(testData("key")).as('df2) - - checkAnswer( - df1.join(df2, $"df1.key" === $"df2.key"), - sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) - } - test("simple explode") { val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words") @@ -127,8 +118,35 @@ class DataFrameSuite extends QueryTest { ) } - test("self join with aliases") { - val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + test("join - join using") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str") + + checkAnswer( + df.join(df2, "int"), + Row(1, "1", "2") :: Row(2, "2", "3") :: Row(3, "3", "4") :: Nil) + } + + test("join - join using self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") + + // self join + checkAnswer( + df.join(df, "int"), + Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: Nil) + } + + test("join - self join") { + val df1 = testData.select(testData("key")).as('df1) + val df2 = testData.select(testData("key")).as('df2) + + checkAnswer( + df1.join(df2, $"df1.key" === $"df2.key"), + sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) + } + + test("join - using aliases after self join") { + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") checkAnswer( df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(), Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) From f4f39981f4f5e88c30eec7d0b107e2c3cdc268c9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 22 Apr 2015 17:22:26 -0700 Subject: [PATCH 092/191] [SPARK-6827] [MLLIB] Wrap FPGrowthModel.freqItemsets and make it consistent with Java API Make PySpark ```FPGrowthModel.freqItemsets``` consistent with Java/Scala API like ```MatrixFactorizationModel.userFeatures``` It return a RDD with each tuple is composed of an array and a long value. I think it's difficult to implement namedtuples to wrap the output because items of freqItemsets can be any type with arbitrary length which is tedious to impelement corresponding SerDe function. Author: Yanbo Liang Closes #5614 from yanboliang/spark-6827 and squashes the following commits: da8c404 [Yanbo Liang] use namedtuple 5532e78 [Yanbo Liang] Wrap FPGrowthModel.freqItemsets and make it consistent with Java API --- python/pyspark/mllib/fpm.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index 628ccc01cf3cc..d8df02bdbaba9 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -15,6 +15,10 @@ # limitations under the License. # +import numpy +from numpy import array +from collections import namedtuple + from pyspark import SparkContext from pyspark.rdd import ignore_unicode_prefix from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc @@ -36,14 +40,14 @@ class FPGrowthModel(JavaModelWrapper): >>> rdd = sc.parallelize(data, 2) >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) - [([u'a'], 4), ([u'c'], 3), ([u'c', u'a'], 3)] + [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... """ def freqItemsets(self): """ - Get the frequent itemsets of this model + Returns the frequent itemsets of this model. """ - return self.call("getFreqItemsets") + return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1]))) class FPGrowth(object): @@ -67,6 +71,11 @@ def train(cls, data, minSupport=0.3, numPartitions=-1): model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) return FPGrowthModel(model) + class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])): + """ + Represents an (items, freq) tuple. + """ + def _test(): import doctest From 04525c077c638a7e615c294ba988e35036554f5f Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Wed, 22 Apr 2015 19:14:28 -0700 Subject: [PATCH 093/191] [SPARK-6967] [SQL] fix date type convertion in jdbcrdd This pr convert java.sql.Date type into Int for JDBCRDD. Author: Daoyuan Wang Closes #5590 from adrian-wang/datebug and squashes the following commits: f897b81 [Daoyuan Wang] add a test case 3c9184c [Daoyuan Wang] fix date type convertion in jdbcrdd --- .../src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 4 ++-- .../test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 8b1edec20feee..b975191d41963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -350,8 +350,8 @@ private[sql] class JDBCRDD( val pos = i + 1 conversions(i) match { case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) - // TODO(davies): convert Date into Int - case DateConversion => mutableRow.update(i, rs.getDate(pos)) + case DateConversion => + mutableRow.update(i, DateUtils.fromJavaDate(rs.getDate(pos))) case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos)) case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 3596b183d4328..db096af4535a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -249,6 +249,13 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(rows(0).getAs[java.sql.Timestamp](2).getNanos === 543543543) } + test("test DATE types") { + val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").collect() + val cachedRows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").cache().collect() + assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + } + test("H2 floating-point types") { val rows = sql("SELECT * FROM flttypes").collect() assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==. From b69c4f9b2e8544f1b178db2aefbcaa166f76cb7a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 21:24:22 -0700 Subject: [PATCH 094/191] Disable flaky test: ReceiverSuite "block generator throttling". --- .../test/scala/org/apache/spark/streaming/ReceiverSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index e7aee6eadbfc7..b84129fd70dd4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -155,7 +155,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { assert(recordedData.toSet === generatedData.toSet) } - test("block generator throttling") { + ignore("block generator throttling") { val blockGeneratorListener = new FakeBlockGeneratorListener val blockIntervalMs = 100 val maxRate = 1001 From 1b85e08509a0e19dc35b6ab869977254156cdaf1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 21:35:12 -0700 Subject: [PATCH 095/191] [MLlib] UnaryTransformer nullability should not depend on PrimitiveType. Author: Reynold Xin Closes #5644 from rxin/mllib-nullable and squashes the following commits: a727e5b [Reynold Xin] [MLlib] UnaryTransformer nullability should not depend on primitive types. --- mllib/src/main/scala/org/apache/spark/ml/Transformer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 7fb87fe452ee6..0acda71ec6045 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -94,7 +94,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.") } val outputFields = schema.fields :+ - StructField(map(outputCol), outputDataType, !outputDataType.isPrimitive) + StructField(map(outputCol), outputDataType, nullable = false) StructType(outputFields) } From d20686066e978dd12e618e3978f109f05bc412fe Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 21:35:42 -0700 Subject: [PATCH 096/191] [SPARK-7066][MLlib] VectorAssembler should use NumericType not NativeType. Author: Reynold Xin Closes #5642 from rxin/mllib-native-type and squashes the following commits: e23af5b [Reynold Xin] Remove StringType 7cbb205 [Reynold Xin] [SPARK-7066][MLlib] VectorAssembler should use NumericType and StringType, not NativeType. --- .../scala/org/apache/spark/ml/feature/VectorAssembler.scala | 5 +++-- .../main/scala/org/apache/spark/sql/types/dataTypes.scala | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index e567e069e7c0b..fd16d3d6c268b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -55,7 +55,8 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { schema(c).dataType match { case DoubleType => UnresolvedAttribute(c) case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) - case _: NativeType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() + case _: NumericType => + Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() } } dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol))) @@ -67,7 +68,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { val outputColName = map(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) inputDataTypes.foreach { - case _: NativeType => + case _: NumericType => case t if t.isInstanceOf[VectorUDT] => case other => throw new IllegalArgumentException(s"Data type $other is not supported.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 7cd7bd1914c95..ddf9d664c6826 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -299,7 +299,7 @@ class NullType private() extends DataType { case object NullType extends NullType -protected[spark] object NativeType { +protected[sql] object NativeType { val all = Seq( IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) @@ -327,7 +327,7 @@ protected[sql] object PrimitiveType { } } -protected[spark] abstract class NativeType extends DataType { +protected[sql] abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] private[sql] val ordering: Ordering[JvmType] From 03e85b4a11899f37424cd6e1f8d71f1d704c90bb Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Wed, 22 Apr 2015 21:42:09 -0700 Subject: [PATCH 097/191] [SPARK-7046] Remove InputMetrics from BlockResult This is a code cleanup. The BlockResult class originally contained an InputMetrics object so that InputMetrics could directly be used as the InputMetrics for the whole task. Now we copy the fields out of here, and the presence of this object is confusing because it's only a partial input metrics (it doesn't include the records read). Because this object is no longer useful (and is confusing), it should be removed. Author: Kay Ousterhout Closes #5627 from kayousterhout/SPARK-7046 and squashes the following commits: bf64bbe [Kay Ousterhout] Import fix a08ca19 [Kay Ousterhout] [SPARK-7046] Remove InputMetrics from BlockResult --- .../main/scala/org/apache/spark/CacheManager.scala | 5 ++--- .../org/apache/spark/storage/BlockManager.scala | 9 +++------ .../org/apache/spark/storage/BlockManagerSuite.scala | 12 ++++++------ 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index a96d754744a05..4d20c7369376e 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -44,10 +44,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { blockManager.get(key) match { case Some(blockResult) => // Partition is already materialized, so just return its values - val inputMetrics = blockResult.inputMetrics val existingMetrics = context.taskMetrics - .getInputMetricsForReadMethod(inputMetrics.readMethod) - existingMetrics.incBytesRead(inputMetrics.bytesRead) + .getInputMetricsForReadMethod(blockResult.readMethod) + existingMetrics.incBytesRead(blockResult.bytes) val iter = blockResult.data.asInstanceOf[Iterator[T]] new InterruptibleIterator[T](context, iter) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 145a9c1ae3391..55718e584c195 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import scala.util.Random import sun.nio.ch.DirectBuffer import org.apache.spark._ -import org.apache.spark.executor._ +import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} @@ -50,11 +50,8 @@ private[spark] case class ArrayValues(buffer: Array[Any]) extends BlockValues /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( val data: Iterator[Any], - readMethod: DataReadMethod.Value, - bytes: Long) { - val inputMetrics = new InputMetrics(readMethod) - inputMetrics.incBytesRead(bytes) -} + val readMethod: DataReadMethod.Value, + val bytes: Long) /** * Manager running on every node (driver and executors) which provides interfaces for putting and diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 545722b050ee8..7d82a7c66ad1a 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -428,19 +428,19 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val list1Get = store.get("list1") assert(list1Get.isDefined, "list1 expected to be in store") assert(list1Get.get.data.size === 2) - assert(list1Get.get.inputMetrics.bytesRead === list1SizeEstimate) - assert(list1Get.get.inputMetrics.readMethod === DataReadMethod.Memory) + assert(list1Get.get.bytes === list1SizeEstimate) + assert(list1Get.get.readMethod === DataReadMethod.Memory) val list2MemoryGet = store.get("list2memory") assert(list2MemoryGet.isDefined, "list2memory expected to be in store") assert(list2MemoryGet.get.data.size === 3) - assert(list2MemoryGet.get.inputMetrics.bytesRead === list2SizeEstimate) - assert(list2MemoryGet.get.inputMetrics.readMethod === DataReadMethod.Memory) + assert(list2MemoryGet.get.bytes === list2SizeEstimate) + assert(list2MemoryGet.get.readMethod === DataReadMethod.Memory) val list2DiskGet = store.get("list2disk") assert(list2DiskGet.isDefined, "list2memory expected to be in store") assert(list2DiskGet.get.data.size === 3) // We don't know the exact size of the data on disk, but it should certainly be > 0. - assert(list2DiskGet.get.inputMetrics.bytesRead > 0) - assert(list2DiskGet.get.inputMetrics.readMethod === DataReadMethod.Disk) + assert(list2DiskGet.get.bytes > 0) + assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } test("in-memory LRU storage") { From d9e70f331fc3999d615ede49fc69a993dc65f272 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 22 Apr 2015 22:18:56 -0700 Subject: [PATCH 098/191] [HOTFIX][SQL] Fix broken cached test Added in #5475. Pointed as broken in #5639. /cc marmbrus Author: Liang-Chi Hsieh Closes #5640 from viirya/fix_cached_test and squashes the following commits: c0cf69a [Liang-Chi Hsieh] Fix broken cached test. --- .../apache/spark/sql/CachedTableSuite.scala | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 01e3b8671071e..0772e5e187425 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -300,19 +300,26 @@ class CachedTableSuite extends QueryTest { } test("Clear accumulators when uncacheTable to prevent memory leaking") { - val accsSize = Accumulators.originals.size - sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - cacheTable("t1") - cacheTable("t2") + + Accumulators.synchronized { + val accsSize = Accumulators.originals.size + cacheTable("t1") + cacheTable("t2") + assert((accsSize + 2) == Accumulators.originals.size) + } + sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() sql("SELECT * FROM t1").count() sql("SELECT * FROM t2").count() - uncacheTable("t1") - uncacheTable("t2") - assert(accsSize >= Accumulators.originals.size) + Accumulators.synchronized { + val accsSize = Accumulators.originals.size + uncacheTable("t1") + uncacheTable("t2") + assert((accsSize - 2) == Accumulators.originals.size) + } } } From 2d33323cadbf58dd1d05ffff998d18cad6a896cd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 23:54:48 -0700 Subject: [PATCH 099/191] [MLlib] Add support for BooleanType to VectorAssembler. Author: Reynold Xin Closes #5648 from rxin/vectorAssembler-boolean and squashes the following commits: 1bf3d40 [Reynold Xin] [MLlib] Add support for BooleanType to VectorAssembler. --- .../scala/org/apache/spark/ml/feature/VectorAssembler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index fd16d3d6c268b..7b2a451ca5ee5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -55,7 +55,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { schema(c).dataType match { case DoubleType => UnresolvedAttribute(c) case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) - case _: NumericType => + case _: NumericType | BooleanType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() } } @@ -68,7 +68,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { val outputColName = map(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) inputDataTypes.foreach { - case _: NumericType => + case _: NumericType | BooleanType => case t if t.isInstanceOf[VectorUDT] => case other => throw new IllegalArgumentException(s"Data type $other is not supported.") From 29163c520087e89ca322521db1dd8656d86a6f0e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Apr 2015 23:55:20 -0700 Subject: [PATCH 100/191] [SPARK-7068][SQL] Remove PrimitiveType Author: Reynold Xin Closes #5646 from rxin/remove-primitive-type and squashes the following commits: 01b673d [Reynold Xin] [SPARK-7068][SQL] Remove PrimitiveType --- .../apache/spark/sql/types/dataTypes.scala | 70 ++++++++----------- .../spark/sql/parquet/ParquetConverter.scala | 11 +-- .../sql/parquet/ParquetTableOperations.scala | 2 +- .../spark/sql/parquet/ParquetTypes.scala | 6 +- .../apache/spark/sql/parquet/newParquet.scala | 13 ++-- 5 files changed, 48 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index ddf9d664c6826..42e26e05996dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -41,6 +41,21 @@ import org.apache.spark.util.Utils object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) + private val nonDecimalNameToType = { + (Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all) + .map(t => t.typeName -> t).toMap + } + + /** Given the string representation of a type, return its DataType */ + private def nameToType(name: String): DataType = { + val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r + name match { + case "decimal" => DecimalType.Unlimited + case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) + case other => nonDecimalNameToType(other) + } + } + private object JSortedObject { def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { case JObject(seq) => Some(seq.toList.sortBy(_._1)) @@ -51,7 +66,7 @@ object DataType { // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. private def parseDataType(json: JValue): DataType = json match { case JString(name) => - PrimitiveType.nameToType(name) + nameToType(name) case JSortedObject( ("containsNull", JBool(n)), @@ -190,13 +205,11 @@ object DataType { equalsIgnoreNullability(leftKeyType, rightKeyType) && equalsIgnoreNullability(leftValueType, rightValueType) case (StructType(leftFields), StructType(rightFields)) => - leftFields.size == rightFields.size && - leftFields.zip(rightFields) - .forall{ - case (left, right) => - left.name == right.name && equalsIgnoreNullability(left.dataType, right.dataType) - } - case (left, right) => left == right + leftFields.length == rightFields.length && + leftFields.zip(rightFields).forall { case (l, r) => + l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) + } + case (l, r) => l == r } } @@ -225,12 +238,11 @@ object DataType { equalsIgnoreCompatibleNullability(fromValue, toValue) case (StructType(fromFields), StructType(toFields)) => - fromFields.size == toFields.size && - fromFields.zip(toFields).forall { - case (fromField, toField) => - fromField.name == toField.name && - (toField.nullable || !fromField.nullable) && - equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { case (fromField, toField) => + fromField.name == toField.name && + (toField.nullable || !fromField.nullable) && + equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) } case (fromDataType, toDataType) => fromDataType == toDataType @@ -256,8 +268,6 @@ abstract class DataType { /** The default size of a value of this data type. */ def defaultSize: Int - def isPrimitive: Boolean = false - def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase private[sql] def jsonValue: JValue = typeName @@ -307,26 +317,6 @@ protected[sql] object NativeType { } -protected[sql] trait PrimitiveType extends DataType { - override def isPrimitive: Boolean = true -} - - -protected[sql] object PrimitiveType { - private val nonDecimals = Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all - private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap - - /** Given the string representation of a type, return its DataType */ - private[sql] def nameToType(name: String): DataType = { - val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r - name match { - case "decimal" => DecimalType.Unlimited - case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) - case other => nonDecimalNameToType(other) - } - } -} - protected[sql] abstract class NativeType extends DataType { private[sql] type JvmType @transient private[sql] val tag: TypeTag[JvmType] @@ -346,7 +336,7 @@ protected[sql] abstract class NativeType extends DataType { * @group dataType */ @DeveloperApi -class StringType private() extends NativeType with PrimitiveType { +class StringType private() extends NativeType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. @@ -373,7 +363,7 @@ case object StringType extends StringType * @group dataType */ @DeveloperApi -class BinaryType private() extends NativeType with PrimitiveType { +class BinaryType private() extends NativeType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. @@ -407,7 +397,7 @@ case object BinaryType extends BinaryType *@group dataType */ @DeveloperApi -class BooleanType private() extends NativeType with PrimitiveType { +class BooleanType private() extends NativeType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. @@ -492,7 +482,7 @@ case object DateType extends DateType * * @group dataType */ -abstract class NumericType extends NativeType with PrimitiveType { +abstract class NumericType extends NativeType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index bc108e37dfb0f..116424539da11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -146,7 +146,8 @@ private[sql] object CatalystConverter { } } // All other primitive types use the default converter - case ctype: PrimitiveType => { // note: need the type tag here! + case ctype: DataType if ParquetTypesConverter.isPrimitiveType(ctype) => { + // note: need the type tag here! new CatalystPrimitiveConverter(parent, fieldIndex) } case _ => throw new RuntimeException( @@ -324,9 +325,9 @@ private[parquet] class CatalystGroupConverter( override def start(): Unit = { current = ArrayBuffer.fill(size)(null) - converters.foreach { - converter => if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer + converters.foreach { converter => + if (!converter.isPrimitive) { + converter.asInstanceOf[CatalystConverter].clearBuffer() } } } @@ -612,7 +613,7 @@ private[parquet] class CatalystArrayConverter( override def start(): Unit = { if (!converter.isPrimitive) { - converter.asInstanceOf[CatalystConverter].clearBuffer + converter.asInstanceOf[CatalystConverter].clearBuffer() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 1c868da23e060..a938b77578686 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -268,7 +268,7 @@ private[sql] case class InsertIntoParquetTable( val job = new Job(sqlContext.sparkContext.hadoopConfiguration) val writeSupport = - if (child.output.map(_.dataType).forall(_.isPrimitive)) { + if (child.output.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { log.debug("Initializing MutableRowWriteSupport") classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 60e1bec4db8e5..1dc819b5d7b9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -48,8 +48,10 @@ private[parquet] case class ParquetTypeInfo( length: Option[Int] = None) private[parquet] object ParquetTypesConverter extends Logging { - def isPrimitiveType(ctype: DataType): Boolean = - classOf[PrimitiveType] isAssignableFrom ctype.getClass + def isPrimitiveType(ctype: DataType): Boolean = ctype match { + case _: NumericType | BooleanType | StringType | BinaryType => true + case _: DataType => false + } def toPrimitiveDataType( parquetType: ParquetPrimitiveType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 88466f52bd4e9..85e60733bc57a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -634,12 +634,13 @@ private[sql] case class ParquetRelation2( // before calling execute(). val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val writeSupport = if (parquetSchema.map(_.dataType).forall(_.isPrimitive)) { - log.debug("Initializing MutableRowWriteSupport") - classOf[MutableRowWriteSupport] - } else { - classOf[RowWriteSupport] - } + val writeSupport = + if (parquetSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { + log.debug("Initializing MutableRowWriteSupport") + classOf[MutableRowWriteSupport] + } else { + classOf[RowWriteSupport] + } ParquetOutputFormat.setWriteSupportClass(job, writeSupport) From f60bece14f98450b4a71b00d7b58525f06e1f9ed Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 23 Apr 2015 01:43:40 -0700 Subject: [PATCH 101/191] [SPARK-7069][SQL] Rename NativeType -> AtomicType. Also renamed JvmType to InternalType. Author: Reynold Xin Closes #5651 from rxin/native-to-atomic-type and squashes the following commits: cbd4028 [Reynold Xin] [SPARK-7069][SQL] Rename NativeType -> AtomicType. --- .../spark/sql/catalyst/ScalaReflection.scala | 24 ++-- .../sql/catalyst/expressions/arithmetic.scala | 4 +- .../expressions/codegen/CodeGenerator.scala | 18 ++- .../codegen/GenerateProjection.scala | 4 +- .../sql/catalyst/expressions/predicates.scala | 10 +- .../spark/sql/catalyst/expressions/rows.scala | 6 +- .../apache/spark/sql/types/dataTypes.scala | 114 +++++++++--------- .../spark/sql/columnar/ColumnAccessor.scala | 2 +- .../spark/sql/columnar/ColumnBuilder.scala | 4 +- .../spark/sql/columnar/ColumnType.scala | 6 +- .../CompressibleColumnAccessor.scala | 4 +- .../CompressibleColumnBuilder.scala | 4 +- .../compression/CompressionScheme.scala | 10 +- .../compression/compressionSchemes.scala | 42 +++---- .../org/apache/spark/sql/json/JsonRDD.scala | 6 +- .../spark/sql/parquet/ParquetConverter.scala | 12 +- .../sql/parquet/ParquetTableSupport.scala | 2 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 6 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 8 +- .../sql/columnar/ColumnarTestUtils.scala | 6 +- .../compression/DictionaryEncodingSuite.scala | 4 +- .../compression/IntegralDeltaSuite.scala | 6 +- .../compression/RunLengthEncodingSuite.scala | 4 +- .../TestCompressibleColumnBuilder.scala | 6 +- 24 files changed, 159 insertions(+), 153 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d9521953cad73..c52965507c715 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst -import java.sql.Timestamp - import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -110,7 +108,7 @@ trait ScalaReflection { StructField(p.name.toString, dataType, nullable) }), nullable = true) case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< typeOf[java.sql.Date] => Schema(DateType, nullable = true) case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) case t if t <:< typeOf[java.math.BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) @@ -136,20 +134,20 @@ trait ScalaReflection { def typeOfObject: PartialFunction[Any, DataType] = { // The data type can be determined without ambiguity. - case obj: BooleanType.JvmType => BooleanType - case obj: BinaryType.JvmType => BinaryType + case obj: Boolean => BooleanType + case obj: Array[Byte] => BinaryType case obj: String => StringType - case obj: StringType.JvmType => StringType - case obj: ByteType.JvmType => ByteType - case obj: ShortType.JvmType => ShortType - case obj: IntegerType.JvmType => IntegerType - case obj: LongType.JvmType => LongType - case obj: FloatType.JvmType => FloatType - case obj: DoubleType.JvmType => DoubleType + case obj: UTF8String => StringType + case obj: Byte => ByteType + case obj: Short => ShortType + case obj: Int => IntegerType + case obj: Long => LongType + case obj: Float => FloatType + case obj: Double => DoubleType case obj: java.sql.Date => DateType case obj: java.math.BigDecimal => DecimalType.Unlimited case obj: Decimal => DecimalType.Unlimited - case obj: TimestampType.JvmType => TimestampType + case obj: java.sql.Timestamp => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a // Catalyst data type. A user should provide his/her specific rules diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 566b34f7c3a6a..140ccd8d3796f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -346,7 +346,7 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { } lazy val ordering = left.dataType match { - case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case other => sys.error(s"Type $other does not support ordered operations") } @@ -391,7 +391,7 @@ case class MinOf(left: Expression, right: Expression) extends Expression { } lazy val ordering = left.dataType match { - case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case other => sys.error(s"Type $other does not support ordered operations") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index eeffedb558c1b..cbe520347385d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -623,7 +623,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { dataType match { case StringType => q"$inputRow($ordinal).asInstanceOf[org.apache.spark.sql.types.UTF8String]" - case dt @ NativeType() => q"$inputRow.${accessorForType(dt)}($ordinal)" + case dt: DataType if isNativeType(dt) => q"$inputRow.${accessorForType(dt)}($ordinal)" case _ => q"$inputRow.apply($ordinal).asInstanceOf[${termForType(dataType)}]" } } @@ -635,7 +635,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin value: TermName) = { dataType match { case StringType => q"$destinationRow.update($ordinal, $value)" - case dt @ NativeType() => q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" + case dt: DataType if isNativeType(dt) => + q"$destinationRow.${mutatorForType(dt)}($ordinal, $value)" case _ => q"$destinationRow.update($ordinal, $value)" } } @@ -675,7 +676,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } protected def termForType(dt: DataType) = dt match { - case n: NativeType => n.tag + case n: AtomicType => n.tag case _ => typeTag[Any] } + + /** + * List of data types that have special accessors and setters in [[Row]]. + */ + protected val nativeTypes = + Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + + /** + * Returns true if the data type has a special accessor and setter in [[Row]]. + */ + protected def isNativeType(dt: DataType) = nativeTypes.contains(dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index d491babc2bff0..584f938445c8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -109,7 +109,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" } - val specificAccessorFunctions = NativeType.all.map { dataType => + val specificAccessorFunctions = nativeTypes.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { // getString() is not used by expressions case (e, i) if e.dataType == dataType && dataType != StringType => @@ -135,7 +135,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } } - val specificMutatorFunctions = NativeType.all.map { dataType => + val specificMutatorFunctions = nativeTypes.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { // setString() is not used by expressions case (e, i) if e.dataType == dataType && dataType != StringType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 46522eb9c1264..9cb00cb2732ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, NativeType} +import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType} object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -211,7 +211,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso s"Types do not match ${left.dataType} != ${right.dataType}") } left.dataType match { - case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case other => sys.error(s"Type $other does not support ordered operations") } } @@ -240,7 +240,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo s"Types do not match ${left.dataType} != ${right.dataType}") } left.dataType match { - case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case other => sys.error(s"Type $other does not support ordered operations") } } @@ -269,7 +269,7 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar s"Types do not match ${left.dataType} != ${right.dataType}") } left.dataType match { - case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case other => sys.error(s"Type $other does not support ordered operations") } } @@ -298,7 +298,7 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar s"Types do not match ${left.dataType} != ${right.dataType}") } left.dataType match { - case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]] + case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case other => sys.error(s"Type $other does not support ordered operations") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 981373477a4bc..5fd892c42e69c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{UTF8String, DataType, StructType, NativeType} +import org.apache.spark.sql.types.{UTF8String, DataType, StructType, AtomicType} /** * An extended interface to [[Row]] that allows the values for each column to be updated. Setting @@ -227,9 +227,9 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { return if (order.direction == Ascending) 1 else -1 } else { val comparison = order.dataType match { - case n: NativeType if order.direction == Ascending => + case n: AtomicType if order.direction == Ascending => n.ordering.asInstanceOf[Ordering[Any]].compare(left, right) - case n: NativeType if order.direction == Descending => + case n: AtomicType if order.direction == Descending => n.ordering.asInstanceOf[Ordering[Any]].reverse.compare(left, right) case other => sys.error(s"Type $other does not support ordered operations") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 42e26e05996dd..87c7b7599366a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -42,7 +42,8 @@ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) private val nonDecimalNameToType = { - (Seq(NullType, DateType, TimestampType, BinaryType) ++ NativeType.all) + Seq(NullType, DateType, TimestampType, BinaryType, + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) .map(t => t.typeName -> t).toMap } @@ -309,22 +310,17 @@ class NullType private() extends DataType { case object NullType extends NullType -protected[sql] object NativeType { - val all = Seq( - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) - - def unapply(dt: DataType): Boolean = all.contains(dt) -} - - -protected[sql] abstract class NativeType extends DataType { - private[sql] type JvmType - @transient private[sql] val tag: TypeTag[JvmType] - private[sql] val ordering: Ordering[JvmType] +/** + * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. + */ +protected[sql] abstract class AtomicType extends DataType { + private[sql] type InternalType + @transient private[sql] val tag: TypeTag[InternalType] + private[sql] val ordering: Ordering[InternalType] @transient private[sql] val classTag = ScalaReflectionLock.synchronized { val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) + ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) } } @@ -336,13 +332,13 @@ protected[sql] abstract class NativeType extends DataType { * @group dataType */ @DeveloperApi -class StringType private() extends NativeType { +class StringType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "StringType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = UTF8String - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] type InternalType = UTF8String + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the StringType is 4096 bytes. @@ -363,13 +359,13 @@ case object StringType extends StringType * @group dataType */ @DeveloperApi -class BinaryType private() extends NativeType { +class BinaryType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Array[Byte] - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val ordering = new Ordering[JvmType] { + private[sql] type InternalType = Array[Byte] + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val ordering = new Ordering[InternalType] { def compare(x: Array[Byte], y: Array[Byte]): Int = { for (i <- 0 until x.length; if i < y.length) { val res = x(i).compareTo(y(i)) @@ -397,13 +393,13 @@ case object BinaryType extends BinaryType *@group dataType */ @DeveloperApi -class BooleanType private() extends NativeType { +class BooleanType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Boolean - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] type InternalType = Boolean + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the BooleanType is 1 byte. @@ -424,15 +420,15 @@ case object BooleanType extends BooleanType * @group dataType */ @DeveloperApi -class TimestampType private() extends NativeType { +class TimestampType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Timestamp + private[sql] type InternalType = Timestamp - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = new Ordering[JvmType] { + private[sql] val ordering = new Ordering[InternalType] { def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) } @@ -455,15 +451,15 @@ case object TimestampType extends TimestampType * @group dataType */ @DeveloperApi -class DateType private() extends NativeType { +class DateType private() extends AtomicType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DateType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Int + private[sql] type InternalType = Int - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the DateType is 4 bytes. @@ -482,13 +478,13 @@ case object DateType extends DateType * * @group dataType */ -abstract class NumericType extends NativeType { +abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no // longer an no argument constructor and thus the JVM cannot serialize the object anymore. - private[sql] val numeric: Numeric[JvmType] + private[sql] val numeric: Numeric[InternalType] } @@ -507,7 +503,7 @@ protected[sql] object IntegralType { protected[sql] sealed abstract class IntegralType extends NumericType { - private[sql] val integral: Integral[JvmType] + private[sql] val integral: Integral[InternalType] } @@ -522,11 +518,11 @@ class LongType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "LongType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Long - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Long + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the LongType is 8 bytes. @@ -552,11 +548,11 @@ class IntegerType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Int - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Int + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the IntegerType is 4 bytes. @@ -582,11 +578,11 @@ class ShortType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Short - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Short + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the ShortType is 2 bytes. @@ -612,11 +608,11 @@ class ByteType private() extends IntegralType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Byte - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Byte + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] /** * The default size of a value of the ByteType is 1 byte. @@ -641,8 +637,8 @@ protected[sql] object FractionalType { protected[sql] sealed abstract class FractionalType extends NumericType { - private[sql] val fractional: Fractional[JvmType] - private[sql] val asIntegral: Integral[JvmType] + private[sql] val fractional: Fractional[InternalType] + private[sql] val asIntegral: Integral[InternalType] } @@ -665,8 +661,8 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT /** No-arg constructor for kryo. */ protected def this() = this(null) - private[sql] type JvmType = Decimal - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Decimal + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = Decimal.DecimalIsFractional private[sql] val fractional = Decimal.DecimalIsFractional private[sql] val ordering = Decimal.DecimalIsFractional @@ -743,11 +739,11 @@ class DoubleType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Double - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Double + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Double]] private[sql] val fractional = implicitly[Fractional[Double]] - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] private[sql] val asIntegral = DoubleAsIfIntegral /** @@ -772,11 +768,11 @@ class FloatType private() extends FractionalType { // The companion object and this class is separated so the companion object also subclasses // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type JvmType = Float - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } + private[sql] type InternalType = Float + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Float]] private[sql] val fractional = implicitly[Fractional[Float]] - private[sql] val ordering = implicitly[Ordering[JvmType]] + private[sql] val ordering = implicitly[Ordering[InternalType]] private[sql] val asIntegral = FloatAsIfIntegral /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index f615fb33a7c35..64449b2659b4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -61,7 +61,7 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( protected def underlyingBuffer = buffer } -private[sql] abstract class NativeColumnAccessor[T <: NativeType]( +private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( override protected val buffer: ByteBuffer, override protected val columnType: NativeColumnType[T]) extends BasicColumnAccessor(buffer, columnType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 00ed70430b84d..aa10af400c815 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -84,10 +84,10 @@ private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( extends BasicColumnBuilder[T, JvmType](columnStats, columnType) with NullableColumnBuilder -private[sql] abstract class NativeColumnBuilder[T <: NativeType]( +private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) - extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType) + extends BasicColumnBuilder[T, T#InternalType](columnStats, columnType) with NullableColumnBuilder with AllCompressionSchemes with CompressibleColumnBuilder[T] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 1b9e0df2dcb5e..20be5ca9d0046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -101,16 +101,16 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( override def toString: String = getClass.getSimpleName.stripSuffix("$") } -private[sql] abstract class NativeColumnType[T <: NativeType]( +private[sql] abstract class NativeColumnType[T <: AtomicType]( val dataType: T, typeId: Int, defaultSize: Int) - extends ColumnType[T, T#JvmType](typeId, defaultSize) { + extends ColumnType[T, T#InternalType](typeId, defaultSize) { /** * Scala TypeTag. Can be used to create primitive arrays and hash tables. */ - def scalaTag: TypeTag[dataType.JvmType] = dataType.tag + def scalaTag: TypeTag[dataType.InternalType] = dataType.tag } private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala index d0b602a834dfe..cb205defbb1ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.columnar.compression import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType -private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAccessor { +private[sql] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { this: NativeColumnAccessor[T] => private var decoder: Decoder[T] = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index b9cfc5df550d1..8e2a1af6dae78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -22,7 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType /** * A stackable trait that builds optionally compressed byte buffer for a column. Memory layout of @@ -41,7 +41,7 @@ import org.apache.spark.sql.types.NativeType * header body * }}} */ -private[sql] trait CompressibleColumnBuilder[T <: NativeType] +private[sql] trait CompressibleColumnBuilder[T <: AtomicType] extends ColumnBuilder with Logging { this: NativeColumnBuilder[T] with WithCompressionSchemes => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index 879d29bcfa6f6..17c2d9b111188 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -22,9 +22,9 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType -private[sql] trait Encoder[T <: NativeType] { +private[sql] trait Encoder[T <: AtomicType] { def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {} def compressedSize: Int @@ -38,7 +38,7 @@ private[sql] trait Encoder[T <: NativeType] { def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer } -private[sql] trait Decoder[T <: NativeType] { +private[sql] trait Decoder[T <: AtomicType] { def next(row: MutableRow, ordinal: Int): Unit def hasNext: Boolean @@ -49,9 +49,9 @@ private[sql] trait CompressionScheme { def supports(columnType: ColumnType[_, _]): Boolean - def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] + def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] - def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] + def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] } private[sql] trait WithCompressionSchemes { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 8727d71c48bb7..534ae90ddbc8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -35,16 +35,16 @@ private[sql] case object PassThrough extends CompressionScheme { override def supports(columnType: ColumnType[_, _]): Boolean = true - override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def decoder[T <: NativeType]( + override def decoder[T <: AtomicType]( buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = { new this.Decoder(buffer, columnType) } - class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { + class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { override def uncompressedSize: Int = 0 override def compressedSize: Int = 0 @@ -56,7 +56,7 @@ private[sql] case object PassThrough extends CompressionScheme { } } - class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { override def next(row: MutableRow, ordinal: Int): Unit = { @@ -70,11 +70,11 @@ private[sql] case object PassThrough extends CompressionScheme { private[sql] case object RunLengthEncoding extends CompressionScheme { override val typeId = 1 - override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } - override def decoder[T <: NativeType]( + override def decoder[T <: AtomicType]( buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = { new this.Decoder(buffer, columnType) } @@ -84,7 +84,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { case _ => false } - class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { + class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { private var _uncompressedSize = 0 private var _compressedSize = 0 @@ -152,12 +152,12 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { } } - class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { private var run = 0 private var valueCount = 0 - private var currentValue: T#JvmType = _ + private var currentValue: T#InternalType = _ override def next(row: MutableRow, ordinal: Int): Unit = { if (valueCount == run) { @@ -181,12 +181,12 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { // 32K unique values allowed val MAX_DICT_SIZE = Short.MaxValue - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) : Decoder[T] = { new this.Decoder(buffer, columnType) } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T] = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) } @@ -195,7 +195,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { case _ => false } - class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { + class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary // overflows. private var _uncompressedSize = 0 @@ -208,7 +208,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { private var count = 0 // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself. - private var values = new mutable.ArrayBuffer[T#JvmType](1024) + private var values = new mutable.ArrayBuffer[T#InternalType](1024) // The dictionary that maps a value to the encoded short integer. private val dictionary = mutable.HashMap.empty[Any, Short] @@ -268,14 +268,14 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { override def compressedSize: Int = if (overflow) Int.MaxValue else dictionarySize + count * 2 } - class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) extends compression.Decoder[T] { private val dictionary = { // TODO Can we clean up this mess? Maybe move this to `DataType`? implicit val classTag = { val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe)) + ClassTag[T#InternalType](mirror.runtimeClass(columnType.scalaTag.tpe)) } Array.fill(buffer.getInt()) { @@ -296,12 +296,12 @@ private[sql] case object BooleanBitSet extends CompressionScheme { val BITS_PER_LONG = 64 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) : compression.Decoder[T] = { new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new this.Encoder).asInstanceOf[compression.Encoder[T]] } @@ -384,12 +384,12 @@ private[sql] case object BooleanBitSet extends CompressionScheme { private[sql] case object IntDelta extends CompressionScheme { override def typeId: Int = 4 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) : compression.Decoder[T] = { new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new Encoder).asInstanceOf[compression.Encoder[T]] } @@ -464,12 +464,12 @@ private[sql] case object IntDelta extends CompressionScheme { private[sql] case object LongDelta extends CompressionScheme { override def typeId: Int = 5 - override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) : compression.Decoder[T] = { new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]] } - override def encoder[T <: NativeType](columnType: NativeColumnType[T]): compression.Encoder[T] = { + override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = { (new Encoder).asInstanceOf[compression.Encoder[T]] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 29de7401dda71..6e94e7056eb0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -183,7 +183,7 @@ private[sql] object JsonRDD extends Logging { private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { // For Integer values, use LongType by default. val useLongType: PartialFunction[Any, DataType] = { - case value: IntegerType.JvmType => LongType + case value: IntegerType.InternalType => LongType } useLongType orElse ScalaReflection.typeOfObject orElse { @@ -411,11 +411,11 @@ private[sql] object JsonRDD extends Logging { desiredType match { case StringType => UTF8String(toString(value)) case _ if value == null || value == "" => null // guard the non string type - case IntegerType => value.asInstanceOf[IntegerType.JvmType] + case IntegerType => value.asInstanceOf[IntegerType.InternalType] case LongType => toLong(value) case DoubleType => toDouble(value) case DecimalType() => toDecimal(value) - case BooleanType => value.asInstanceOf[BooleanType.JvmType] + case BooleanType => value.asInstanceOf[BooleanType.InternalType] case NullType => null case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 116424539da11..36cb5e03bbca7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -90,7 +90,7 @@ private[sql] object CatalystConverter { createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) } // For native JVM types we use a converter with native arrays - case ArrayType(elementType: NativeType, false) => { + case ArrayType(elementType: AtomicType, false) => { new CatalystNativeArrayConverter(elementType, fieldIndex, parent) } // This is for other types of arrays, including those with nested fields @@ -118,19 +118,19 @@ private[sql] object CatalystConverter { case ShortType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addInt(value: Int): Unit = - parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.JvmType]) + parent.updateShort(fieldIndex, value.asInstanceOf[ShortType.InternalType]) } } case ByteType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addInt(value: Int): Unit = - parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) + parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.InternalType]) } } case DateType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addInt(value: Int): Unit = - parent.updateDate(fieldIndex, value.asInstanceOf[DateType.JvmType]) + parent.updateDate(fieldIndex, value.asInstanceOf[DateType.InternalType]) } } case d: DecimalType => { @@ -637,13 +637,13 @@ private[parquet] class CatalystArrayConverter( * @param capacity The (initial) capacity of the buffer */ private[parquet] class CatalystNativeArrayConverter( - val elementType: NativeType, + val elementType: AtomicType, val index: Int, protected[parquet] val parent: CatalystConverter, protected[parquet] var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE) extends CatalystConverter { - type NativeType = elementType.JvmType + type NativeType = elementType.InternalType private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index e05a4c20b0d41..c45c431438efc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -189,7 +189,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case t @ StructType(_) => writeStruct( t, value.asInstanceOf[CatalystConverter.StructScalaType[_]]) - case _ => writePrimitive(schema.asInstanceOf[NativeType], value) + case _ => writePrimitive(schema.asInstanceOf[AtomicType], value) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index fec487f1d2c82..7cefcf44061ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -34,7 +34,7 @@ class ColumnStatsSuite extends FunSuite { testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0)) testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) - def testColumnStats[T <: NativeType, U <: ColumnStats]( + def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], initialStatistics: Row): Unit = { @@ -55,8 +55,8 @@ class ColumnStatsSuite extends FunSuite { val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_(0).asInstanceOf[T#JvmType]) - val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] + val values = rows.take(10).map(_(0).asInstanceOf[T#InternalType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index b48bed1871c50..1e105e259dce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -196,12 +196,12 @@ class ColumnTypeSuite extends FunSuite with Logging { } } - def testNativeColumnType[T <: NativeType]( + def testNativeColumnType[T <: AtomicType]( columnType: NativeColumnType[T], - putter: (ByteBuffer, T#JvmType) => Unit, - getter: (ByteBuffer) => T#JvmType): Unit = { + putter: (ByteBuffer, T#InternalType) => Unit, + getter: (ByteBuffer) => T#InternalType): Unit = { - testColumnType[T, T#JvmType](columnType, putter, getter) + testColumnType[T, T#InternalType](columnType, putter, getter) } def testColumnType[T <: DataType, JvmType]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index f76314b9dab5e..75d993e563e06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -24,7 +24,7 @@ import scala.util.Random import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, NativeType} +import org.apache.spark.sql.types.{UTF8String, DataType, Decimal, AtomicType} object ColumnarTestUtils { def makeNullRow(length: Int): GenericMutableRow = { @@ -91,9 +91,9 @@ object ColumnarTestUtils { row } - def makeUniqueValuesAndSingleValueRows[T <: NativeType]( + def makeUniqueValuesAndSingleValueRows[T <: AtomicType]( columnType: NativeColumnType[T], - count: Int): (Seq[T#JvmType], Seq[GenericMutableRow]) = { + count: Int): (Seq[T#InternalType], Seq[GenericMutableRow]) = { val values = makeUniqueRandomValues(columnType, count) val rows = values.map { value => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index c82d9799359c7..64b70552eb047 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -24,14 +24,14 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends FunSuite { testDictionaryEncoding(new IntColumnStats, INT) testDictionaryEncoding(new LongColumnStats, LONG) testDictionaryEncoding(new StringColumnStats, STRING) - def testDictionaryEncoding[T <: NativeType]( + def testDictionaryEncoding[T <: AtomicType]( columnStats: ColumnStats, columnType: NativeColumnType[T]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 88011631ee4e3..bfd99f143bedc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -33,7 +33,7 @@ class IntegralDeltaSuite extends FunSuite { columnType: NativeColumnType[I], scheme: CompressionScheme) { - def skeleton(input: Seq[I#JvmType]) { + def skeleton(input: Seq[I#InternalType]) { // ------------- // Tests encoder // ------------- @@ -120,13 +120,13 @@ class IntegralDeltaSuite extends FunSuite { case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long) } - skeleton(input.map(_.asInstanceOf[I#JvmType])) + skeleton(input.map(_.asInstanceOf[I#InternalType])) } test(s"$scheme: long random series") { // Have to workaround with `Any` since no `ClassTag[I#JvmType]` available here. val input = Array.fill[Any](10000)(makeRandomValue(columnType)) - skeleton(input.map(_.asInstanceOf[I#JvmType])) + skeleton(input.map(_.asInstanceOf[I#InternalType])) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 08df1db375097..fde7a4595be0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends FunSuite { testRunLengthEncoding(new NoopColumnStats, BOOLEAN) @@ -32,7 +32,7 @@ class RunLengthEncodingSuite extends FunSuite { testRunLengthEncoding(new LongColumnStats, LONG) testRunLengthEncoding(new StringColumnStats, STRING) - def testRunLengthEncoding[T <: NativeType]( + def testRunLengthEncoding[T <: AtomicType]( columnStats: ColumnStats, columnType: NativeColumnType[T]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index fc8ff3b41d0e6..5268dfe0aa03e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.columnar.compression import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.types.NativeType +import org.apache.spark.sql.types.AtomicType -class TestCompressibleColumnBuilder[T <: NativeType]( +class TestCompressibleColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T], override val schemes: Seq[CompressionScheme]) @@ -32,7 +32,7 @@ class TestCompressibleColumnBuilder[T <: NativeType]( } object TestCompressibleColumnBuilder { - def apply[T <: NativeType]( + def apply[T <: AtomicType]( columnStats: ColumnStats, columnType: NativeColumnType[T], scheme: CompressionScheme): TestCompressibleColumnBuilder[T] = { From a7d65d38f934c5c751ba32aa7ab648c6d16044ab Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 23 Apr 2015 16:45:26 +0530 Subject: [PATCH 102/191] [HOTFIX] [SQL] Fix compilation for scala 2.11. Author: Prashant Sharma Closes #5652 from ScrapCodes/hf/compilation-fix-scala-2.11 and squashes the following commits: 819ff06 [Prashant Sharma] [HOTFIX] Fix compilation for scala 2.11. --- .../test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index fc3ed4a708d46..e02c84872c628 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -162,7 +162,7 @@ public void testCreateDataFrameFromJavaBeans() { Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"), - Ints.toArray(JavaConversions.asJavaList(outputBuffer))); + Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer))); Seq d = first.getAs(3); Assert.assertEquals(bean.getD().size(), d.length()); for (int i = 0; i < d.length(); i++) { From 975f53e4f978759db7639cd08498ad8cd0ae2a56 Mon Sep 17 00:00:00 2001 From: Prabeesh K Date: Thu, 23 Apr 2015 10:33:13 -0700 Subject: [PATCH 103/191] [minor][streaming]fixed scala string interpolation error Author: Prabeesh K Closes #5653 from prabeesh/fix and squashes the following commits: 9d7a9f5 [Prabeesh K] fixed scala string interpolation error --- .../org/apache/spark/examples/streaming/MQTTWordCount.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index f40caad322f59..85b9a54b40baf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -56,7 +56,7 @@ object MQTTPublisher { while (true) { try { msgtopic.publish(message) - println(s"Published data. topic: {msgtopic.getName()}; Message: {message}") + println(s"Published data. topic: ${msgtopic.getName()}; Message: $message") } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => Thread.sleep(10) From cc48e6387abdd909921cb58e0588cdf226556bcd Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 23 Apr 2015 10:35:22 -0700 Subject: [PATCH 104/191] [SPARK-7044] [SQL] Fix the deadlock in script transformation Author: Cheng Hao Closes #5625 from chenghao-intel/transform and squashes the following commits: 5ec1dd2 [Cheng Hao] fix the deadlock issue in ScriptTransform --- .../hive/execution/ScriptTransformation.scala | 33 ++++++++++++------- .../sql/hive/execution/SQLQuerySuite.scala | 8 +++++ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index cab0fdd35723a..3eddda3b28c66 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -145,20 +145,29 @@ case class ScriptTransformation( val dataOutputStream = new DataOutputStream(outputStream) val outputProjection = new InterpretedProjection(input, child.output) - iter - .map(outputProjection) - .foreach { row => - if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - - outputStream.write(data) - } else { - val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) + // Put the write(output to the pipeline) into a single thread + // and keep the collector as remain in the main thread. + // otherwise it will causes deadlock if the data size greater than + // the pipeline / buffer capacity. + new Thread(new Runnable() { + override def run(): Unit = { + iter + .map(outputProjection) + .foreach { row => + if (inputSerde == null) { + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + + outputStream.write(data) + } else { + val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi) + prepareWritable(writable).write(dataOutputStream) + } } + outputStream.close() } - outputStream.close() + }).start() + iterator } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 47b4cb9ca61ff..4f8d0ac0e7656 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -561,4 +561,12 @@ class SQLQuerySuite extends QueryTest { sql("select d from dn union all select d * 2 from dn") .queryExecution.analyzed } + + test("test script transform") { + val data = (1 to 100000).map { i => (i, i, i) } + data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + assert(100000 === + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") + .queryExecution.toRdd.count()) + } } From 534f2a43625fbf1a3a65d09550a19875cd1dce43 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 23 Apr 2015 11:29:34 -0700 Subject: [PATCH 105/191] [SPARK-6752][Streaming] Allow StreamingContext to be recreated from checkpoint and existing SparkContext Currently if you want to create a StreamingContext from checkpoint information, the system will create a new SparkContext. This prevent StreamingContext to be recreated from checkpoints in managed environments where SparkContext is precreated. The solution in this PR: Introduce the following methods on StreamingContext 1. `new StreamingContext(checkpointDirectory, sparkContext)` Recreate StreamingContext from checkpoint using the provided SparkContext 2. `StreamingContext.getOrCreate(checkpointDirectory, sparkContext, createFunction: SparkContext => StreamingContext)` If checkpoint file exists, then recreate StreamingContext using the provided SparkContext (that is, call 1.), else create StreamingContext using the provided createFunction TODO: the corresponding Java and Python API has to be added as well. Author: Tathagata Das Closes #5428 from tdas/SPARK-6752 and squashes the following commits: 94db63c [Tathagata Das] Fix long line. 524f519 [Tathagata Das] Many changes based on PR comments. eabd092 [Tathagata Das] Added Function0, Java API and unit tests for StreamingContext.getOrCreate 36a7823 [Tathagata Das] Minor changes. 204814e [Tathagata Das] Added StreamingContext.getOrCreate with existing SparkContext --- .../spark/api/java/function/Function0.java | 27 +++ .../apache/spark/streaming/Checkpoint.scala | 26 ++- .../spark/streaming/StreamingContext.scala | 85 ++++++++-- .../api/java/JavaStreamingContext.scala | 119 ++++++++++++- .../apache/spark/streaming/JavaAPISuite.java | 145 ++++++++++++---- .../spark/streaming/CheckpointSuite.scala | 3 +- .../streaming/StreamingContextSuite.scala | 159 ++++++++++++++++++ 7 files changed, 503 insertions(+), 61 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/Function0.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java new file mode 100644 index 0000000000000..38e410c5debe6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A zero-argument function that returns an R. + */ +public interface Function0 extends Serializable { + public R call() throws Exception; +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 0a50485118588..7bfae253c3a0c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -77,7 +77,8 @@ object Checkpoint extends Logging { } /** Get checkpoint files present in the give directory, ordered by oldest-first */ - def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = { + def getCheckpointFiles(checkpointDir: String, fsOption: Option[FileSystem] = None): Seq[Path] = { + def sortFunc(path1: Path, path2: Path): Boolean = { val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } @@ -85,6 +86,7 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) + val fs = fsOption.getOrElse(path.getFileSystem(new Configuration())) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { @@ -160,7 +162,7 @@ class CheckpointWriter( } // Delete old checkpoint files - val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) + val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) if (allCheckpointFiles.size > 10) { allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { logInfo("Deleting " + file) @@ -234,15 +236,24 @@ class CheckpointWriter( private[streaming] object CheckpointReader extends Logging { - def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = - { + /** + * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint + * files, then return None, else try to return the latest valid checkpoint object. If no + * checkpoint files could be read correctly, then return None (if ignoreReadError = true), + * or throw exception (if ignoreReadError = false). + */ + def read( + checkpointDir: String, + conf: SparkConf, + hadoopConf: Configuration, + ignoreReadError: Boolean = false): Option[Checkpoint] = { val checkpointPath = new Path(checkpointDir) // TODO(rxin): Why is this a def?! def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files - val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse if (checkpointFiles.isEmpty) { return None } @@ -282,7 +293,10 @@ object CheckpointReader extends Logging { }) // If none of checkpoint files could be read, then throw exception - throw new SparkException("Failed to read checkpoint from directory " + checkpointPath) + if (!ignoreReadError) { + throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") + } + None } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index f57f295874645..90c8b47aebce0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -107,6 +107,19 @@ class StreamingContext private[streaming] ( */ def this(path: String) = this(path, new Configuration) + /** + * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. + * @param path Path to the directory that was specified as the checkpoint directory + * @param sparkContext Existing SparkContext + */ + def this(path: String, sparkContext: SparkContext) = { + this( + sparkContext, + CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get, + null) + } + + if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") @@ -115,10 +128,12 @@ class StreamingContext private[streaming] ( private[streaming] val isCheckpointPresent = (cp_ != null) private[streaming] val sc: SparkContext = { - if (isCheckpointPresent) { + if (sc_ != null) { + sc_ + } else if (isCheckpointPresent) { new SparkContext(cp_.createSparkConf()) } else { - sc_ + throw new SparkException("Cannot create StreamingContext without a SparkContext") } } @@ -129,7 +144,7 @@ class StreamingContext private[streaming] ( private[streaming] val conf = sc.conf - private[streaming] val env = SparkEnv.get + private[streaming] val env = sc.env private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { @@ -174,7 +189,9 @@ class StreamingContext private[streaming] ( /** Register streaming source to metrics system */ private val streamingSource = new StreamingSource(this) - SparkEnv.get.metricsSystem.registerSource(streamingSource) + assert(env != null) + assert(env.metricsSystem != null) + env.metricsSystem.registerSource(streamingSource) /** Enumeration to identify current state of the StreamingContext */ private[streaming] object StreamingContextState extends Enumeration { @@ -621,19 +638,59 @@ object StreamingContext extends Logging { hadoopConf: Configuration = new Configuration(), createOnError: Boolean = false ): StreamingContext = { - val checkpointOption = try { - CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf) - } catch { - case e: Exception => - if (createOnError) { - None - } else { - throw e - } - } + val checkpointOption = CheckpointReader.read( + checkpointPath, new SparkConf(), hadoopConf, createOnError) checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note + * that the SparkConf configuration in the checkpoint data will not be restored as the + * SparkContext has already been created. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext using the given SparkContext + * @param sparkContext SparkContext using which the StreamingContext will be created + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: SparkContext => StreamingContext, + sparkContext: SparkContext + ): StreamingContext = { + getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note + * that the SparkConf configuration in the checkpoint data will not be restored as the + * SparkContext has already been created. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext using the given SparkContext + * @param sparkContext SparkContext using which the StreamingContext will be created + * @param createOnError Whether to create a new StreamingContext if there is an + * error in reading checkpoint data. By default, an exception will be + * thrown on error. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: SparkContext => StreamingContext, + sparkContext: SparkContext, + createOnError: Boolean + ): StreamingContext = { + val checkpointOption = CheckpointReader.read( + checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError) + checkpointOption.map(new StreamingContext(sparkContext, _, null)) + .getOrElse(creatingFunc(sparkContext)) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 4095a7cc84946..572d7d8e8753d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -32,13 +32,14 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import org.apache.spark.api.java.function.{Function0 => JFunction0} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.hadoop.conf.Configuration -import org.apache.spark.streaming.dstream.{PluggableInputDStream, ReceiverInputDStream, DStream} +import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver +import org.apache.hadoop.conf.Configuration /** * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main @@ -655,6 +656,7 @@ object JavaStreamingContext { * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext */ + @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( checkpointPath: String, factory: JavaStreamingContextFactory @@ -676,6 +678,7 @@ object JavaStreamingContext { * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system */ + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -700,6 +703,7 @@ object JavaStreamingContext { * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. */ + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -712,6 +716,117 @@ object JavaStreamingContext { new JavaStreamingContext(ssc) } + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext] + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext], + hadoopConf: Configuration + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }, hadoopConf) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext], + hadoopConf: Configuration, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }, hadoopConf, createOnError) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param sparkContext SparkContext using which the StreamingContext will be created + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], + sparkContext: JavaSparkContext + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { + creatingFunc.call(new JavaSparkContext(sparkContext)).ssc + }, sparkContext.sc) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param sparkContext SparkContext using which the StreamingContext will be created + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], + sparkContext: JavaSparkContext, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { + creatingFunc.call(new JavaSparkContext(sparkContext)).ssc + }, sparkContext.sc, createOnError) + new JavaStreamingContext(ssc) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 90340753a4eed..cb2e8380b4933 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -22,10 +22,12 @@ import java.nio.charset.Charset; import java.util.*; +import org.apache.commons.lang.mutable.MutableBoolean; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; + import scala.Tuple2; import org.junit.Assert; @@ -45,6 +47,7 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; +import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -929,7 +932,7 @@ public void testPairMap() { // Maps pair -> pair of different type public Tuple2 call(Tuple2 in) throws Exception { return in.swap(); } - }); + }); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -987,12 +990,12 @@ public void testPairMap2() { // Maps pair -> single JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaDStream reversed = pairStream.map( - new Function, Integer>() { - @Override - public Integer call(Tuple2 in) throws Exception { - return in._2(); - } - }); + new Function, Integer>() { + @Override + public Integer call(Tuple2 in) throws Exception { + return in._2(); + } + }); JavaTestUtils.attachTestOutputStream(reversed); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1123,7 +1126,7 @@ public void testCombineByKey() { JavaPairDStream combined = pairStream.combineByKey( new Function() { - @Override + @Override public Integer call(Integer i) throws Exception { return i; } @@ -1144,14 +1147,14 @@ public void testCountByValue() { Arrays.asList("hello")); List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), - Arrays.asList( - new Tuple2("hello", 1L))); + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("moon", 1L)), + Arrays.asList( + new Tuple2("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1249,17 +1252,17 @@ public void testUpdateStateByKey() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); } - return Optional.of(out); - } }); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1292,17 +1295,17 @@ public void testUpdateStateByKeyWithInitial() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); } - return Optional.of(out); - } }, new HashPartitioner(1), initialRDD); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1328,7 +1331,7 @@ public void testReduceByKeyAndWindowWithInverse() { JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1707,6 +1710,74 @@ public Integer call(String s) throws Exception { Utils.deleteRecursively(tempDir); } + @SuppressWarnings("unchecked") + @Test + public void testContextGetOrCreate() throws InterruptedException { + + final SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("newContext", "true"); + + File emptyDir = Files.createTempDir(); + emptyDir.deleteOnExit(); + StreamingContextSuite contextSuite = new StreamingContextSuite(); + String corruptedCheckpointDir = contextSuite.createCorruptedCheckpoint(); + String checkpointDir = contextSuite.createValidCheckpoint(); + + // Function to create JavaStreamingContext without any output operations + // (used to detect the new context) + final MutableBoolean newContextCreated = new MutableBoolean(false); + Function0 creatingFunc = new Function0() { + public JavaStreamingContext call() { + newContextCreated.setValue(true); + return new JavaStreamingContext(conf, Seconds.apply(1)); + } + }; + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc); + Assert.assertTrue("new context not created", newContextCreated.isTrue()); + ssc.stop(); + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration(), true); + Assert.assertTrue("new context not created", newContextCreated.isTrue()); + ssc.stop(); + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration()); + Assert.assertTrue("old context not recovered", newContextCreated.isFalse()); + ssc.stop(); + + // Function to create JavaStreamingContext using existing JavaSparkContext + // without any output operations (used to detect the new context) + Function creatingFunc2 = + new Function() { + public JavaStreamingContext call(JavaSparkContext context) { + newContextCreated.setValue(true); + return new JavaStreamingContext(context, Seconds.apply(1)); + } + }; + + JavaSparkContext sc = new JavaSparkContext(conf); + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc); + Assert.assertTrue("new context not created", newContextCreated.isTrue()); + ssc.stop(false); + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true); + Assert.assertTrue("new context not created", newContextCreated.isTrue()); + ssc.stop(false); + + newContextCreated.setValue(false); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc); + Assert.assertTrue("old context not recovered", newContextCreated.isFalse()); + ssc.stop(); + } /* TEST DISABLED: Pending a discussion about checkpoint() semantics with TD @SuppressWarnings("unchecked") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 54c30440a6e8d..6b0a3f91d4d06 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -430,9 +430,8 @@ class CheckpointSuite extends TestSuiteBase { assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } // Wait for a checkpoint to be written - val fs = new Path(checkpointDir).getFileSystem(ssc.sc.hadoopConfiguration) eventually(eventuallyTimeout) { - assert(Checkpoint.getCheckpointFiles(checkpointDir, fs).size === 6) + assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) } ssc.stop() // Check that we shut down while the third batch was being processed diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 58353a5f97c8a..4f193322ad33e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.streaming +import java.io.File import java.util.concurrent.atomic.AtomicInteger +import org.apache.commons.io.FileUtils import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ @@ -330,6 +332,139 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("getOrCreate") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + + // Function to create StreamingContext that has a config to identify it to be new context + var newContextCreated = false + def creatingFunction(): StreamingContext = { + newContextCreated = true + new StreamingContext(conf, batchDuration) + } + + // Call ssc.stop after a body of code + def testGetOrCreate(body: => Unit): Unit = { + newContextCreated = false + try { + body + } finally { + if (ssc != null) { + ssc.stop() + } + ssc = null + } + } + + val emptyPath = Utils.createTempDir().getAbsolutePath() + + // getOrCreate should create new context with empty path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + val corrutedCheckpointPath = createCorruptedCheckpoint() + + // getOrCreate should throw exception with fake checkpoint file and createOnError = false + intercept[Exception] { + ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _) + } + + // getOrCreate should throw exception with fake checkpoint file + intercept[Exception] { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, createOnError = false) + } + + // getOrCreate should create new context with fake checkpoint file and createOnError = true + testGetOrCreate { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + val checkpointPath = createValidCheckpoint() + + // getOrCreate should recover context with checkpoint path, and recover old configuration + testGetOrCreate { + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(ssc.conf.get("someKey") === "someValue") + } + } + + test("getOrCreate with existing SparkContext") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + sc = new SparkContext(conf) + + // Function to create StreamingContext that has a config to identify it to be new context + var newContextCreated = false + def creatingFunction(sparkContext: SparkContext): StreamingContext = { + newContextCreated = true + new StreamingContext(sparkContext, batchDuration) + } + + // Call ssc.stop(stopSparkContext = false) after a body of cody + def testGetOrCreate(body: => Unit): Unit = { + newContextCreated = false + try { + body + } finally { + if (ssc != null) { + ssc.stop(stopSparkContext = false) + } + ssc = null + } + } + + val emptyPath = Utils.createTempDir().getAbsolutePath() + + // getOrCreate should create new context with empty path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + } + + val corrutedCheckpointPath = createCorruptedCheckpoint() + + // getOrCreate should throw exception with fake checkpoint file and createOnError = false + intercept[Exception] { + ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc) + } + + // getOrCreate should throw exception with fake checkpoint file + intercept[Exception] { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, sc, createOnError = false) + } + + // getOrCreate should create new context with fake checkpoint file and createOnError = true + testGetOrCreate { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, sc, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + } + + val checkpointPath = createValidCheckpoint() + + // StreamingContext.getOrCreate should recover context with checkpoint path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + assert(!ssc.conf.contains("someKey"), + "recovered StreamingContext unexpectedly has old config") + } + } + test("DStream and generated RDD creation sites") { testPackage.test() } @@ -339,6 +474,30 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val inputStream = new TestInputStream(s, input, 1) inputStream } + + def createValidCheckpoint(): String = { + val testDirectory = Utils.createTempDir().getAbsolutePath() + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set("someKey", "someValue") + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDirectory) + ssc.textFileStream(testDirectory).foreachRDD { rdd => rdd.count() } + ssc.start() + eventually(timeout(10000 millis)) { + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + } + ssc.stop() + checkpointDirectory + } + + def createCorruptedCheckpoint(): String = { + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + val fakeCheckpointFile = Checkpoint.checkpointFile(checkpointDirectory, Time(1000)) + FileUtils.write(new File(fakeCheckpointFile.toString()), "blablabla") + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).nonEmpty) + checkpointDirectory + } } class TestException(msg: String) extends Exception(msg) From c1213e6a92e126ad886d9804cedaf6db3618e602 Mon Sep 17 00:00:00 2001 From: Vinod K C Date: Thu, 23 Apr 2015 12:00:23 -0700 Subject: [PATCH 106/191] [SPARK-7055][SQL]Use correct ClassLoader for JDBC Driver in JDBCRDD.getConnector Author: Vinod K C Closes #5633 from vinodkc/use_correct_classloader_driverload and squashes the following commits: 73c5380 [Vinod K C] Use correct ClassLoader for JDBC Driver --- .../src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index b975191d41963..f326510042122 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -26,6 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} import org.apache.spark.sql.types._ import org.apache.spark.sql.sources._ +import org.apache.spark.util.Utils private[sql] object JDBCRDD extends Logging { /** @@ -152,7 +153,7 @@ private[sql] object JDBCRDD extends Logging { def getConnector(driver: String, url: String, properties: Properties): () => Connection = { () => { try { - if (driver != null) Class.forName(driver) + if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver) } catch { case e: ClassNotFoundException => { logWarning(s"Couldn't find class $driver", e); From 6afde2c7810c363083d0a699b1de02b54c13e6a9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Apr 2015 13:19:03 -0700 Subject: [PATCH 107/191] [SPARK-7058] Include RDD deserialization time in "task deserialization time" metric The web UI's "task deserialization time" metric is slightly misleading because it does not capture the time taken to deserialize the broadcasted RDD. Author: Josh Rosen Closes #5635 from JoshRosen/SPARK-7058 and squashes the following commits: ed90f75 [Josh Rosen] Update UI tooltip a3743b4 [Josh Rosen] Update comments. 4f52910 [Josh Rosen] Roll back whitespace change e9cf9f4 [Josh Rosen] Remove unused variable 9f32e55 [Josh Rosen] Expose executorDeserializeTime on Task instead of pushing runtime calculation into Task. 21f5b47 [Josh Rosen] Don't double-count the broadcast deserialization time in task runtime 1752f0e [Josh Rosen] [SPARK-7058] Incorporate RDD deserialization time in task deserialization time metric --- .../main/scala/org/apache/spark/executor/Executor.scala | 8 ++++++-- .../scala/org/apache/spark/scheduler/ResultTask.scala | 2 ++ .../scala/org/apache/spark/scheduler/ShuffleMapTask.scala | 2 ++ core/src/main/scala/org/apache/spark/scheduler/Task.scala | 7 +++++++ core/src/main/scala/org/apache/spark/ui/ToolTips.scala | 4 +++- 5 files changed, 20 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5fc04df5d6a40..f57e215c3f2ed 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -220,8 +220,12 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.setExecutorDeserializeTime(taskStart - deserializeStartTime) - m.setExecutorRunTime(taskFinish - taskStart) + // Deserialization happens in two parts: first, we deserialize a Task object, which + // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. + m.setExecutorDeserializeTime( + (taskStart - deserializeStartTime) + task.executorDeserializeTime) + // We need to subtract Task.run()'s deserialization time to avoid double-counting + m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) m.setJvmGCTime(computeTotalGcTime() - startGCTime) m.setResultSerializationTime(afterSerialization - beforeSerialization) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index e074ce6ebff0b..c9a124113961f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -53,9 +53,11 @@ private[spark] class ResultTask[T, U]( override def runTask(context: TaskContext): U = { // Deserialize the RDD and the func using the broadcast variables. + val deserializeStartTime = System.currentTimeMillis() val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime metrics = Some(context.taskMetrics) func(context, rdd.iterator(partition, context)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 6c7d00069acb2..bd3dd23dfe1ac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -56,9 +56,11 @@ private[spark] class ShuffleMapTask( override def runTask(context: TaskContext): MapStatus = { // Deserialize the RDD using the broadcast variable. + val deserializeStartTime = System.currentTimeMillis() val ser = SparkEnv.get.closureSerializer.newInstance() val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 8b592867ee31d..b09b19e2ac9e7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -87,11 +87,18 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex // initialized when kill() is invoked. @volatile @transient private var _killed = false + protected var _executorDeserializeTime: Long = 0 + /** * Whether the task has been killed. */ def killed: Boolean = _killed + /** + * Returns the amount of time spent deserializing the RDD and function to be run. + */ + def executorDeserializeTime: Long = _executorDeserializeTime + /** * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark * code and user code to properly handle the flag. This function should be idempotent so it can diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index cae6870c2ab20..24f3236456248 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,7 +24,9 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" - val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor." + val TASK_DESERIALIZATION_TIME = + """Time spent deserializing the task closure on the executor, including the time to read the + broadcasted task.""" val SHUFFLE_READ_BLOCKED_TIME = "Time that the task spent blocked waiting for shuffle data to be read from remote machines." From 3e91cc273d281053618bfa032bc610e2cf8d8e78 Mon Sep 17 00:00:00 2001 From: wizz Date: Thu, 23 Apr 2015 14:00:07 -0700 Subject: [PATCH 108/191] [SPARK-7085][MLlib] Fix miniBatchFraction parameter in train method called with 4 arguments Author: wizz Closes #5658 from kuromatsu-nobuyuki/SPARK-7085 and squashes the following commits: 6ec2d21 [wizz] Fix miniBatchFraction parameter in train method called with 4 arguments --- .../org/apache/spark/mllib/regression/RidgeRegression.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index 8838ca8c14718..309f9af466457 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -171,7 +171,7 @@ object RidgeRegressionWithSGD { numIterations: Int, stepSize: Double, regParam: Double): RidgeRegressionModel = { - train(input, numIterations, stepSize, regParam, 0.01) + train(input, numIterations, stepSize, regParam, 1.0) } /** From baa83a9a6769c5e119438d65d7264dceb8d743d5 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Thu, 23 Apr 2015 17:20:17 -0400 Subject: [PATCH 109/191] [SPARK-6879] [HISTORYSERVER] check if app is completed before clean it up https://issues.apache.org/jira/browse/SPARK-6879 Use `applications` to replace `FileStatus`, and check if the app is completed before clean it up. If an exception was throwed, add it to `applications` to wait for the next loop. Author: WangTaoTheTonic Closes #5491 from WangTaoTheTonic/SPARK-6879 and squashes the following commits: 4a533eb [WangTaoTheTonic] treat ACE specially cb45105 [WangTaoTheTonic] rebase d4d5251 [WangTaoTheTonic] per Marcelo's comments d7455d8 [WangTaoTheTonic] slightly change when delete file b0abca5 [WangTaoTheTonic] use global var to store apps to clean 94adfe1 [WangTaoTheTonic] leave expired apps alone to be deleted 9872a9d [WangTaoTheTonic] use the right path fdef4d6 [WangTaoTheTonic] check if app is completed before clean it up --- .../deploy/history/FsHistoryProvider.scala | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 9847d5944a390..a94ebf6e53750 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -35,7 +35,6 @@ import org.apache.spark.ui.SparkUI import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.{Logging, SecurityManager, SparkConf} - /** * A class that provides application history from event logs stored in the file system. * This provider checks for new finished applications in the background periodically and @@ -76,6 +75,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis @volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo] = new mutable.LinkedHashMap() + // List of applications to be deleted by event log cleaner. + private var appsToClean = new mutable.ListBuffer[FsApplicationHistoryInfo] + // Constants used to parse Spark 1.0.0 log directories. private[history] val LOG_PREFIX = "EVENT_LOG_" private[history] val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" @@ -266,34 +268,40 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis */ private def cleanLogs(): Unit = { try { - val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) - .getOrElse(Seq[FileStatus]()) val maxAge = conf.getTimeAsSeconds("spark.history.fs.cleaner.maxAge", "7d") * 1000 val now = System.currentTimeMillis() val appsToRetain = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() + // Scan all logs from the log directory. + // Only completed applications older than the specified max age will be deleted. applications.values.foreach { info => - if (now - info.lastUpdated <= maxAge) { + if (now - info.lastUpdated <= maxAge || !info.completed) { appsToRetain += (info.id -> info) + } else { + appsToClean += info } } applications = appsToRetain - // Scan all logs from the log directory. - // Only directories older than the specified max age will be deleted - statusList.foreach { dir => + val leftToClean = new mutable.ListBuffer[FsApplicationHistoryInfo] + appsToClean.foreach { info => try { - if (now - dir.getModificationTime() > maxAge) { - // if path is a directory and set to true, - // the directory is deleted else throws an exception - fs.delete(dir.getPath, true) + val path = new Path(logDir, info.logPath) + if (fs.exists(path)) { + fs.delete(path, true) } } catch { - case t: IOException => logError(s"IOException in cleaning logs of $dir", t) + case e: AccessControlException => + logInfo(s"No permission to delete ${info.logPath}, ignoring.") + case t: IOException => + logError(s"IOException in cleaning logs of ${info.logPath}", t) + leftToClean += info } } + + appsToClean = leftToClean } catch { case t: Exception => logError("Exception in cleaning logs", t) } From 6d0749cae301ee4bf37632d657de48e75548a523 Mon Sep 17 00:00:00 2001 From: Tijo Thomas Date: Thu, 23 Apr 2015 17:23:15 -0400 Subject: [PATCH 110/191] [SPARK-7087] [BUILD] Fix path issue change version script Author: Tijo Thomas Closes #5656 from tijoparacka/FIX_PATHISSUE_CHANGE_VERSION_SCRIPT and squashes the following commits: ab4f4b1 [Tijo Thomas] removed whitespace 24478c9 [Tijo Thomas] modified to provide the spark base dir while searching for pom and also while changing the vesrion no 7b8e10b [Tijo Thomas] Modified for providing the base directories while finding the list of pom files and also while changing the version no --- dev/change-version-to-2.10.sh | 6 +++--- dev/change-version-to-2.11.sh | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh index 15e0c73b4295e..c4adb1f96b7d3 100755 --- a/dev/change-version-to-2.10.sh +++ b/dev/change-version-to-2.10.sh @@ -18,9 +18,9 @@ # # Note that this will not necessarily work as intended with non-GNU sed (e.g. OS X) - -find . -name 'pom.xml' | grep -v target \ +BASEDIR=$(dirname $0)/.. +find $BASEDIR -name 'pom.xml' | grep -v target \ | xargs -I {} sed -i -e 's/\(artifactId.*\)_2.11/\1_2.10/g' {} # Also update in parent POM -sed -i -e '0,/2.112.102.112.10 in parent POM -sed -i -e '0,/2.102.112.102.11 Date: Thu, 23 Apr 2015 14:46:54 -0700 Subject: [PATCH 111/191] [SPARK-7070] [MLLIB] LDA.setBeta should call setTopicConcentration. jkbradley Author: Xiangrui Meng Closes #5649 from mengxr/SPARK-7070 and squashes the following commits: c66023c [Xiangrui Meng] setBeta should call setTopicConcentration --- .../scala/org/apache/spark/mllib/clustering/LDA.scala | 2 +- .../org/apache/spark/mllib/clustering/LDASuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 9d63a08e211bc..d006b39acb213 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -177,7 +177,7 @@ class LDA private ( def getBeta: Double = getTopicConcentration /** Alias for [[setTopicConcentration()]] */ - def setBeta(beta: Double): this.type = setBeta(beta) + def setBeta(beta: Double): this.type = setTopicConcentration(beta) /** * Maximum number of iterations for learning. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 15de10fd13a19..cc747dabb9968 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -123,6 +123,14 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { assert(termVertexIds.map(i => LDA.index2term(i.toLong)) === termIds) assert(termVertexIds.forall(i => LDA.isTermVertex((i.toLong, 0)))) } + + test("setter alias") { + val lda = new LDA().setAlpha(2.0).setBeta(3.0) + assert(lda.getAlpha === 2.0) + assert(lda.getDocConcentration === 2.0) + assert(lda.getBeta === 3.0) + assert(lda.getTopicConcentration === 3.0) + } } private[clustering] object LDASuite { From 6220d933e5ce4ba890f5d6a50a69b95d319dafb4 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 23 Apr 2015 14:48:19 -0700 Subject: [PATCH 112/191] [SQL] Break dataTypes.scala into multiple files. It was over 1000 lines of code, making it harder to find all the types. Only moved code around, and didn't change any. Author: Reynold Xin Closes #5670 from rxin/break-types and squashes the following commits: 8c59023 [Reynold Xin] Check in missing files. dcd5193 [Reynold Xin] [SQL] Break dataTypes.scala into multiple files. --- .../apache/spark/sql/types/ArrayType.scala | 74 + .../apache/spark/sql/types/BinaryType.scala | 63 + .../apache/spark/sql/types/BooleanType.scala | 51 + .../org/apache/spark/sql/types/ByteType.scala | 54 + .../org/apache/spark/sql/types/DataType.scala | 353 +++++ .../org/apache/spark/sql/types/DateType.scala | 54 + .../apache/spark/sql/types/DecimalType.scala | 110 ++ .../apache/spark/sql/types/DoubleType.scala | 53 + .../apache/spark/sql/types/FloatType.scala | 53 + .../apache/spark/sql/types/IntegerType.scala | 54 + .../org/apache/spark/sql/types/LongType.scala | 54 + .../org/apache/spark/sql/types/MapType.scala | 79 ++ .../org/apache/spark/sql/types/NullType.scala | 39 + .../apache/spark/sql/types/ShortType.scala | 53 + .../apache/spark/sql/types/StringType.scala | 50 + .../apache/spark/sql/types/StructField.scala | 54 + .../apache/spark/sql/types/StructType.scala | 263 ++++ .../spark/sql/types/TimestampType.scala | 57 + .../spark/sql/types/UserDefinedType.scala | 81 ++ .../apache/spark/sql/types/dataTypes.scala | 1224 ----------------- 20 files changed, 1649 insertions(+), 1224 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala new file mode 100644 index 0000000000000..b116163faccad --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.json4s.JsonDSL._ + +import org.apache.spark.annotation.DeveloperApi + + +object ArrayType { + /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ + def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true) +} + + +/** + * :: DeveloperApi :: + * The data type for collections of multiple values. + * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * + * Please use [[DataTypes.createArrayType()]] to create a specific instance. + * + * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and + * `containsNull: Boolean`. The field of `elementType` is used to specify the type of + * array elements. The field of `containsNull` is used to specify if the array has `null` values. + * + * @param elementType The data type of values. + * @param containsNull Indicates if values have `null` values + * + * @group dataType + */ +@DeveloperApi +case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null, false) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append( + s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") + DataType.buildFormattedString(elementType, s"$prefix |", builder) + } + + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("elementType" -> elementType.jsonValue) ~ + ("containsNull" -> containsNull) + + /** + * The default size of a value of the ArrayType is 100 * the default size of the element type. + * (We assume that there are 100 elements). + */ + override def defaultSize: Int = 100 * elementType.defaultSize + + override def simpleString: String = s"array<${elementType.simpleString}>" + + private[spark] override def asNullable: ArrayType = + ArrayType(elementType.asNullable, containsNull = true) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala new file mode 100644 index 0000000000000..a581a9e9468ef --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BinaryType.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `Array[Byte]` values. + * Please use the singleton [[DataTypes.BinaryType]]. + * + * @group dataType + */ +@DeveloperApi +class BinaryType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + + private[sql] type InternalType = Array[Byte] + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + + private[sql] val ordering = new Ordering[InternalType] { + def compare(x: Array[Byte], y: Array[Byte]): Int = { + for (i <- 0 until x.length; if i < y.length) { + val res = x(i).compareTo(y(i)) + if (res != 0) return res + } + x.length - y.length + } + } + + /** + * The default size of a value of the BinaryType is 4096 bytes. + */ + override def defaultSize: Int = 4096 + + private[spark] override def asNullable: BinaryType = this +} + + +case object BinaryType extends BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala new file mode 100644 index 0000000000000..a7f228cefa57a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/BooleanType.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. + * + *@group dataType + */ +@DeveloperApi +class BooleanType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Boolean + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the BooleanType is 1 byte. + */ + override def defaultSize: Int = 1 + + private[spark] override def asNullable: BooleanType = this +} + + +case object BooleanType extends BooleanType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala new file mode 100644 index 0000000000000..4d8685796ec76 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Integral, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. + * + * @group dataType + */ +@DeveloperApi +class ByteType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Byte + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Byte]] + private[sql] val integral = implicitly[Integral[Byte]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the ByteType is 1 byte. + */ + override def defaultSize: Int = 1 + + override def simpleString: String = "tinyint" + + private[spark] override def asNullable: ByteType = this +} + +case object ByteType extends ByteType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala new file mode 100644 index 0000000000000..e6bfcd9adfeb1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{TypeTag, runtimeMirror} +import scala.util.parsing.combinator.RegexParsers + +import org.json4s._ +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.util.Utils + + +/** + * :: DeveloperApi :: + * The base type of all Spark SQL data types. + * + * @group dataType + */ +@DeveloperApi +abstract class DataType { + /** Matches any expression that evaluates to this DataType */ + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType == this => true + case _ => false + } + + /** The default size of a value of this data type. */ + def defaultSize: Int + + def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase + + private[sql] def jsonValue: JValue = typeName + + def json: String = compact(render(jsonValue)) + + def prettyJson: String = pretty(render(jsonValue)) + + def simpleString: String = typeName + + /** Check if `this` and `other` are the same data type when ignoring nullability + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + */ + private[spark] def sameType(other: DataType): Boolean = + DataType.equalsIgnoreNullability(this, other) + + /** Returns the same data type but set all nullability fields are true + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + */ + private[spark] def asNullable: DataType +} + + +/** + * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. + */ +protected[sql] abstract class AtomicType extends DataType { + private[sql] type InternalType + @transient private[sql] val tag: TypeTag[InternalType] + private[sql] val ordering: Ordering[InternalType] + + @transient private[sql] val classTag = ScalaReflectionLock.synchronized { + val mirror = runtimeMirror(Utils.getSparkClassLoader) + ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) + } +} + + +/** + * :: DeveloperApi :: + * Numeric data types. + * + * @group dataType + */ +abstract class NumericType extends AtomicType { + // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for + // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a + // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // desugared by the compiler into an argument to the objects constructor. This means there is no + // longer an no argument constructor and thus the JVM cannot serialize the object anymore. + private[sql] val numeric: Numeric[InternalType] +} + + +private[sql] object NumericType { + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] +} + + +/** Matcher for any expressions that evaluate to [[IntegralType]]s */ +private[sql] object IntegralType { + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[IntegralType] => true + case _ => false + } +} + + +private[sql] abstract class IntegralType extends NumericType { + private[sql] val integral: Integral[InternalType] +} + + + +/** Matcher for any expressions that evaluate to [[FractionalType]]s */ +private[sql] object FractionalType { + def unapply(a: Expression): Boolean = a match { + case e: Expression if e.dataType.isInstanceOf[FractionalType] => true + case _ => false + } +} + + +private[sql] abstract class FractionalType extends NumericType { + private[sql] val fractional: Fractional[InternalType] + private[sql] val asIntegral: Integral[InternalType] +} + + +object DataType { + + def fromJson(json: String): DataType = parseDataType(parse(json)) + + @deprecated("Use DataType.fromJson instead", "1.2.0") + def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) + + private val nonDecimalNameToType = { + Seq(NullType, DateType, TimestampType, BinaryType, + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + .map(t => t.typeName -> t).toMap + } + + /** Given the string representation of a type, return its DataType */ + private def nameToType(name: String): DataType = { + val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r + name match { + case "decimal" => DecimalType.Unlimited + case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) + case other => nonDecimalNameToType(other) + } + } + + private object JSortedObject { + def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { + case JObject(seq) => Some(seq.toList.sortBy(_._1)) + case _ => None + } + } + + // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. + private def parseDataType(json: JValue): DataType = json match { + case JString(name) => + nameToType(name) + + case JSortedObject( + ("containsNull", JBool(n)), + ("elementType", t: JValue), + ("type", JString("array"))) => + ArrayType(parseDataType(t), n) + + case JSortedObject( + ("keyType", k: JValue), + ("type", JString("map")), + ("valueContainsNull", JBool(n)), + ("valueType", v: JValue)) => + MapType(parseDataType(k), parseDataType(v), n) + + case JSortedObject( + ("fields", JArray(fields)), + ("type", JString("struct"))) => + StructType(fields.map(parseStructField)) + + case JSortedObject( + ("class", JString(udtClass)), + ("pyClass", _), + ("sqlType", _), + ("type", JString("udt"))) => + Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + } + + private def parseStructField(json: JValue): StructField = json match { + case JSortedObject( + ("metadata", metadata: JObject), + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata)) + // Support reading schema when 'metadata' is missing. + case JSortedObject( + ("name", JString(name)), + ("nullable", JBool(nullable)), + ("type", dataType: JValue)) => + StructField(name, parseDataType(dataType), nullable) + } + + private object CaseClassStringParser extends RegexParsers { + protected lazy val primitiveType: Parser[DataType] = + ( "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DateType" ^^^ DateType + | "DecimalType()" ^^^ DecimalType.Unlimited + | fixedDecimalType + | "TimestampType" ^^^ TimestampType + ) + + protected lazy val fixedDecimalType: Parser[DataType] = + ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => + StructField(name, tpe, nullable = nullable) + } + + protected lazy val boolVal: Parser[Boolean] = + ( "true" ^^^ true + | "false" ^^^ false + ) + + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => StructType(fields) + } + + protected lazy val dataType: Parser[DataType] = + ( arrayType + | mapType + | structType + | primitiveType + ) + + /** + * Parses a string representation of a DataType. + * + * TODO: Generate parser as pickler... + */ + def apply(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => + throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") + } + } + + protected[types] def buildFormattedString( + dataType: DataType, + prefix: String, + builder: StringBuilder): Unit = { + dataType match { + case array: ArrayType => + array.buildFormattedString(prefix, builder) + case struct: StructType => + struct.buildFormattedString(prefix, builder) + case map: MapType => + map.buildFormattedString(prefix, builder) + case _ => + } + } + + /** + * Compares two types, ignoring nullability of ArrayType, MapType, StructType. + */ + private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + (left, right) match { + case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => + equalsIgnoreNullability(leftElementType, rightElementType) + case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => + equalsIgnoreNullability(leftKeyType, rightKeyType) && + equalsIgnoreNullability(leftValueType, rightValueType) + case (StructType(leftFields), StructType(rightFields)) => + leftFields.length == rightFields.length && + leftFields.zip(rightFields).forall { case (l, r) => + l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) + } + case (l, r) => l == r + } + } + + /** + * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType. + * + * Compatible nullability is defined as follows: + * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` + * if and only if `to.containsNull` is true, or both of `from.containsNull` and + * `to.containsNull` are false. + * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` + * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and + * `to.valueContainsNull` are false. + * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` + * if and only if for all every pair of fields, `to.nullable` is true, or both + * of `fromField.nullable` and `toField.nullable` are false. + */ + private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => + (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + (tn || !fn) && + equalsIgnoreCompatibleNullability(fromKey, toKey) && + equalsIgnoreCompatibleNullability(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { case (fromField, toField) => + fromField.name == toField.name && + (toField.nullable || !fromField.nullable) && + equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala new file mode 100644 index 0000000000000..03f0644bc784c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateType.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `java.sql.Date` values. + * Please use the singleton [[DataTypes.DateType]]. + * + * @group dataType + */ +@DeveloperApi +class DateType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "DateType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Int + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the DateType is 4 bytes. + */ + override def defaultSize: Int = 4 + + private[spark] override def asNullable: DateType = this +} + + +case object DateType extends DateType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala new file mode 100644 index 0000000000000..0f8cecd28f7df --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.expressions.Expression + + +/** Precision parameters for a Decimal */ +case class PrecisionInfo(precision: Int, scale: Int) + + +/** + * :: DeveloperApi :: + * The data type representing `java.math.BigDecimal` values. + * A Decimal that might have fixed precision and scale, or unlimited values for these. + * + * Please use [[DataTypes.createDecimalType()]] to create a specific instance. + * + * @group dataType + */ +@DeveloperApi +case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + + /** No-arg constructor for kryo. */ + protected def this() = this(null) + + private[sql] type InternalType = Decimal + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = Decimal.DecimalIsFractional + private[sql] val fractional = Decimal.DecimalIsFractional + private[sql] val ordering = Decimal.DecimalIsFractional + private[sql] val asIntegral = Decimal.DecimalAsIfIntegral + + def precision: Int = precisionInfo.map(_.precision).getOrElse(-1) + + def scale: Int = precisionInfo.map(_.scale).getOrElse(-1) + + override def typeName: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal" + } + + override def toString: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" + case None => "DecimalType()" + } + + /** + * The default size of a value of the DecimalType is 4096 bytes. + */ + override def defaultSize: Int = 4096 + + override def simpleString: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal(10,0)" + } + + private[spark] override def asNullable: DecimalType = this +} + + +/** Extra factory methods and pattern matchers for Decimals */ +object DecimalType { + val Unlimited: DecimalType = DecimalType(None) + + object Fixed { + def unapply(t: DecimalType): Option[(Int, Int)] = + t.precisionInfo.map(p => (p.precision, p.scale)) + } + + object Expression { + def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { + case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) + case _ => None + } + } + + def apply(): DecimalType = Unlimited + + def apply(precision: Int, scale: Int): DecimalType = + DecimalType(Some(PrecisionInfo(precision, scale))) + + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] + + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] + + def isFixed(dataType: DataType): Boolean = dataType match { + case DecimalType.Fixed(_, _) => true + case _ => false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala new file mode 100644 index 0000000000000..66766623213c9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Fractional, Numeric} +import scala.math.Numeric.DoubleAsIfIntegral +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. + * + * @group dataType + */ +@DeveloperApi +class DoubleType private() extends FractionalType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Double + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Double]] + private[sql] val fractional = implicitly[Fractional[Double]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val asIntegral = DoubleAsIfIntegral + + /** + * The default size of a value of the DoubleType is 8 bytes. + */ + override def defaultSize: Int = 8 + + private[spark] override def asNullable: DoubleType = this +} + +case object DoubleType extends DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala new file mode 100644 index 0000000000000..1d5a2f4f6f86c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Numeric.FloatAsIfIntegral +import scala.math.{Ordering, Fractional, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. + * + * @group dataType + */ +@DeveloperApi +class FloatType private() extends FractionalType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Float + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Float]] + private[sql] val fractional = implicitly[Fractional[Float]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val asIntegral = FloatAsIfIntegral + + /** + * The default size of a value of the FloatType is 4 bytes. + */ + override def defaultSize: Int = 4 + + private[spark] override def asNullable: FloatType = this +} + +case object FloatType extends FloatType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala new file mode 100644 index 0000000000000..74e464c082873 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Integral, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. + * + * @group dataType + */ +@DeveloperApi +class IntegerType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Int + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Int]] + private[sql] val integral = implicitly[Integral[Int]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the IntegerType is 4 bytes. + */ + override def defaultSize: Int = 4 + + override def simpleString: String = "int" + + private[spark] override def asNullable: IntegerType = this +} + +case object IntegerType extends IntegerType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala new file mode 100644 index 0000000000000..390675782e5fd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Integral, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. + * + * @group dataType + */ +@DeveloperApi +class LongType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "LongType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Long + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Long]] + private[sql] val integral = implicitly[Integral[Long]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the LongType is 8 bytes. + */ + override def defaultSize: Int = 8 + + override def simpleString: String = "bigint" + + private[spark] override def asNullable: LongType = this +} + + +case object LongType extends LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala new file mode 100644 index 0000000000000..cfdf493074415 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ + + +/** + * :: DeveloperApi :: + * The data type for Maps. Keys in a map are not allowed to have `null` values. + * + * Please use [[DataTypes.createMapType()]] to create a specific instance. + * + * @param keyType The data type of map keys. + * @param valueType The data type of map values. + * @param valueContainsNull Indicates if map values have `null` values. + * + * @group dataType + */ +case class MapType( + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean) extends DataType { + + /** No-arg constructor for kryo. */ + def this() = this(null, null, false) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"$prefix-- key: ${keyType.typeName}\n") + builder.append(s"$prefix-- value: ${valueType.typeName} " + + s"(valueContainsNull = $valueContainsNull)\n") + DataType.buildFormattedString(keyType, s"$prefix |", builder) + DataType.buildFormattedString(valueType, s"$prefix |", builder) + } + + override private[sql] def jsonValue: JValue = + ("type" -> typeName) ~ + ("keyType" -> keyType.jsonValue) ~ + ("valueType" -> valueType.jsonValue) ~ + ("valueContainsNull" -> valueContainsNull) + + /** + * The default size of a value of the MapType is + * 100 * (the default size of the key type + the default size of the value type). + * (We assume that there are 100 elements). + */ + override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) + + override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" + + private[spark] override def asNullable: MapType = + MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) +} + + +object MapType { + /** + * Construct a [[MapType]] object with the given key type and value type. + * The `valueContainsNull` is true. + */ + def apply(keyType: DataType, valueType: DataType): MapType = + MapType(keyType: DataType, valueType: DataType, valueContainsNull = true) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala new file mode 100644 index 0000000000000..b64b07431fa96 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/NullType.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.annotation.DeveloperApi + + +/** + * :: DeveloperApi :: + * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. + * + * @group dataType + */ +@DeveloperApi +class NullType private() extends DataType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "NullType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + override def defaultSize: Int = 1 + + private[spark] override def asNullable: NullType = this +} + +case object NullType extends NullType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala new file mode 100644 index 0000000000000..73e9ec780b0af --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.{Ordering, Integral, Numeric} +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. + * + * @group dataType + */ +@DeveloperApi +class ShortType private() extends IntegralType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Short + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val numeric = implicitly[Numeric[Short]] + private[sql] val integral = implicitly[Integral[Short]] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the ShortType is 2 bytes. + */ + override def defaultSize: Int = 2 + + override def simpleString: String = "smallint" + + private[spark] override def asNullable: ShortType = this +} + +case object ShortType extends ShortType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala new file mode 100644 index 0000000000000..134ab0af4e0de --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + +/** + * :: DeveloperApi :: + * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. + * + * @group dataType + */ +@DeveloperApi +class StringType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "StringType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = UTF8String + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + private[sql] val ordering = implicitly[Ordering[InternalType]] + + /** + * The default size of a value of the StringType is 4096 bytes. + */ + override def defaultSize: Int = 4096 + + private[spark] override def asNullable: StringType = this +} + +case object StringType extends StringType + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala new file mode 100644 index 0000000000000..83570a5eaee61 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ + +/** + * A field inside a StructType. + * @param name The name of this field. + * @param dataType The data type of this field. + * @param nullable Indicates if values of this field can be `null` values. + * @param metadata The metadata of this field. The metadata should be preserved during + * transformation if the content of the column is not modified, e.g, in selection. + */ +case class StructField( + name: String, + dataType: DataType, + nullable: Boolean = true, + metadata: Metadata = Metadata.empty) { + + /** No-arg constructor for kryo. */ + protected def this() = this(null, null) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") + DataType.buildFormattedString(dataType, s"$prefix |", builder) + } + + // override the default toString to be compatible with legacy parquet files. + override def toString: String = s"StructField($name,$dataType,$nullable)" + + private[sql] def jsonValue: JValue = { + ("name" -> name) ~ + ("type" -> dataType.jsonValue) ~ + ("nullable" -> nullable) ~ + ("metadata" -> metadata.jsonValue) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala new file mode 100644 index 0000000000000..d80ffca18ec9a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.collection.mutable.ArrayBuffer +import scala.math.max + +import org.json4s.JsonDSL._ + +import org.apache.spark.SparkException +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} + + +/** + * :: DeveloperApi :: + * A [[StructType]] object can be constructed by + * {{{ + * StructType(fields: Seq[StructField]) + * }}} + * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. + * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. + * If a provided name does not have a matching field, it will be ignored. For the case + * of extracting a single StructField, a `null` will be returned. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val struct = + * StructType( + * StructField("a", IntegerType, true) :: + * StructField("b", LongType, false) :: + * StructField("c", BooleanType, false) :: Nil) + * + * // Extract a single StructField. + * val singleField = struct("b") + * // singleField: StructField = StructField(b,LongType,false) + * + * // This struct does not have a field called "d". null will be returned. + * val nonExisting = struct("d") + * // nonExisting: StructField = null + * + * // Extract multiple StructFields. Field names are provided in a set. + * // A StructType object will be returned. + * val twoFields = struct(Set("b", "c")) + * // twoFields: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * + * // Any names without matching fields will be ignored. + * // For the case shown below, "d" will be ignored and + * // it is treated as struct(Set("b", "c")). + * val ignoreNonExisting = struct(Set("b", "c", "d")) + * // ignoreNonExisting: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * }}} + * + * A [[org.apache.spark.sql.Row]] object is used as a value of the StructType. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val innerStruct = + * StructType( + * StructField("f1", IntegerType, true) :: + * StructField("f2", LongType, false) :: + * StructField("f3", BooleanType, false) :: Nil) + * + * val struct = StructType( + * StructField("a", innerStruct, true) :: Nil) + * + * // Create a Row with the schema defined by struct + * val row = Row(Row(1, 2, true)) + * // row: Row = [[1,2,true]] + * }}} + * + * @group dataType + */ +@DeveloperApi +case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { + + /** No-arg constructor for kryo. */ + protected def this() = this(null) + + /** Returns all field names in an array. */ + def fieldNames: Array[String] = fields.map(_.name) + + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet + private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap + private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap + + /** + * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not + * have a name matching the given name, `null` will be returned. + */ + def apply(name: String): StructField = { + nameToField.getOrElse(name, + throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + } + + /** + * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the + * original order of fields. Those names which do not have matching fields will be ignored. + */ + def apply(names: Set[String]): StructType = { + val nonExistFields = names -- fieldNamesSet + if (nonExistFields.nonEmpty) { + throw new IllegalArgumentException( + s"Field ${nonExistFields.mkString(",")} does not exist.") + } + // Preserve the original order of fields. + StructType(fields.filter(f => names.contains(f.name))) + } + + /** + * Returns index of a given field + */ + def fieldIndex(name: String): Int = { + nameToIndex.getOrElse(name, + throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) + } + + protected[sql] def toAttributes: Seq[AttributeReference] = + map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + + def treeString: String = { + val builder = new StringBuilder + builder.append("root\n") + val prefix = " |" + fields.foreach(field => field.buildFormattedString(prefix, builder)) + + builder.toString() + } + + def printTreeString(): Unit = println(treeString) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + fields.foreach(field => field.buildFormattedString(prefix, builder)) + } + + override private[sql] def jsonValue = + ("type" -> typeName) ~ + ("fields" -> map(_.jsonValue)) + + override def apply(fieldIndex: Int): StructField = fields(fieldIndex) + + override def length: Int = fields.length + + override def iterator: Iterator[StructField] = fields.iterator + + /** + * The default size of a value of the StructType is the total default sizes of all field types. + */ + override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum + + override def simpleString: String = { + val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}") + s"struct<${fieldTypes.mkString(",")}>" + } + + /** + * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field + * B from `that`, + * + * 1. If A and B have the same name and data type, they are merged to a field C with the same name + * and data type. C is nullable if and only if either A or B is nullable. + * 2. If A doesn't exist in `that`, it's included in the result schema. + * 3. If B doesn't exist in `this`, it's also included in the result schema. + * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be + * thrown. + */ + private[sql] def merge(that: StructType): StructType = + StructType.merge(this, that).asInstanceOf[StructType] + + private[spark] override def asNullable: StructType = { + val newFields = fields.map { + case StructField(name, dataType, nullable, metadata) => + StructField(name, dataType.asNullable, nullable = true, metadata) + } + + StructType(newFields) + } +} + + +object StructType { + + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) + + def apply(fields: java.util.List[StructField]): StructType = { + StructType(fields.toArray.asInstanceOf[Array[StructField]]) + } + + protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + + private[sql] def merge(left: DataType, right: DataType): DataType = + (left, right) match { + case (ArrayType(leftElementType, leftContainsNull), + ArrayType(rightElementType, rightContainsNull)) => + ArrayType( + merge(leftElementType, rightElementType), + leftContainsNull || rightContainsNull) + + case (MapType(leftKeyType, leftValueType, leftContainsNull), + MapType(rightKeyType, rightValueType, rightContainsNull)) => + MapType( + merge(leftKeyType, rightKeyType), + merge(leftValueType, rightValueType), + leftContainsNull || rightContainsNull) + + case (StructType(leftFields), StructType(rightFields)) => + val newFields = ArrayBuffer.empty[StructField] + + leftFields.foreach { + case leftField @ StructField(leftName, leftType, leftNullable, _) => + rightFields + .find(_.name == leftName) + .map { case rightField @ StructField(_, rightType, rightNullable, _) => + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } + .orElse(Some(leftField)) + .foreach(newFields += _) + } + + rightFields + .filterNot(f => leftFields.map(_.name).contains(f.name)) + .foreach(newFields += _) + + StructType(newFields) + + case (DecimalType.Fixed(leftPrecision, leftScale), + DecimalType.Fixed(rightPrecision, rightScale)) => + DecimalType( + max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), + max(leftScale, rightScale)) + + case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) + if leftUdt.userClass == rightUdt.userClass => leftUdt + + case (leftType, rightType) if leftType == rightType => + leftType + + case _ => + throw new SparkException(s"Failed to merge incompatible data types $left and $right") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala new file mode 100644 index 0000000000000..aebabfc475925 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/TimestampType.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import java.sql.Timestamp + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.ScalaReflectionLock + + +/** + * :: DeveloperApi :: + * The data type representing `java.sql.Timestamp` values. + * Please use the singleton [[DataTypes.TimestampType]]. + * + * @group dataType + */ +@DeveloperApi +class TimestampType private() extends AtomicType { + // The companion object and this class is separated so the companion object also subclasses + // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. + // Defined with a private constructor so the companion object is the only possible instantiation. + private[sql] type InternalType = Timestamp + + @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } + + private[sql] val ordering = new Ordering[InternalType] { + def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) + } + + /** + * The default size of a value of the TimestampType is 12 bytes. + */ + override def defaultSize: Int = 12 + + private[spark] override def asNullable: TimestampType = this +} + +case object TimestampType extends TimestampType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala new file mode 100644 index 0000000000000..6b20505c6009a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL._ + +import org.apache.spark.annotation.DeveloperApi + +/** + * ::DeveloperApi:: + * The data type for User Defined Types (UDTs). + * + * This interface allows a user to make their own classes more interoperable with SparkSQL; + * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create + * a `DataFrame` which has class X in the schema. + * + * For SparkSQL to recognize UDTs, the UDT must be annotated with + * [[SQLUserDefinedType]]. + * + * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD. + * The conversion via `deserialize` occurs when reading from a `DataFrame`. + */ +@DeveloperApi +abstract class UserDefinedType[UserType] extends DataType with Serializable { + + /** Underlying storage type for this UDT */ + def sqlType: DataType + + /** Paired Python UDT class, if exists. */ + def pyUDT: String = null + + /** + * Convert the user type to a SQL datum + * + * TODO: Can we make this take obj: UserType? The issue is in + * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType. + */ + def serialize(obj: Any): Any + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType + + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("class" -> this.getClass.getName) ~ + ("pyClass" -> pyUDT) ~ + ("sqlType" -> sqlType.jsonValue) + } + + /** + * Class object for the UserType + */ + def userClass: java.lang.Class[UserType] + + /** + * The default size of a value of the UserDefinedType is 4096 bytes. + */ + override def defaultSize: Int = 4096 + + /** + * For UDT, asNullable will not change the nullability of its internal sqlType and just returns + * itself. + */ + private[spark] override def asNullable: UserDefinedType[UserType] = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala deleted file mode 100644 index 87c7b7599366a..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ /dev/null @@ -1,1224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.types - -import java.sql.Timestamp - -import scala.collection.mutable.ArrayBuffer -import scala.math._ -import scala.math.Numeric.{FloatAsIfIntegral, DoubleAsIfIntegral} -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} -import scala.util.parsing.combinator.RegexParsers - -import org.json4s._ -import org.json4s.JsonAST.JValue -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods._ - -import org.apache.spark.SparkException -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} -import org.apache.spark.util.Utils - - -object DataType { - def fromJson(json: String): DataType = parseDataType(parse(json)) - - private val nonDecimalNameToType = { - Seq(NullType, DateType, TimestampType, BinaryType, - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) - .map(t => t.typeName -> t).toMap - } - - /** Given the string representation of a type, return its DataType */ - private def nameToType(name: String): DataType = { - val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r - name match { - case "decimal" => DecimalType.Unlimited - case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) - case other => nonDecimalNameToType(other) - } - } - - private object JSortedObject { - def unapplySeq(value: JValue): Option[List[(String, JValue)]] = value match { - case JObject(seq) => Some(seq.toList.sortBy(_._1)) - case _ => None - } - } - - // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. - private def parseDataType(json: JValue): DataType = json match { - case JString(name) => - nameToType(name) - - case JSortedObject( - ("containsNull", JBool(n)), - ("elementType", t: JValue), - ("type", JString("array"))) => - ArrayType(parseDataType(t), n) - - case JSortedObject( - ("keyType", k: JValue), - ("type", JString("map")), - ("valueContainsNull", JBool(n)), - ("valueType", v: JValue)) => - MapType(parseDataType(k), parseDataType(v), n) - - case JSortedObject( - ("fields", JArray(fields)), - ("type", JString("struct"))) => - StructType(fields.map(parseStructField)) - - case JSortedObject( - ("class", JString(udtClass)), - ("pyClass", _), - ("sqlType", _), - ("type", JString("udt"))) => - Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] - } - - private def parseStructField(json: JValue): StructField = json match { - case JSortedObject( - ("metadata", metadata: JObject), - ("name", JString(name)), - ("nullable", JBool(nullable)), - ("type", dataType: JValue)) => - StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata)) - // Support reading schema when 'metadata' is missing. - case JSortedObject( - ("name", JString(name)), - ("nullable", JBool(nullable)), - ("type", dataType: JValue)) => - StructField(name, parseDataType(dataType), nullable) - } - - @deprecated("Use DataType.fromJson instead", "1.2.0") - def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) - - private object CaseClassStringParser extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - ( "StringType" ^^^ StringType - | "FloatType" ^^^ FloatType - | "IntegerType" ^^^ IntegerType - | "ByteType" ^^^ ByteType - | "ShortType" ^^^ ShortType - | "DoubleType" ^^^ DoubleType - | "LongType" ^^^ LongType - | "BinaryType" ^^^ BinaryType - | "BooleanType" ^^^ BooleanType - | "DateType" ^^^ DateType - | "DecimalType()" ^^^ DecimalType.Unlimited - | fixedDecimalType - | "TimestampType" ^^^ TimestampType - ) - - protected lazy val fixedDecimalType: Parser[DataType] = - ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) - } - - protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { - case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) - } - - protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { - case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) - } - - protected lazy val structField: Parser[StructField] = - ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => - StructField(name, tpe, nullable = nullable) - } - - protected lazy val boolVal: Parser[Boolean] = - ( "true" ^^^ true - | "false" ^^^ false - ) - - protected lazy val structType: Parser[DataType] = - "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => StructType(fields) - } - - protected lazy val dataType: Parser[DataType] = - ( arrayType - | mapType - | structType - | primitiveType - ) - - /** - * Parses a string representation of a DataType. - * - * TODO: Generate parser as pickler... - */ - def apply(asString: String): DataType = parseAll(dataType, asString) match { - case Success(result, _) => result - case failure: NoSuccess => - throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") - } - } - - protected[types] def buildFormattedString( - dataType: DataType, - prefix: String, - builder: StringBuilder): Unit = { - dataType match { - case array: ArrayType => - array.buildFormattedString(prefix, builder) - case struct: StructType => - struct.buildFormattedString(prefix, builder) - case map: MapType => - map.buildFormattedString(prefix, builder) - case _ => - } - } - - /** - * Compares two types, ignoring nullability of ArrayType, MapType, StructType. - */ - private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { - (left, right) match { - case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => - equalsIgnoreNullability(leftElementType, rightElementType) - case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => - equalsIgnoreNullability(leftKeyType, rightKeyType) && - equalsIgnoreNullability(leftValueType, rightValueType) - case (StructType(leftFields), StructType(rightFields)) => - leftFields.length == rightFields.length && - leftFields.zip(rightFields).forall { case (l, r) => - l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) - } - case (l, r) => l == r - } - } - - /** - * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType. - * - * Compatible nullability is defined as follows: - * - If `from` and `to` are ArrayTypes, `from` has a compatible nullability with `to` - * if and only if `to.containsNull` is true, or both of `from.containsNull` and - * `to.containsNull` are false. - * - If `from` and `to` are MapTypes, `from` has a compatible nullability with `to` - * if and only if `to.valueContainsNull` is true, or both of `from.valueContainsNull` and - * `to.valueContainsNull` are false. - * - If `from` and `to` are StructTypes, `from` has a compatible nullability with `to` - * if and only if for all every pair of fields, `to.nullable` is true, or both - * of `fromField.nullable` and `toField.nullable` are false. - */ - private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { - (from, to) match { - case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => - (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) - - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => - (tn || !fn) && - equalsIgnoreCompatibleNullability(fromKey, toKey) && - equalsIgnoreCompatibleNullability(fromValue, toValue) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.length == toFields.length && - fromFields.zip(toFields).forall { case (fromField, toField) => - fromField.name == toField.name && - (toField.nullable || !fromField.nullable) && - equalsIgnoreCompatibleNullability(fromField.dataType, toField.dataType) - } - - case (fromDataType, toDataType) => fromDataType == toDataType - } - } -} - - -/** - * :: DeveloperApi :: - * The base type of all Spark SQL data types. - * - * @group dataType - */ -@DeveloperApi -abstract class DataType { - /** Matches any expression that evaluates to this DataType */ - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType == this => true - case _ => false - } - - /** The default size of a value of this data type. */ - def defaultSize: Int - - def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase - - private[sql] def jsonValue: JValue = typeName - - def json: String = compact(render(jsonValue)) - - def prettyJson: String = pretty(render(jsonValue)) - - def simpleString: String = typeName - - /** Check if `this` and `other` are the same data type when ignoring nullability - * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). - */ - private[spark] def sameType(other: DataType): Boolean = - DataType.equalsIgnoreNullability(this, other) - - /** Returns the same data type but set all nullability fields are true - * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). - */ - private[spark] def asNullable: DataType -} - -/** - * :: DeveloperApi :: - * The data type representing `NULL` values. Please use the singleton [[DataTypes.NullType]]. - * - * @group dataType - */ -@DeveloperApi -class NullType private() extends DataType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "NullType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - override def defaultSize: Int = 1 - - private[spark] override def asNullable: NullType = this -} - -case object NullType extends NullType - - -/** - * An internal type used to represent everything that is not null, UDTs, arrays, structs, and maps. - */ -protected[sql] abstract class AtomicType extends DataType { - private[sql] type InternalType - @transient private[sql] val tag: TypeTag[InternalType] - private[sql] val ordering: Ordering[InternalType] - - @transient private[sql] val classTag = ScalaReflectionLock.synchronized { - val mirror = runtimeMirror(Utils.getSparkClassLoader) - ClassTag[InternalType](mirror.runtimeClass(tag.tpe)) - } -} - - -/** - * :: DeveloperApi :: - * The data type representing `String` values. Please use the singleton [[DataTypes.StringType]]. - * - * @group dataType - */ -@DeveloperApi -class StringType private() extends AtomicType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "StringType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = UTF8String - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the StringType is 4096 bytes. - */ - override def defaultSize: Int = 4096 - - private[spark] override def asNullable: StringType = this -} - -case object StringType extends StringType - - -/** - * :: DeveloperApi :: - * The data type representing `Array[Byte]` values. - * Please use the singleton [[DataTypes.BinaryType]]. - * - * @group dataType - */ -@DeveloperApi -class BinaryType private() extends AtomicType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "BinaryType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Array[Byte] - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = new Ordering[InternalType] { - def compare(x: Array[Byte], y: Array[Byte]): Int = { - for (i <- 0 until x.length; if i < y.length) { - val res = x(i).compareTo(y(i)) - if (res != 0) return res - } - x.length - y.length - } - } - - /** - * The default size of a value of the BinaryType is 4096 bytes. - */ - override def defaultSize: Int = 4096 - - private[spark] override def asNullable: BinaryType = this -} - -case object BinaryType extends BinaryType - - -/** - * :: DeveloperApi :: - * The data type representing `Boolean` values. Please use the singleton [[DataTypes.BooleanType]]. - * - *@group dataType - */ -@DeveloperApi -class BooleanType private() extends AtomicType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "BooleanType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Boolean - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the BooleanType is 1 byte. - */ - override def defaultSize: Int = 1 - - private[spark] override def asNullable: BooleanType = this -} - -case object BooleanType extends BooleanType - - -/** - * :: DeveloperApi :: - * The data type representing `java.sql.Timestamp` values. - * Please use the singleton [[DataTypes.TimestampType]]. - * - * @group dataType - */ -@DeveloperApi -class TimestampType private() extends AtomicType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "TimestampType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Timestamp - - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - - private[sql] val ordering = new Ordering[InternalType] { - def compare(x: Timestamp, y: Timestamp): Int = x.compareTo(y) - } - - /** - * The default size of a value of the TimestampType is 12 bytes. - */ - override def defaultSize: Int = 12 - - private[spark] override def asNullable: TimestampType = this -} - -case object TimestampType extends TimestampType - - -/** - * :: DeveloperApi :: - * The data type representing `java.sql.Date` values. - * Please use the singleton [[DataTypes.DateType]]. - * - * @group dataType - */ -@DeveloperApi -class DateType private() extends AtomicType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "DateType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Int - - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the DateType is 4 bytes. - */ - override def defaultSize: Int = 4 - - private[spark] override def asNullable: DateType = this -} - -case object DateType extends DateType - - -/** - * :: DeveloperApi :: - * Numeric data types. - * - * @group dataType - */ -abstract class NumericType extends AtomicType { - // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for - // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a - // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets - // desugared by the compiler into an argument to the objects constructor. This means there is no - // longer an no argument constructor and thus the JVM cannot serialize the object anymore. - private[sql] val numeric: Numeric[InternalType] -} - - -protected[sql] object NumericType { - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] -} - - -/** Matcher for any expressions that evaluate to [[IntegralType]]s */ -protected[sql] object IntegralType { - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType.isInstanceOf[IntegralType] => true - case _ => false - } -} - - -protected[sql] sealed abstract class IntegralType extends NumericType { - private[sql] val integral: Integral[InternalType] -} - - -/** - * :: DeveloperApi :: - * The data type representing `Long` values. Please use the singleton [[DataTypes.LongType]]. - * - * @group dataType - */ -@DeveloperApi -class LongType private() extends IntegralType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "LongType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Long - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = implicitly[Numeric[Long]] - private[sql] val integral = implicitly[Integral[Long]] - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the LongType is 8 bytes. - */ - override def defaultSize: Int = 8 - - override def simpleString: String = "bigint" - - private[spark] override def asNullable: LongType = this -} - -case object LongType extends LongType - - -/** - * :: DeveloperApi :: - * The data type representing `Int` values. Please use the singleton [[DataTypes.IntegerType]]. - * - * @group dataType - */ -@DeveloperApi -class IntegerType private() extends IntegralType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "IntegerType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Int - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = implicitly[Numeric[Int]] - private[sql] val integral = implicitly[Integral[Int]] - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the IntegerType is 4 bytes. - */ - override def defaultSize: Int = 4 - - override def simpleString: String = "int" - - private[spark] override def asNullable: IntegerType = this -} - -case object IntegerType extends IntegerType - - -/** - * :: DeveloperApi :: - * The data type representing `Short` values. Please use the singleton [[DataTypes.ShortType]]. - * - * @group dataType - */ -@DeveloperApi -class ShortType private() extends IntegralType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "ShortType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Short - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = implicitly[Numeric[Short]] - private[sql] val integral = implicitly[Integral[Short]] - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the ShortType is 2 bytes. - */ - override def defaultSize: Int = 2 - - override def simpleString: String = "smallint" - - private[spark] override def asNullable: ShortType = this -} - -case object ShortType extends ShortType - - -/** - * :: DeveloperApi :: - * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. - * - * @group dataType - */ -@DeveloperApi -class ByteType private() extends IntegralType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "ByteType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Byte - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = implicitly[Numeric[Byte]] - private[sql] val integral = implicitly[Integral[Byte]] - private[sql] val ordering = implicitly[Ordering[InternalType]] - - /** - * The default size of a value of the ByteType is 1 byte. - */ - override def defaultSize: Int = 1 - - override def simpleString: String = "tinyint" - - private[spark] override def asNullable: ByteType = this -} - -case object ByteType extends ByteType - - -/** Matcher for any expressions that evaluate to [[FractionalType]]s */ -protected[sql] object FractionalType { - def unapply(a: Expression): Boolean = a match { - case e: Expression if e.dataType.isInstanceOf[FractionalType] => true - case _ => false - } -} - - -protected[sql] sealed abstract class FractionalType extends NumericType { - private[sql] val fractional: Fractional[InternalType] - private[sql] val asIntegral: Integral[InternalType] -} - - -/** Precision parameters for a Decimal */ -case class PrecisionInfo(precision: Int, scale: Int) - - -/** - * :: DeveloperApi :: - * The data type representing `java.math.BigDecimal` values. - * A Decimal that might have fixed precision and scale, or unlimited values for these. - * - * Please use [[DataTypes.createDecimalType()]] to create a specific instance. - * - * @group dataType - */ -@DeveloperApi -case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { - - /** No-arg constructor for kryo. */ - protected def this() = this(null) - - private[sql] type InternalType = Decimal - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = Decimal.DecimalIsFractional - private[sql] val fractional = Decimal.DecimalIsFractional - private[sql] val ordering = Decimal.DecimalIsFractional - private[sql] val asIntegral = Decimal.DecimalAsIfIntegral - - def precision: Int = precisionInfo.map(_.precision).getOrElse(-1) - - def scale: Int = precisionInfo.map(_.scale).getOrElse(-1) - - override def typeName: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" - case None => "decimal" - } - - override def toString: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" - case None => "DecimalType()" - } - - /** - * The default size of a value of the DecimalType is 4096 bytes. - */ - override def defaultSize: Int = 4096 - - override def simpleString: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" - case None => "decimal(10,0)" - } - - private[spark] override def asNullable: DecimalType = this -} - - -/** Extra factory methods and pattern matchers for Decimals */ -object DecimalType { - val Unlimited: DecimalType = DecimalType(None) - - object Fixed { - def unapply(t: DecimalType): Option[(Int, Int)] = - t.precisionInfo.map(p => (p.precision, p.scale)) - } - - object Expression { - def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { - case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) - case _ => None - } - } - - def apply(): DecimalType = Unlimited - - def apply(precision: Int, scale: Int): DecimalType = - DecimalType(Some(PrecisionInfo(precision, scale))) - - def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] - - def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] - - def isFixed(dataType: DataType): Boolean = dataType match { - case DecimalType.Fixed(_, _) => true - case _ => false - } -} - - -/** - * :: DeveloperApi :: - * The data type representing `Double` values. Please use the singleton [[DataTypes.DoubleType]]. - * - * @group dataType - */ -@DeveloperApi -class DoubleType private() extends FractionalType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "DoubleType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Double - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = implicitly[Numeric[Double]] - private[sql] val fractional = implicitly[Fractional[Double]] - private[sql] val ordering = implicitly[Ordering[InternalType]] - private[sql] val asIntegral = DoubleAsIfIntegral - - /** - * The default size of a value of the DoubleType is 8 bytes. - */ - override def defaultSize: Int = 8 - - private[spark] override def asNullable: DoubleType = this -} - -case object DoubleType extends DoubleType - - -/** - * :: DeveloperApi :: - * The data type representing `Float` values. Please use the singleton [[DataTypes.FloatType]]. - * - * @group dataType - */ -@DeveloperApi -class FloatType private() extends FractionalType { - // The companion object and this class is separated so the companion object also subclasses - // this type. Otherwise, the companion object would be of type "FloatType$" in byte code. - // Defined with a private constructor so the companion object is the only possible instantiation. - private[sql] type InternalType = Float - @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } - private[sql] val numeric = implicitly[Numeric[Float]] - private[sql] val fractional = implicitly[Fractional[Float]] - private[sql] val ordering = implicitly[Ordering[InternalType]] - private[sql] val asIntegral = FloatAsIfIntegral - - /** - * The default size of a value of the FloatType is 4 bytes. - */ - override def defaultSize: Int = 4 - - private[spark] override def asNullable: FloatType = this -} - -case object FloatType extends FloatType - - -object ArrayType { - /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ - def apply(elementType: DataType): ArrayType = ArrayType(elementType, true) -} - - -/** - * :: DeveloperApi :: - * The data type for collections of multiple values. - * Internally these are represented as columns that contain a ``scala.collection.Seq``. - * - * Please use [[DataTypes.createArrayType()]] to create a specific instance. - * - * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and - * `containsNull: Boolean`. The field of `elementType` is used to specify the type of - * array elements. The field of `containsNull` is used to specify if the array has `null` values. - * - * @param elementType The data type of values. - * @param containsNull Indicates if values have `null` values - * - * @group dataType - */ -@DeveloperApi -case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { - - /** No-arg constructor for kryo. */ - protected def this() = this(null, false) - - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append( - s"$prefix-- element: ${elementType.typeName} (containsNull = $containsNull)\n") - DataType.buildFormattedString(elementType, s"$prefix |", builder) - } - - override private[sql] def jsonValue = - ("type" -> typeName) ~ - ("elementType" -> elementType.jsonValue) ~ - ("containsNull" -> containsNull) - - /** - * The default size of a value of the ArrayType is 100 * the default size of the element type. - * (We assume that there are 100 elements). - */ - override def defaultSize: Int = 100 * elementType.defaultSize - - override def simpleString: String = s"array<${elementType.simpleString}>" - - private[spark] override def asNullable: ArrayType = - ArrayType(elementType.asNullable, containsNull = true) -} - - -/** - * A field inside a StructType. - * @param name The name of this field. - * @param dataType The data type of this field. - * @param nullable Indicates if values of this field can be `null` values. - * @param metadata The metadata of this field. The metadata should be preserved during - * transformation if the content of the column is not modified, e.g, in selection. - */ -case class StructField( - name: String, - dataType: DataType, - nullable: Boolean = true, - metadata: Metadata = Metadata.empty) { - - /** No-arg constructor for kryo. */ - protected def this() = this(null, null) - - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") - DataType.buildFormattedString(dataType, s"$prefix |", builder) - } - - // override the default toString to be compatible with legacy parquet files. - override def toString: String = s"StructField($name,$dataType,$nullable)" - - private[sql] def jsonValue: JValue = { - ("name" -> name) ~ - ("type" -> dataType.jsonValue) ~ - ("nullable" -> nullable) ~ - ("metadata" -> metadata.jsonValue) - } -} - - -object StructType { - protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = - StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) - - def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) - - def apply(fields: java.util.List[StructField]): StructType = { - StructType(fields.toArray.asInstanceOf[Array[StructField]]) - } - - private[sql] def merge(left: DataType, right: DataType): DataType = - (left, right) match { - case (ArrayType(leftElementType, leftContainsNull), - ArrayType(rightElementType, rightContainsNull)) => - ArrayType( - merge(leftElementType, rightElementType), - leftContainsNull || rightContainsNull) - - case (MapType(leftKeyType, leftValueType, leftContainsNull), - MapType(rightKeyType, rightValueType, rightContainsNull)) => - MapType( - merge(leftKeyType, rightKeyType), - merge(leftValueType, rightValueType), - leftContainsNull || rightContainsNull) - - case (StructType(leftFields), StructType(rightFields)) => - val newFields = ArrayBuffer.empty[StructField] - - leftFields.foreach { - case leftField @ StructField(leftName, leftType, leftNullable, _) => - rightFields - .find(_.name == leftName) - .map { case rightField @ StructField(_, rightType, rightNullable, _) => - leftField.copy( - dataType = merge(leftType, rightType), - nullable = leftNullable || rightNullable) - } - .orElse(Some(leftField)) - .foreach(newFields += _) - } - - rightFields - .filterNot(f => leftFields.map(_.name).contains(f.name)) - .foreach(newFields += _) - - StructType(newFields) - - case (DecimalType.Fixed(leftPrecision, leftScale), - DecimalType.Fixed(rightPrecision, rightScale)) => - DecimalType( - max(leftScale, rightScale) + max(leftPrecision - leftScale, rightPrecision - rightScale), - max(leftScale, rightScale)) - - case (leftUdt: UserDefinedType[_], rightUdt: UserDefinedType[_]) - if leftUdt.userClass == rightUdt.userClass => leftUdt - - case (leftType, rightType) if leftType == rightType => - leftType - - case _ => - throw new SparkException(s"Failed to merge incompatible data types $left and $right") - } -} - - -/** - * :: DeveloperApi :: - * A [[StructType]] object can be constructed by - * {{{ - * StructType(fields: Seq[StructField]) - * }}} - * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. - * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. - * If a provided name does not have a matching field, it will be ignored. For the case - * of extracting a single StructField, a `null` will be returned. - * Example: - * {{{ - * import org.apache.spark.sql._ - * - * val struct = - * StructType( - * StructField("a", IntegerType, true) :: - * StructField("b", LongType, false) :: - * StructField("c", BooleanType, false) :: Nil) - * - * // Extract a single StructField. - * val singleField = struct("b") - * // singleField: StructField = StructField(b,LongType,false) - * - * // This struct does not have a field called "d". null will be returned. - * val nonExisting = struct("d") - * // nonExisting: StructField = null - * - * // Extract multiple StructFields. Field names are provided in a set. - * // A StructType object will be returned. - * val twoFields = struct(Set("b", "c")) - * // twoFields: StructType = - * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) - * - * // Any names without matching fields will be ignored. - * // For the case shown below, "d" will be ignored and - * // it is treated as struct(Set("b", "c")). - * val ignoreNonExisting = struct(Set("b", "c", "d")) - * // ignoreNonExisting: StructType = - * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) - * }}} - * - * A [[org.apache.spark.sql.Row]] object is used as a value of the StructType. - * Example: - * {{{ - * import org.apache.spark.sql._ - * - * val innerStruct = - * StructType( - * StructField("f1", IntegerType, true) :: - * StructField("f2", LongType, false) :: - * StructField("f3", BooleanType, false) :: Nil) - * - * val struct = StructType( - * StructField("a", innerStruct, true) :: Nil) - * - * // Create a Row with the schema defined by struct - * val row = Row(Row(1, 2, true)) - * // row: Row = [[1,2,true]] - * }}} - * - * @group dataType - */ -@DeveloperApi -case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { - - /** No-arg constructor for kryo. */ - protected def this() = this(null) - - /** Returns all field names in an array. */ - def fieldNames: Array[String] = fields.map(_.name) - - private lazy val fieldNamesSet: Set[String] = fieldNames.toSet - private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap - private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap - - /** - * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not - * have a name matching the given name, `null` will be returned. - */ - def apply(name: String): StructField = { - nameToField.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) - } - - /** - * Returns a [[StructType]] containing [[StructField]]s of the given names, preserving the - * original order of fields. Those names which do not have matching fields will be ignored. - */ - def apply(names: Set[String]): StructType = { - val nonExistFields = names -- fieldNamesSet - if (nonExistFields.nonEmpty) { - throw new IllegalArgumentException( - s"Field ${nonExistFields.mkString(",")} does not exist.") - } - // Preserve the original order of fields. - StructType(fields.filter(f => names.contains(f.name))) - } - - /** - * Returns index of a given field - */ - def fieldIndex(name: String): Int = { - nameToIndex.getOrElse(name, - throw new IllegalArgumentException(s"""Field "$name" does not exist.""")) - } - - protected[sql] def toAttributes: Seq[AttributeReference] = - map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) - - def treeString: String = { - val builder = new StringBuilder - builder.append("root\n") - val prefix = " |" - fields.foreach(field => field.buildFormattedString(prefix, builder)) - - builder.toString() - } - - def printTreeString(): Unit = println(treeString) - - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - fields.foreach(field => field.buildFormattedString(prefix, builder)) - } - - override private[sql] def jsonValue = - ("type" -> typeName) ~ - ("fields" -> map(_.jsonValue)) - - override def apply(fieldIndex: Int): StructField = fields(fieldIndex) - - override def length: Int = fields.length - - override def iterator: Iterator[StructField] = fields.iterator - - /** - * The default size of a value of the StructType is the total default sizes of all field types. - */ - override def defaultSize: Int = fields.map(_.dataType.defaultSize).sum - - override def simpleString: String = { - val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.simpleString}") - s"struct<${fieldTypes.mkString(",")}>" - } - - /** - * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field - * B from `that`, - * - * 1. If A and B have the same name and data type, they are merged to a field C with the same name - * and data type. C is nullable if and only if either A or B is nullable. - * 2. If A doesn't exist in `that`, it's included in the result schema. - * 3. If B doesn't exist in `this`, it's also included in the result schema. - * 4. Otherwise, `this` and `that` are considered as conflicting schemas and an exception would be - * thrown. - */ - private[sql] def merge(that: StructType): StructType = - StructType.merge(this, that).asInstanceOf[StructType] - - private[spark] override def asNullable: StructType = { - val newFields = fields.map { - case StructField(name, dataType, nullable, metadata) => - StructField(name, dataType.asNullable, nullable = true, metadata) - } - - StructType(newFields) - } -} - - -object MapType { - /** - * Construct a [[MapType]] object with the given key type and value type. - * The `valueContainsNull` is true. - */ - def apply(keyType: DataType, valueType: DataType): MapType = - MapType(keyType: DataType, valueType: DataType, valueContainsNull = true) -} - - -/** - * :: DeveloperApi :: - * The data type for Maps. Keys in a map are not allowed to have `null` values. - * - * Please use [[DataTypes.createMapType()]] to create a specific instance. - * - * @param keyType The data type of map keys. - * @param valueType The data type of map values. - * @param valueContainsNull Indicates if map values have `null` values. - * - * @group dataType - */ -case class MapType( - keyType: DataType, - valueType: DataType, - valueContainsNull: Boolean) extends DataType { - - /** No-arg constructor for kryo. */ - def this() = this(null, null, false) - - private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"$prefix-- key: ${keyType.typeName}\n") - builder.append(s"$prefix-- value: ${valueType.typeName} " + - s"(valueContainsNull = $valueContainsNull)\n") - DataType.buildFormattedString(keyType, s"$prefix |", builder) - DataType.buildFormattedString(valueType, s"$prefix |", builder) - } - - override private[sql] def jsonValue: JValue = - ("type" -> typeName) ~ - ("keyType" -> keyType.jsonValue) ~ - ("valueType" -> valueType.jsonValue) ~ - ("valueContainsNull" -> valueContainsNull) - - /** - * The default size of a value of the MapType is - * 100 * (the default size of the key type + the default size of the value type). - * (We assume that there are 100 elements). - */ - override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize) - - override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" - - private[spark] override def asNullable: MapType = - MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) -} - - -/** - * ::DeveloperApi:: - * The data type for User Defined Types (UDTs). - * - * This interface allows a user to make their own classes more interoperable with SparkSQL; - * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create - * a `DataFrame` which has class X in the schema. - * - * For SparkSQL to recognize UDTs, the UDT must be annotated with - * [[SQLUserDefinedType]]. - * - * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD. - * The conversion via `deserialize` occurs when reading from a `DataFrame`. - */ -@DeveloperApi -abstract class UserDefinedType[UserType] extends DataType with Serializable { - - /** Underlying storage type for this UDT */ - def sqlType: DataType - - /** Paired Python UDT class, if exists. */ - def pyUDT: String = null - - /** - * Convert the user type to a SQL datum - * - * TODO: Can we make this take obj: UserType? The issue is in - * CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType. - */ - def serialize(obj: Any): Any - - /** Convert a SQL datum to the user type */ - def deserialize(datum: Any): UserType - - override private[sql] def jsonValue: JValue = { - ("type" -> "udt") ~ - ("class" -> this.getClass.getName) ~ - ("pyClass" -> pyUDT) ~ - ("sqlType" -> sqlType.jsonValue) - } - - /** - * Class object for the UserType - */ - def userClass: java.lang.Class[UserType] - - /** - * The default size of a value of the UserDefinedType is 4096 bytes. - */ - override def defaultSize: Int = 4096 - - /** - * For UDT, asNullable will not change the nullability of its internal sqlType and just returns - * itself. - */ - private[spark] override def asNullable: UserDefinedType[UserType] = this -} From 73db132bf503341c7a5cf9409351c282a8464175 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Thu, 23 Apr 2015 16:08:14 -0700 Subject: [PATCH 113/191] [SPARK-6818] [SPARKR] Support column deletion in SparkR DataFrame API. Author: Sun Rui Closes #5655 from sun-rui/SPARK-6818 and squashes the following commits: 7c66570 [Sun Rui] [SPARK-6818][SPARKR] Support column deletion in SparkR DataFrame API. --- R/pkg/R/DataFrame.R | 8 +++++++- R/pkg/inst/tests/test_sparkSQL.R | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 861fe1c78b0db..b59b700af5dc9 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -790,9 +790,12 @@ setMethod("$", signature(x = "DataFrame"), setMethod("$<-", signature(x = "DataFrame"), function(x, name, value) { - stopifnot(class(value) == "Column") + stopifnot(class(value) == "Column" || is.null(value)) cols <- columns(x) if (name %in% cols) { + if (is.null(value)) { + cols <- Filter(function(c) { c != name }, cols) + } cols <- lapply(cols, function(c) { if (c == name) { alias(value, name) @@ -802,6 +805,9 @@ setMethod("$<-", signature(x = "DataFrame"), }) nx <- select(x, cols) } else { + if (is.null(value)) { + return(x) + } nx <- withColumn(x, name, value) } x@sdf <- nx@sdf diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 25831ae2d9e18..af7a6c582047a 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -449,6 +449,11 @@ test_that("select operators", { df$age2 <- df$age * 2 expect_equal(columns(df), c("name", "age", "age2")) expect_equal(count(where(df, df$age2 == df$age * 2)), 2) + + df$age2 <- NULL + expect_equal(columns(df), c("name", "age")) + df$age3 <- NULL + expect_equal(columns(df), c("name", "age")) }) test_that("select with column", { From 336f7f5373e5f6960ecd9967d3703c8507e329ec Mon Sep 17 00:00:00 2001 From: Cheolsoo Park Date: Thu, 23 Apr 2015 20:10:55 -0400 Subject: [PATCH 114/191] [SPARK-7037] [CORE] Inconsistent behavior for non-spark config properties in spark-shell and spark-submit When specifying non-spark properties (i.e. names don't start with spark.) in the command line and config file, spark-submit and spark-shell behave differently, causing confusion to users. Here is the summary- * spark-submit * --conf k=v => silently ignored * spark-defaults.conf => applied * spark-shell * --conf k=v => show a warning message and ignored * spark-defaults.conf => show a warning message and ignored I assume that ignoring non-spark properties is intentional. If so, it should always be ignored with a warning message in all cases. Author: Cheolsoo Park Closes #5617 from piaozhexiu/SPARK-7037 and squashes the following commits: 8957950 [Cheolsoo Park] Add IgnoreNonSparkProperties method fedd01c [Cheolsoo Park] Ignore non-spark properties with a warning message in all cases --- .../spark/deploy/SparkSubmitArguments.scala | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index faa8780288ea3..c896842943f2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -77,12 +77,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => Utils.getPropertiesFromFile(filename).foreach { case (k, v) => - if (k.startsWith("spark.")) { - defaultProperties(k) = v - if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") - } else { - SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") - } + defaultProperties(k) = v + if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } } defaultProperties @@ -97,6 +93,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } // Populate `sparkProperties` map from properties file mergeDefaultSparkProperties() + // Remove keys that don't start with "spark." from `sparkProperties`. + ignoreNonSparkProperties() // Use `sparkProperties` map along with env vars to fill in any missing parameters loadEnvironmentArguments() @@ -117,6 +115,18 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } } + /** + * Remove keys that don't start with "spark." from `sparkProperties`. + */ + private def ignoreNonSparkProperties(): Unit = { + sparkProperties.foreach { case (k, v) => + if (!k.startsWith("spark.")) { + sparkProperties -= k + SparkSubmit.printWarning(s"Ignoring non-spark config property: $k=$v") + } + } + } + /** * Load arguments from environment variables, Spark properties etc. */ From 2d010f7afe6ac8e67e07da6bea700e9e8c9e6cc2 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 23 Apr 2015 18:52:55 -0700 Subject: [PATCH 115/191] [SPARK-7060][SQL] Add alias function to python dataframe This pr tries to provide a way to let python users workaround https://issues.apache.org/jira/browse/SPARK-6231. Author: Yin Huai Closes #5634 from yhuai/pythonDFAlias and squashes the following commits: 8465acd [Yin Huai] Add an alias to a Python DF. --- python/pyspark/sql/dataframe.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c8c30ce4022c8..4759f5fe783ad 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -452,6 +452,20 @@ def columns(self): """ return [f.name for f in self.schema.fields] + @ignore_unicode_prefix + def alias(self, alias): + """Returns a new :class:`DataFrame` with an alias set. + + >>> from pyspark.sql.functions import * + >>> df_as1 = df.alias("df_as1") + >>> df_as2 = df.alias("df_as2") + >>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner') + >>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect() + [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)] + """ + assert isinstance(alias, basestring), "alias should be a string" + return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) + @ignore_unicode_prefix def join(self, other, joinExprs=None, joinType=None): """Joins with another :class:`DataFrame`, using the given join expression. From 67bccbda1e3ed7db2753daa7e6ae8b1441356177 Mon Sep 17 00:00:00 2001 From: Ken Geis Date: Thu, 23 Apr 2015 20:45:33 -0700 Subject: [PATCH 116/191] Update sql-programming-guide.md fix typo Author: Ken Geis Closes #5674 from kgeis/patch-1 and squashes the following commits: 5ae67de [Ken Geis] Update sql-programming-guide.md --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b2022546268a7..49b1e69f0e9db 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1364,7 +1364,7 @@ the Data Sources API. The following options are supported: driver - The class name of the JDBC driver needed to connect to this URL. This class with be loaded + The class name of the JDBC driver needed to connect to this URL. This class will be loaded on the master and workers before running an JDBC commands to allow the driver to register itself with the JDBC subsystem. From d3a302defc45768492dec9da4c40d78d28997a65 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 23 Apr 2015 21:21:03 -0700 Subject: [PATCH 117/191] [SQL] Fixed expression data type matching. Also took the chance to improve documentation for various types. Author: Reynold Xin Closes #5675 from rxin/data-type-matching-expr and squashes the following commits: 0f31856 [Reynold Xin] One more function documentation. 27c1973 [Reynold Xin] Added more documentation. 336a36d [Reynold Xin] [SQL] Fixed expression data type matching. --- .../expressions/codegen/CodeGenerator.scala | 2 +- .../org/apache/spark/sql/types/DataType.scala | 50 +++++++++++++++---- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cbe520347385d..dbc92fb93e95e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -279,7 +279,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin org.apache.spark.sql.types.UTF8String(${eval.primitiveTerm}.toString) """.children - case EqualTo(e1: BinaryType, e2: BinaryType) => + case EqualTo(e1 @ BinaryType(), e2 @ BinaryType()) => (e1, e2).evaluateAs (BooleanType) { case (eval1, eval2) => q""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index e6bfcd9adfeb1..06bff7d70edbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -40,32 +40,46 @@ import org.apache.spark.util.Utils */ @DeveloperApi abstract class DataType { - /** Matches any expression that evaluates to this DataType */ - def unapply(a: Expression): Boolean = a match { + /** + * Enables matching against NumericType for expressions: + * {{{ + * case Cast(child @ BinaryType(), StringType) => + * ... + * }}} + */ + private[sql] def unapply(a: Expression): Boolean = a match { case e: Expression if e.dataType == this => true case _ => false } - /** The default size of a value of this data type. */ + /** + * The default size of a value of this data type, used internally for size estimation. + */ def defaultSize: Int + /** Name of the type used in JSON serialization. */ def typeName: String = this.getClass.getSimpleName.stripSuffix("$").dropRight(4).toLowerCase private[sql] def jsonValue: JValue = typeName + /** The compact JSON representation of this data type. */ def json: String = compact(render(jsonValue)) + /** The pretty (i.e. indented) JSON representation of this data type. */ def prettyJson: String = pretty(render(jsonValue)) + /** Readable string representation for the type. */ def simpleString: String = typeName - /** Check if `this` and `other` are the same data type when ignoring nullability - * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + /** + * Check if `this` and `other` are the same data type when ignoring nullability + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = DataType.equalsIgnoreNullability(this, other) - /** Returns the same data type but set all nullability fields are true + /** + * Returns the same data type but set all nullability fields are true * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def asNullable: DataType @@ -104,12 +118,25 @@ abstract class NumericType extends AtomicType { private[sql] object NumericType { + /** + * Enables matching against NumericType for expressions: + * {{{ + * case Cast(child @ NumericType(), StringType) => + * ... + * }}} + */ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType] } -/** Matcher for any expressions that evaluate to [[IntegralType]]s */ private[sql] object IntegralType { + /** + * Enables matching against IntegralType for expressions: + * {{{ + * case Cast(child @ IntegralType(), StringType) => + * ... + * }}} + */ def unapply(a: Expression): Boolean = a match { case e: Expression if e.dataType.isInstanceOf[IntegralType] => true case _ => false @@ -122,9 +149,14 @@ private[sql] abstract class IntegralType extends NumericType { } - -/** Matcher for any expressions that evaluate to [[FractionalType]]s */ private[sql] object FractionalType { + /** + * Enables matching against FractionalType for expressions: + * {{{ + * case Cast(child @ FractionalType(), StringType) => + * ... + * }}} + */ def unapply(a: Expression): Boolean = a match { case e: Expression if e.dataType.isInstanceOf[FractionalType] => true case _ => false From 4c722d77ae7e77eeaa7531687fa9bd6050344d18 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 23 Apr 2015 22:39:00 -0700 Subject: [PATCH 118/191] Fixed a typo from the previous commit. --- .../src/main/scala/org/apache/spark/sql/types/DataType.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 06bff7d70edbc..0992a7c311ee2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -41,7 +41,7 @@ import org.apache.spark.util.Utils @DeveloperApi abstract class DataType { /** - * Enables matching against NumericType for expressions: + * Enables matching against DataType for expressions: * {{{ * case Cast(child @ BinaryType(), StringType) => * ... From 8509519d8bcf99e2d1b5e21da514d51357f9116d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 24 Apr 2015 00:39:29 -0700 Subject: [PATCH 119/191] [SPARK-5894] [ML] Add polynomial mapper See [SPARK-5894](https://issues.apache.org/jira/browse/SPARK-5894). Author: Xusen Yin Author: Xiangrui Meng Closes #5245 from yinxusen/SPARK-5894 and squashes the following commits: dc461a6 [Xusen Yin] merge polynomial expansion v2 6d0c3cc [Xusen Yin] Merge branch 'SPARK-5894' of https://github.com/mengxr/spark into mengxr-SPARK-5894 57bfdd5 [Xusen Yin] Merge branch 'master' into SPARK-5894 3d02a7d [Xusen Yin] Merge branch 'master' into SPARK-5894 a067da2 [Xiangrui Meng] a new approach for poly expansion 0789d81 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5894 4e9aed0 [Xusen Yin] fix test suite 95d8fb9 [Xusen Yin] fix sparse vector indices 8d39674 [Xusen Yin] fix sparse vector expansion error 5998dd6 [Xusen Yin] fix dense vector fillin fa3ade3 [Xusen Yin] change the functional code into imperative one to speedup b70e7e1 [Xusen Yin] remove useless case class 6fa236f [Xusen Yin] fix vector slice error daff601 [Xusen Yin] fix index error of sparse vector 6bd0a10 [Xusen Yin] merge repeated features 419f8a2 [Xusen Yin] need to merge same columns 4ebf34e [Xusen Yin] add test suite of polynomial expansion 372227c [Xusen Yin] add polynomial expansion --- .../ml/feature/PolynomialExpansion.scala | 167 ++++++++++++++++++ .../ml/feature/PolynomialExpansionSuite.scala | 104 +++++++++++ 2 files changed, 271 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala new file mode 100644 index 0000000000000..c3a59a361d0e2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.sql.types.DataType + +/** + * :: AlphaComponent :: + * Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, + * which is available at [[http://en.wikipedia.org/wiki/Polynomial_expansion]], "In mathematics, an + * expansion of a product of sums expresses it as a sum of products by using the fact that + * multiplication distributes over addition". Take a 2-variable feature vector as an example: + * `(x, y)`, if we want to expand it with degree 2, then we get `(x, y, x * x, x * y, y * y)`. + */ +@AlphaComponent +class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { + + /** + * The polynomial degree to expand, which should be larger than 1. + * @group param + */ + val degree = new IntParam(this, "degree", "the polynomial degree to expand") + setDefault(degree -> 2) + + /** @group getParam */ + def getDegree: Int = getOrDefault(degree) + + /** @group setParam */ + def setDegree(value: Int): this.type = set(degree, value) + + override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { v => + val d = paramMap(degree) + PolynomialExpansion.expand(v, d) + } + + override protected def outputDataType: DataType = new VectorUDT() +} + +/** + * The expansion is done via recursion. Given n features and degree d, the size after expansion is + * (n + d choose d) (including 1 and first-order values). For example, let f([a, b, c], 3) be the + * function that expands [a, b, c] to their monomials of degree 3. We have the following recursion: + * + * {{{ + * f([a, b, c], 3) = f([a, b], 3) ++ f([a, b], 2) * c ++ f([a, b], 1) * c^2 ++ [c^3] + * }}} + * + * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the + * current index and increment it properly for sparse input. + */ +object PolynomialExpansion { + + private def choose(n: Int, k: Int): Int = { + Range(n, n - k, -1).product / Range(k, 1, -1).product + } + + private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree) + + private def expandDense( + values: Array[Double], + lastIdx: Int, + degree: Int, + multiplier: Double, + polyValues: Array[Double], + curPolyIdx: Int): Int = { + if (multiplier == 0.0) { + // do nothing + } else if (degree == 0 || lastIdx < 0) { + polyValues(curPolyIdx) = multiplier + } else { + val v = values(lastIdx) + val lastIdx1 = lastIdx - 1 + var alpha = multiplier + var i = 0 + var curStart = curPolyIdx + while (i <= degree && alpha != 0.0) { + curStart = expandDense(values, lastIdx1, degree - i, alpha, polyValues, curStart) + i += 1 + alpha *= v + } + } + curPolyIdx + getPolySize(lastIdx + 1, degree) + } + + private def expandSparse( + indices: Array[Int], + values: Array[Double], + lastIdx: Int, + lastFeatureIdx: Int, + degree: Int, + multiplier: Double, + polyIndices: mutable.ArrayBuilder[Int], + polyValues: mutable.ArrayBuilder[Double], + curPolyIdx: Int): Int = { + if (multiplier == 0.0) { + // do nothing + } else if (degree == 0 || lastIdx < 0) { + polyIndices += curPolyIdx + polyValues += multiplier + } else { + // Skip all zeros at the tail. + val v = values(lastIdx) + val lastIdx1 = lastIdx - 1 + val lastFeatureIdx1 = indices(lastIdx) - 1 + var alpha = multiplier + var curStart = curPolyIdx + var i = 0 + while (i <= degree && alpha != 0.0) { + curStart = expandSparse(indices, values, lastIdx1, lastFeatureIdx1, degree - i, alpha, + polyIndices, polyValues, curStart) + i += 1 + alpha *= v + } + } + curPolyIdx + getPolySize(lastFeatureIdx + 1, degree) + } + + private def expand(dv: DenseVector, degree: Int): DenseVector = { + val n = dv.size + val polySize = getPolySize(n, degree) + val polyValues = new Array[Double](polySize) + expandDense(dv.values, n - 1, degree, 1.0, polyValues, 0) + new DenseVector(polyValues) + } + + private def expand(sv: SparseVector, degree: Int): SparseVector = { + val polySize = getPolySize(sv.size, degree) + val nnz = sv.values.length + val nnzPolySize = getPolySize(nnz, degree) + val polyIndices = mutable.ArrayBuilder.make[Int] + polyIndices.sizeHint(nnzPolySize) + val polyValues = mutable.ArrayBuilder.make[Double] + polyValues.sizeHint(nnzPolySize) + expandSparse( + sv.indices, sv.values, nnz - 1, sv.size - 1, degree, 1.0, polyIndices, polyValues, 0) + new SparseVector(polySize, polyIndices.result(), polyValues.result()) + } + + def expand(v: Vector, degree: Int): Vector = { + v match { + case dv: DenseVector => expand(dv, degree) + case sv: SparseVector => expand(sv, degree) + case _ => throw new IllegalArgumentException + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala new file mode 100644 index 0000000000000..b0a537be42dfd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} +import org.scalatest.exceptions.TestFailedException + +class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + test("Polynomial expansion with default parameter") { + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq()) + ) + + val twoDegreeExpansion: Array[Vector] = Array( + Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5), Array(1.0, -2.0, 4.0, 2.3, -4.6, 5.29)), + Vectors.dense(1.0, -2.0, 4.0, 2.3, -4.6, 5.29), + Vectors.dense(Array(1.0) ++ Array.fill[Double](9)(0.0)), + Vectors.dense(1.0, 0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0), + Vectors.sparse(10, Array(0), Array(1.0))) + + val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") + + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + + polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + case Row(expanded: DenseVector, expected: DenseVector) => + assert(expanded ~== expected absTol 1e-1) + case Row(expanded: SparseVector, expected: SparseVector) => + assert(expanded ~== expected absTol 1e-1) + case _ => + throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + } + } + + test("Polynomial expansion with setter") { + val data = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.dense(-2.0, 2.3), + Vectors.dense(0.0, 0.0, 0.0), + Vectors.dense(0.6, -1.1, -3.0), + Vectors.sparse(3, Seq()) + ) + + val threeDegreeExpansion: Array[Vector] = Array( + Vectors.sparse(20, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + Array(1.0, -2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)), + Vectors.dense(1.0, -2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17), + Vectors.dense(Array(1.0) ++ Array.fill[Double](19)(0.0)), + Vectors.dense(1.0, 0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8, + -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), + Vectors.sparse(20, Array(0), Array(1.0))) + + val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") + + val polynomialExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3) + + polynomialExpansion.transform(df).select("polyFeatures", "expected").collect().foreach { + case Row(expanded: DenseVector, expected: DenseVector) => + assert(expanded ~== expected absTol 1e-1) + case Row(expanded: SparseVector, expected: SparseVector) => + assert(expanded ~== expected absTol 1e-1) + case _ => + throw new TestFailedException("Unmatched data types after polynomial expansion", 0) + } + } +} + From 78b39c7e0de8c9dc748cfbf8f78578a9524b6a94 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 24 Apr 2015 08:27:48 -0700 Subject: [PATCH 120/191] [SPARK-7115] [MLLIB] skip the very first 1 in poly expansion yinxusen Author: Xiangrui Meng Closes #5681 from mengxr/SPARK-7115 and squashes the following commits: 9ac27cd [Xiangrui Meng] skip the very first 1 in poly expansion --- .../ml/feature/PolynomialExpansion.scala | 22 +++++++++++-------- .../ml/feature/PolynomialExpansionSuite.scala | 22 +++++++++---------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index c3a59a361d0e2..d855f04799ae7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -87,7 +87,9 @@ object PolynomialExpansion { if (multiplier == 0.0) { // do nothing } else if (degree == 0 || lastIdx < 0) { - polyValues(curPolyIdx) = multiplier + if (curPolyIdx >= 0) { // skip the very first 1 + polyValues(curPolyIdx) = multiplier + } } else { val v = values(lastIdx) val lastIdx1 = lastIdx - 1 @@ -116,8 +118,10 @@ object PolynomialExpansion { if (multiplier == 0.0) { // do nothing } else if (degree == 0 || lastIdx < 0) { - polyIndices += curPolyIdx - polyValues += multiplier + if (curPolyIdx >= 0) { // skip the very first 1 + polyIndices += curPolyIdx + polyValues += multiplier + } } else { // Skip all zeros at the tail. val v = values(lastIdx) @@ -139,8 +143,8 @@ object PolynomialExpansion { private def expand(dv: DenseVector, degree: Int): DenseVector = { val n = dv.size val polySize = getPolySize(n, degree) - val polyValues = new Array[Double](polySize) - expandDense(dv.values, n - 1, degree, 1.0, polyValues, 0) + val polyValues = new Array[Double](polySize - 1) + expandDense(dv.values, n - 1, degree, 1.0, polyValues, -1) new DenseVector(polyValues) } @@ -149,12 +153,12 @@ object PolynomialExpansion { val nnz = sv.values.length val nnzPolySize = getPolySize(nnz, degree) val polyIndices = mutable.ArrayBuilder.make[Int] - polyIndices.sizeHint(nnzPolySize) + polyIndices.sizeHint(nnzPolySize - 1) val polyValues = mutable.ArrayBuilder.make[Double] - polyValues.sizeHint(nnzPolySize) + polyValues.sizeHint(nnzPolySize - 1) expandSparse( - sv.indices, sv.values, nnz - 1, sv.size - 1, degree, 1.0, polyIndices, polyValues, 0) - new SparseVector(polySize, polyIndices.result(), polyValues.result()) + sv.indices, sv.values, nnz - 1, sv.size - 1, degree, 1.0, polyIndices, polyValues, -1) + new SparseVector(polySize - 1, polyIndices.result(), polyValues.result()) } def expand(v: Vector, degree: Int): Vector = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index b0a537be42dfd..c1d64fba0aa8f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -44,11 +44,11 @@ class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext { ) val twoDegreeExpansion: Array[Vector] = Array( - Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5), Array(1.0, -2.0, 4.0, 2.3, -4.6, 5.29)), - Vectors.dense(1.0, -2.0, 4.0, 2.3, -4.6, 5.29), - Vectors.dense(Array(1.0) ++ Array.fill[Double](9)(0.0)), - Vectors.dense(1.0, 0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0), - Vectors.sparse(10, Array(0), Array(1.0))) + Vectors.sparse(9, Array(0, 1, 2, 3, 4), Array(-2.0, 4.0, 2.3, -4.6, 5.29)), + Vectors.dense(-2.0, 4.0, 2.3, -4.6, 5.29), + Vectors.dense(new Array[Double](9)), + Vectors.dense(0.6, 0.36, -1.1, -0.66, 1.21, -3.0, -1.8, 3.3, 9.0), + Vectors.sparse(9, Array.empty, Array.empty)) val df = sqlContext.createDataFrame(data.zip(twoDegreeExpansion)).toDF("features", "expected") @@ -76,13 +76,13 @@ class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext { ) val threeDegreeExpansion: Array[Vector] = Array( - Vectors.sparse(20, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), - Array(1.0, -2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)), - Vectors.dense(1.0, -2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17), - Vectors.dense(Array(1.0) ++ Array.fill[Double](19)(0.0)), - Vectors.dense(1.0, 0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8, + Vectors.sparse(19, Array(0, 1, 2, 3, 4, 5, 6, 7, 8), + Array(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)), + Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17), + Vectors.dense(new Array[Double](19)), + Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331, -3.0, -1.8, -1.08, 3.3, 1.98, -3.63, 9.0, 5.4, -9.9, -27.0), - Vectors.sparse(20, Array(0), Array(1.0))) + Vectors.sparse(19, Array.empty, Array.empty)) val df = sqlContext.createDataFrame(data.zip(threeDegreeExpansion)).toDF("features", "expected") From 6e57d57b32ba2aa0514692074897b5edd34e0dd6 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 24 Apr 2015 08:29:49 -0700 Subject: [PATCH 121/191] [SPARK-6528] [ML] Add IDF transformer See [SPARK-6528](https://issues.apache.org/jira/browse/SPARK-6528). Add IDF transformer in ML package. Author: Xusen Yin Closes #5266 from yinxusen/SPARK-6528 and squashes the following commits: 741db31 [Xusen Yin] get param from new paramMap d169967 [Xusen Yin] add final to param and IDF class c9c3759 [Xusen Yin] simplify test suite 5867c09 [Xusen Yin] refine IDF transformer with new interfaces 7727cae [Xusen Yin] Merge branch 'master' into SPARK-6528 4338a37 [Xusen Yin] Merge branch 'master' into SPARK-6528 aef2cdf [Xusen Yin] add doc and group for param 5760b49 [Xusen Yin] fix code style 2add691 [Xusen Yin] fix code style and test 03fbecb [Xusen Yin] remove duplicated code 2aa4be0 [Xusen Yin] clean test suite 4802c67 [Xusen Yin] add IDF transformer and test suite --- .../org/apache/spark/ml/feature/IDF.scala | 116 ++++++++++++++++++ .../apache/spark/ml/feature/IDFSuite.scala | 101 +++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala new file mode 100644 index 0000000000000..e6a62d998bb97 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType + +/** + * Params for [[IDF]] and [[IDFModel]]. + */ +private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol { + + /** + * The minimum of documents in which a term should appear. + * @group param + */ + final val minDocFreq = new IntParam( + this, "minDocFreq", "minimum of documents in which a term should appear for filtering") + + setDefault(minDocFreq -> 0) + + /** @group getParam */ + def getMinDocFreq: Int = getOrDefault(minDocFreq) + + /** @group setParam */ + def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) + + /** + * Validate and transform the input schema. + */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), new VectorUDT) + SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT) + } +} + +/** + * :: AlphaComponent :: + * Compute the Inverse Document Frequency (IDF) given a collection of documents. + */ +@AlphaComponent +final class IDF extends Estimator[IDFModel] with IDFBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): IDFModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } + val idf = new feature.IDF(map(minDocFreq)).fit(input) + val model = new IDFModel(this, map, idf) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[IDF]]. + */ +@AlphaComponent +class IDFModel private[ml] ( + override val parent: IDF, + override val fittingParamMap: ParamMap, + idfModel: feature.IDFModel) + extends Model[IDFModel] with IDFBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val idf = udf { vec: Vector => idfModel.transform(vec) } + dataset.withColumn(map(outputCol), idf(col(map(inputCol)))) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala new file mode 100644 index 0000000000000..eaee3443c1f23 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} + +class IDFSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + } + + def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { + dataSet.map { + case data: DenseVector => + val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y } + Vectors.dense(res) + case data: SparseVector => + val res = data.indices.zip(data.values).map { case (id, value) => + (id, value * model(id)) + } + Vectors.sparse(data.size, res) + } + } + + test("compute IDF with default parameter") { + val numOfFeatures = 4 + val data = Array( + Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) + ) + val numOfData = data.size + val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => + math.log((numOfData + 1.0) / (x + 1.0)) + }) + val expected = scaleDataWithIDF(data, idf) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + + val idfModel = new IDF() + .setInputCol("features") + .setOutputCol("idfValue") + .fit(df) + + idfModel.transform(df).select("idfValue", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + } + } + + test("compute IDF with setter") { + val numOfFeatures = 4 + val data = Array( + Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(numOfFeatures, Array(1), Array(1.0)) + ) + val numOfData = data.size + val idf = Vectors.dense(Array(0, 3, 1, 2).map { x => + if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0 + }) + val expected = scaleDataWithIDF(data, idf) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + + val idfModel = new IDF() + .setInputCol("features") + .setOutputCol("idfValue") + .setMinDocFreq(1) + .fit(df) + + idfModel.transform(df).select("idfValue", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + } + } +} From ebb77b2aff085e71906b5de9d266ded89051af82 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Fri, 24 Apr 2015 11:00:19 -0700 Subject: [PATCH 122/191] [SPARK-7033] [SPARKR] Clean usage of split. Use partition instead where applicable. Author: Sun Rui Closes #5628 from sun-rui/SPARK-7033 and squashes the following commits: 046bc9e [Sun Rui] Clean split usage in tests. d531c86 [Sun Rui] [SPARK-7033][SPARKR] Clean usage of split. Use partition instead where applicable. --- R/pkg/R/RDD.R | 36 ++++++++++++++++++------------------ R/pkg/R/context.R | 20 ++++++++++---------- R/pkg/R/pairRDD.R | 8 ++++---- R/pkg/R/utils.R | 2 +- R/pkg/inst/tests/test_rdd.R | 12 ++++++------ 5 files changed, 39 insertions(+), 39 deletions(-) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 128431334ca52..cc09efb1e5418 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -91,8 +91,8 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) # NOTE: We use prev_serializedMode to track the serialization mode of prev_JRDD # prev_serializedMode is used during the delayed computation of JRDD in getJRDD } else { - pipelinedFunc <- function(split, iterator) { - func(split, prev@func(split, iterator)) + pipelinedFunc <- function(partIndex, part) { + func(partIndex, prev@func(partIndex, part)) } .Object@func <- cleanClosure(pipelinedFunc) .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline @@ -306,7 +306,7 @@ setMethod("numPartitions", signature(x = "RDD"), function(x) { jrdd <- getJRDD(x) - partitions <- callJMethod(jrdd, "splits") + partitions <- callJMethod(jrdd, "partitions") callJMethod(partitions, "size") }) @@ -452,8 +452,8 @@ setMethod("countByValue", setMethod("lapply", signature(X = "RDD", FUN = "function"), function(X, FUN) { - func <- function(split, iterator) { - lapply(iterator, FUN) + func <- function(partIndex, part) { + lapply(part, FUN) } lapplyPartitionsWithIndex(X, func) }) @@ -538,8 +538,8 @@ setMethod("mapPartitions", #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10, 5L) -#' prod <- lapplyPartitionsWithIndex(rdd, function(split, part) { -#' split * Reduce("+", part) }) +#' prod <- lapplyPartitionsWithIndex(rdd, function(partIndex, part) { +#' partIndex * Reduce("+", part) }) #' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76 #'} #' @rdname lapplyPartitionsWithIndex @@ -813,7 +813,7 @@ setMethod("distinct", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' rdd <- parallelize(sc, 1:10) # ensure each num is in its own split +#' rdd <- parallelize(sc, 1:10) #' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements #' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates #'} @@ -825,14 +825,14 @@ setMethod("sampleRDD", function(x, withReplacement, fraction, seed) { # The sampler: takes a partition and returns its sampled version. - samplingFunc <- function(split, part) { + samplingFunc <- function(partIndex, part) { set.seed(seed) res <- vector("list", length(part)) len <- 0 # Discards some random values to ensure each partition has a # different random seed. - runif(split) + runif(partIndex) for (elem in part) { if (withReplacement) { @@ -989,8 +989,8 @@ setMethod("coalesce", function(x, numPartitions, shuffle = FALSE) { numPartitions <- numToInt(numPartitions) if (shuffle || numPartitions > SparkR::numPartitions(x)) { - func <- function(s, part) { - set.seed(s) # split as seed + func <- function(partIndex, part) { + set.seed(partIndex) # partIndex as seed start <- as.integer(sample(numPartitions, 1) - 1) lapply(seq_along(part), function(i) { @@ -1035,7 +1035,7 @@ setMethod("saveAsObjectFile", #' Save this RDD as a text file, using string representations of elements. #' #' @param x The RDD to save -#' @param path The directory where the splits of the text file are saved +#' @param path The directory where the partitions of the text file are saved #' @examples #'\dontrun{ #' sc <- sparkR.init() @@ -1335,10 +1335,10 @@ setMethod("zipWithUniqueId", function(x) { n <- numPartitions(x) - partitionFunc <- function(split, part) { + partitionFunc <- function(partIndex, part) { mapply( function(item, index) { - list(item, (index - 1) * n + split) + list(item, (index - 1) * n + partIndex) }, part, seq_along(part), @@ -1382,11 +1382,11 @@ setMethod("zipWithIndex", startIndices <- Reduce("+", nums, accumulate = TRUE) } - partitionFunc <- function(split, part) { - if (split == 0) { + partitionFunc <- function(partIndex, part) { + if (partIndex == 0) { startIndex <- 0 } else { - startIndex <- startIndices[[split]] + startIndex <- startIndices[[partIndex]] } mapply( diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index ebbb8fba1052d..b4845b6948997 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -17,12 +17,12 @@ # context.R: SparkContext driven functions -getMinSplits <- function(sc, minSplits) { - if (is.null(minSplits)) { +getMinPartitions <- function(sc, minPartitions) { + if (is.null(minPartitions)) { defaultParallelism <- callJMethod(sc, "defaultParallelism") - minSplits <- min(defaultParallelism, 2) + minPartitions <- min(defaultParallelism, 2) } - as.integer(minSplits) + as.integer(minPartitions) } #' Create an RDD from a text file. @@ -33,7 +33,7 @@ getMinSplits <- function(sc, minSplits) { #' #' @param sc SparkContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. -#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' @param minPartitions Minimum number of partitions to be created. If NULL, the default #' value is chosen based on available parallelism. #' @return RDD where each item is of type \code{character} #' @export @@ -42,13 +42,13 @@ getMinSplits <- function(sc, minSplits) { #' sc <- sparkR.init() #' lines <- textFile(sc, "myfile.txt") #'} -textFile <- function(sc, path, minSplits = NULL) { +textFile <- function(sc, path, minPartitions = NULL) { # Allow the user to have a more flexible definiton of the text file path path <- suppressWarnings(normalizePath(path)) #' Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") - jrdd <- callJMethod(sc, "textFile", path, getMinSplits(sc, minSplits)) + jrdd <- callJMethod(sc, "textFile", path, getMinPartitions(sc, minPartitions)) # jrdd is of type JavaRDD[String] RDD(jrdd, "string") } @@ -60,7 +60,7 @@ textFile <- function(sc, path, minSplits = NULL) { #' #' @param sc SparkContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. -#' @param minSplits Minimum number of splits to be created. If NULL, the default +#' @param minPartitions Minimum number of partitions to be created. If NULL, the default #' value is chosen based on available parallelism. #' @return RDD containing serialized R objects. #' @seealso saveAsObjectFile @@ -70,13 +70,13 @@ textFile <- function(sc, path, minSplits = NULL) { #' sc <- sparkR.init() #' rdd <- objectFile(sc, "myfile") #'} -objectFile <- function(sc, path, minSplits = NULL) { +objectFile <- function(sc, path, minPartitions = NULL) { # Allow the user to have a more flexible definiton of the text file path path <- suppressWarnings(normalizePath(path)) #' Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") - jrdd <- callJMethod(sc, "objectFile", path, getMinSplits(sc, minSplits)) + jrdd <- callJMethod(sc, "objectFile", path, getMinPartitions(sc, minPartitions)) # Assume the RDD contains serialized R objects. RDD(jrdd, "byte") } diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 13efebc11c46e..f99b474ff8f2a 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -206,8 +206,8 @@ setMethod("partitionBy", get(name, .broadcastNames) }) jrdd <- getJRDD(x) - # We create a PairwiseRRDD that extends RDD[(Array[Byte], - # Array[Byte])], where the key is the hashed split, the value is + # We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])], + # where the key is the target partition number, the value is # the content (key-val pairs). pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD", callJMethod(jrdd, "rdd"), @@ -866,8 +866,8 @@ setMethod("sampleByKey", } # The sampler: takes a partition and returns its sampled version. - samplingFunc <- function(split, part) { - set.seed(bitwXor(seed, split)) + samplingFunc <- function(partIndex, part) { + set.seed(bitwXor(seed, partIndex)) res <- vector("list", length(part)) len <- 0 diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 23305d3c67074..0e7b7bd5a5b34 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -501,7 +501,7 @@ appendPartitionLengths <- function(x, other) { # A result RDD. mergePartitions <- function(rdd, zip) { serializerMode <- getSerializedMode(rdd) - partitionFunc <- function(split, part) { + partitionFunc <- function(partIndex, part) { len <- length(part) if (len > 0) { if (serializerMode == "byte") { diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 3ba7d1716302a..d55af93e3e50a 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -105,8 +105,8 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { rdd2 <- rdd for (i in 1:12) rdd2 <- lapplyPartitionsWithIndex( - rdd2, function(split, part) { - part <- as.list(unlist(part) * split + i) + rdd2, function(partIndex, part) { + part <- as.list(unlist(part) * partIndex + i) }) rdd2 <- lapply(rdd2, function(x) x + x) actual <- collect(rdd2) @@ -121,8 +121,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp # PipelinedRDD rdd2 <- lapplyPartitionsWithIndex( rdd2, - function(split, part) { - part <- as.list(unlist(part) * split) + function(partIndex, part) { + part <- as.list(unlist(part) * partIndex) }) cache(rdd2) @@ -174,13 +174,13 @@ test_that("lapply with dependency", { }) test_that("lapplyPartitionsWithIndex on RDDs", { - func <- function(splitIndex, part) { list(splitIndex, Reduce("+", part)) } + func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L) partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 } - mkTup <- function(splitIndex, part) { list(splitIndex, part) } + mkTup <- function(partIndex, part) { list(partIndex, part) } actual <- collect(lapplyPartitionsWithIndex( partitionBy(pairsRDD, 2L, partitionByParity), mkTup), From caf0136ec5838cf5bf61f39a5b3474a505a6ae11 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Fri, 24 Apr 2015 12:52:07 -0700 Subject: [PATCH 123/191] [SPARK-6852] [SPARKR] Accept numeric as numPartitions in SparkR. Author: Sun Rui Closes #5613 from sun-rui/SPARK-6852 and squashes the following commits: abaf02e [Sun Rui] Change the type of default numPartitions from integer to numeric in generics.R. 29d67c1 [Sun Rui] [SPARK-6852][SPARKR] Accept numeric as numPartitions in SparkR. --- R/pkg/R/RDD.R | 2 +- R/pkg/R/generics.R | 12 ++++++------ R/pkg/R/pairRDD.R | 24 ++++++++++++------------ 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index cc09efb1e5418..1662d6bb3b1ac 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -967,7 +967,7 @@ setMethod("keyBy", setMethod("repartition", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { - coalesce(x, numToInt(numPartitions), TRUE) + coalesce(x, numPartitions, TRUE) }) #' Return a new RDD that is reduced into numPartitions partitions. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 6c6233390134c..34dbe84051c50 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -60,7 +60,7 @@ setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) #' @rdname distinct #' @export -setGeneric("distinct", function(x, numPartitions = 1L) { standardGeneric("distinct") }) +setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) #' @rdname filterRDD #' @export @@ -182,7 +182,7 @@ setGeneric("setName", function(x, name) { standardGeneric("setName") }) #' @rdname sortBy #' @export setGeneric("sortBy", - function(x, func, ascending = TRUE, numPartitions = 1L) { + function(x, func, ascending = TRUE, numPartitions = 1) { standardGeneric("sortBy") }) @@ -244,7 +244,7 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") #' @rdname intersection #' @export -setGeneric("intersection", function(x, other, numPartitions = 1L) { +setGeneric("intersection", function(x, other, numPartitions = 1) { standardGeneric("intersection") }) #' @rdname keys @@ -346,21 +346,21 @@ setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("ri #' @rdname sortByKey #' @export setGeneric("sortByKey", - function(x, ascending = TRUE, numPartitions = 1L) { + function(x, ascending = TRUE, numPartitions = 1) { standardGeneric("sortByKey") }) #' @rdname subtract #' @export setGeneric("subtract", - function(x, other, numPartitions = 1L) { + function(x, other, numPartitions = 1) { standardGeneric("subtract") }) #' @rdname subtractByKey #' @export setGeneric("subtractByKey", - function(x, other, numPartitions = 1L) { + function(x, other, numPartitions = 1) { standardGeneric("subtractByKey") }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index f99b474ff8f2a..9791e55791bae 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -190,7 +190,7 @@ setMethod("flatMapValues", #' @rdname partitionBy #' @aliases partitionBy,RDD,integer-method setMethod("partitionBy", - signature(x = "RDD", numPartitions = "integer"), + signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, partitionFunc = hashCode) { #if (missing(partitionFunc)) { @@ -211,7 +211,7 @@ setMethod("partitionBy", # the content (key-val pairs). pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD", callJMethod(jrdd, "rdd"), - as.integer(numPartitions), + numToInt(numPartitions), serializedHashFuncBytes, getSerializedMode(x), packageNamesArr, @@ -221,7 +221,7 @@ setMethod("partitionBy", # Create a corresponding partitioner. rPartitioner <- newJObject("org.apache.spark.HashPartitioner", - as.integer(numPartitions)) + numToInt(numPartitions)) # Call partitionBy on the obtained PairwiseRDD. javaPairRDD <- callJMethod(pairwiseRRDD, "asJavaPairRDD") @@ -256,7 +256,7 @@ setMethod("partitionBy", #' @rdname groupByKey #' @aliases groupByKey,RDD,integer-method setMethod("groupByKey", - signature(x = "RDD", numPartitions = "integer"), + signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { shuffled <- partitionBy(x, numPartitions) groupVals <- function(part) { @@ -315,7 +315,7 @@ setMethod("groupByKey", #' @rdname reduceByKey #' @aliases reduceByKey,RDD,integer-method setMethod("reduceByKey", - signature(x = "RDD", combineFunc = "ANY", numPartitions = "integer"), + signature(x = "RDD", combineFunc = "ANY", numPartitions = "numeric"), function(x, combineFunc, numPartitions) { reduceVals <- function(part) { vals <- new.env() @@ -422,7 +422,7 @@ setMethod("reduceByKeyLocally", #' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method setMethod("combineByKey", signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", - mergeCombiners = "ANY", numPartitions = "integer"), + mergeCombiners = "ANY", numPartitions = "numeric"), function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { combineLocally <- function(part) { combiners <- new.env() @@ -483,7 +483,7 @@ setMethod("combineByKey", #' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method setMethod("aggregateByKey", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", - combOp = "ANY", numPartitions = "integer"), + combOp = "ANY", numPartitions = "numeric"), function(x, zeroValue, seqOp, combOp, numPartitions) { createCombiner <- function(v) { do.call(seqOp, list(zeroValue, v)) @@ -514,7 +514,7 @@ setMethod("aggregateByKey", #' @aliases foldByKey,RDD,ANY,ANY,integer-method setMethod("foldByKey", signature(x = "RDD", zeroValue = "ANY", - func = "ANY", numPartitions = "integer"), + func = "ANY", numPartitions = "numeric"), function(x, zeroValue, func, numPartitions) { aggregateByKey(x, zeroValue, func, func, numPartitions) }) @@ -553,7 +553,7 @@ setMethod("join", joinTaggedList(v, list(FALSE, FALSE)) } - joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)), + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) @@ -582,7 +582,7 @@ setMethod("join", #' @rdname join-methods #' @aliases leftOuterJoin,RDD,RDD-method setMethod("leftOuterJoin", - signature(x = "RDD", y = "RDD", numPartitions = "integer"), + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) @@ -619,7 +619,7 @@ setMethod("leftOuterJoin", #' @rdname join-methods #' @aliases rightOuterJoin,RDD,RDD-method setMethod("rightOuterJoin", - signature(x = "RDD", y = "RDD", numPartitions = "integer"), + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) @@ -659,7 +659,7 @@ setMethod("rightOuterJoin", #' @rdname join-methods #' @aliases fullOuterJoin,RDD,RDD-method setMethod("fullOuterJoin", - signature(x = "RDD", y = "RDD", numPartitions = "integer"), + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) From 438859eb7c4e605bb4041d9a547a16be9c827c75 Mon Sep 17 00:00:00 2001 From: Calvin Jia Date: Fri, 24 Apr 2015 17:57:41 -0400 Subject: [PATCH 124/191] [SPARK-6122] [CORE] Upgrade tachyon-client version to 0.6.3 This is a reopening of #4867. A short summary of the issues resolved from the previous PR: 1. HTTPClient version mismatch: Selenium (used for UI tests) requires version 4.3.x, and Tachyon included 4.2.5 through a transitive dependency of its shaded thrift jar. To address this, Tachyon 0.6.3 will promote the transitive dependencies of the shaded jar so they can be excluded in spark. 2. Jackson-Mapper-ASL version mismatch: In lower versions of hadoop-client (ie. 1.0.4), version 1.0.1 is included. The parquet library used in spark sql requires version 1.8+. Its unclear to me why upgrading tachyon-client would cause this dependency to break. The solution was to exclude jackson-mapper-asl from hadoop-client. It seems that the dependency management in spark-parent will not work on transitive dependencies, one way to make sure jackson-mapper-asl is included with the correct version is to add it as a top level dependency. The best solution would be to exclude the dependency in the modules which require a higher version, but that did not fix the unit tests. Any suggestions on the best way to solve this would be appreciated! Author: Calvin Jia Closes #5354 from calvinjia/upgrade_tachyon_0.6.3 and squashes the following commits: 0eefe4d [Calvin Jia] Handle httpclient version in maven dependency management. Remove httpclient version setting from profiles. 7c00dfa [Calvin Jia] Set httpclient version to 4.3.2 for selenium. Specify version of httpclient for sql/hive (previously 4.2.5 transitive dependency of libthrift). 9263097 [Calvin Jia] Merge master to test latest changes dbfc1bd [Calvin Jia] Use Tachyon 0.6.4 for cleaner dependencies. e2ff80a [Calvin Jia] Exclude the jetty and curator promoted dependencies from tachyon-client. a3a29da [Calvin Jia] Update tachyon-client exclusions. 0ae6c97 [Calvin Jia] Change tachyon version to 0.6.3 a204df9 [Calvin Jia] Update make distribution tachyon version. a93c94f [Calvin Jia] Exclude jackson-mapper-asl from hadoop client since it has a lower version than spark's expected version. a8a923c [Calvin Jia] Exclude httpcomponents from Tachyon 910fabd [Calvin Jia] Update to master eed9230 [Calvin Jia] Update tachyon version to 0.6.1. 11907b3 [Calvin Jia] Use TachyonURI for tachyon paths instead of strings. 71bf441 [Calvin Jia] Upgrade Tachyon client version to 0.6.0. --- assembly/pom.xml | 10 ---------- core/pom.xml | 6 +++++- .../spark/storage/TachyonBlockManager.scala | 16 ++++++++-------- .../main/scala/org/apache/spark/util/Utils.scala | 4 +++- examples/pom.xml | 5 ----- launcher/pom.xml | 6 ++++++ make-distribution.sh | 2 +- pom.xml | 12 +++++++++++- sql/hive/pom.xml | 5 +++++ 9 files changed, 39 insertions(+), 27 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index f1f8b0d3682e2..20593e710dedb 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -213,16 +213,6 @@ - - kinesis-asl - - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - - - diff --git a/core/pom.xml b/core/pom.xml index e80829b7a7f3d..5e89d548cd47f 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -74,6 +74,10 @@ javax.servlet servlet-api + + org.codehaus.jackson + jackson-mapper-asl +
@@ -275,7 +279,7 @@ org.tachyonproject tachyon-client - 0.5.0 + 0.6.4 org.apache.hadoop diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 951897cead996..583f1fdf0475b 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -20,8 +20,8 @@ package org.apache.spark.storage import java.text.SimpleDateFormat import java.util.{Date, Random} -import tachyon.client.TachyonFS -import tachyon.client.TachyonFile +import tachyon.TachyonURI +import tachyon.client.{TachyonFile, TachyonFS} import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode @@ -40,7 +40,7 @@ private[spark] class TachyonBlockManager( val master: String) extends Logging { - val client = if (master != null && master != "") TachyonFS.get(master) else null + val client = if (master != null && master != "") TachyonFS.get(new TachyonURI(master)) else null if (client == null) { logError("Failed to connect to the Tachyon as the master address is not configured") @@ -60,11 +60,11 @@ private[spark] class TachyonBlockManager( addShutdownHook() def removeFile(file: TachyonFile): Boolean = { - client.delete(file.getPath(), false) + client.delete(new TachyonURI(file.getPath()), false) } def fileExists(file: TachyonFile): Boolean = { - client.exist(file.getPath()) + client.exist(new TachyonURI(file.getPath())) } def getFile(filename: String): TachyonFile = { @@ -81,7 +81,7 @@ private[spark] class TachyonBlockManager( if (old != null) { old } else { - val path = tachyonDirs(dirId) + "/" + "%02x".format(subDirId) + val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}") client.mkdir(path) val newDir = client.getFile(path) subDirs(dirId)(subDirId) = newDir @@ -89,7 +89,7 @@ private[spark] class TachyonBlockManager( } } } - val filePath = subDir + "/" + filename + val filePath = new TachyonURI(s"$subDir/$filename") if(!client.exist(filePath)) { client.createFile(filePath) } @@ -113,7 +113,7 @@ private[spark] class TachyonBlockManager( tries += 1 try { tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - val path = rootDir + "/" + "spark-tachyon-" + tachyonDirId + val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId") if (!client.exist(path)) { foundLocalDir = client.mkdir(path) tachyonDir = client.getFile(path) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2feb7341b159b..667aa168e7ef3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -42,6 +42,8 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ + +import tachyon.TachyonURI import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ @@ -955,7 +957,7 @@ private[spark] object Utils extends Logging { * Delete a file or directory and its contents recursively. */ def deleteRecursively(dir: TachyonFile, client: TachyonFS) { - if (!client.delete(dir.getPath(), true)) { + if (!client.delete(new TachyonURI(dir.getPath()), true)) { throw new IOException("Failed to delete the tachyon dir: " + dir) } } diff --git a/examples/pom.xml b/examples/pom.xml index afd7c6d52f0dd..df1717403b673 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -390,11 +390,6 @@ spark-streaming-kinesis-asl_${scala.binary.version} ${project.version} - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - diff --git a/launcher/pom.xml b/launcher/pom.xml index 182e5f60218db..ebfa7685eaa18 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -68,6 +68,12 @@ org.apache.hadoop hadoop-client test + + + org.codehaus.jackson + jackson-mapper-asl + + diff --git a/make-distribution.sh b/make-distribution.sh index 738a9c4d69601..cb65932b4abc0 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -32,7 +32,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.5.0" +TACHYON_VERSION="0.6.4" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" diff --git a/pom.xml b/pom.xml index bcc2f57f1af5d..4b0b0c85eff21 100644 --- a/pom.xml +++ b/pom.xml @@ -146,7 +146,7 @@ 0.7.1 1.8.3 1.1.0 - 4.2.6 + 4.3.2 3.4.1 ${project.build.directory}/spark-test-classpath.txt 2.10.4 @@ -420,6 +420,16 @@ jsr305 1.3.9 + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + + + org.apache.httpcomponents + httpcore + ${commons.httpclient.version} + org.seleniumhq.selenium selenium-java diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 04440076a26a3..21dce8d8a565a 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -59,6 +59,11 @@ ${hive.group} hive-exec + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + org.codehaus.jackson jackson-mapper-asl From d874f8b546d8fae95bc92d8461b8189e51cb731b Mon Sep 17 00:00:00 2001 From: linweizhong Date: Fri, 24 Apr 2015 20:23:19 -0700 Subject: [PATCH 125/191] [PySpark][Minor] Update sql example, so that can read file correctly To run Spark, default will read file from HDFS if we don't set the schema. Author: linweizhong Closes #5684 from Sephiroth-Lin/pyspark_example_minor and squashes the following commits: 19fe145 [linweizhong] Update example sql.py, so that can read file correctly --- examples/src/main/python/sql.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index 87d7b088f077b..2c188759328f2 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -18,6 +18,7 @@ from __future__ import print_function import os +import sys from pyspark import SparkContext from pyspark.sql import SQLContext @@ -50,7 +51,11 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. - path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") + if len(sys.argv) < 2: + path = "file://" + \ + os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") + else: + path = sys.argv[1] # Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root From 59b7cfc41b2c06fbfbf6aca16c1619496a8d1d00 Mon Sep 17 00:00:00 2001 From: Deborah Siegel Date: Fri, 24 Apr 2015 20:25:07 -0700 Subject: [PATCH 126/191] [SPARK-7136][Docs] Spark SQL and DataFrame Guide fix example file and paths Changes example file for Generic Load/Save Functions to users.parquet rather than people.parquet which doesn't exist unless a later example has already been executed. Also adds filepaths. Author: Deborah Siegel Author: DEBORAH SIEGEL Author: DEBORAH SIEGEL Author: DEBORAH SIEGEL Closes #5693 from d3borah/master and squashes the following commits: 4d5e43b [Deborah Siegel] sparkSQL doc change b15a497 [Deborah Siegel] Revert "sparkSQL doc change" 5a2863c [DEBORAH SIEGEL] Merge remote-tracking branch 'upstream/master' 91972fc [DEBORAH SIEGEL] sparkSQL doc change f000e59 [DEBORAH SIEGEL] Merge remote-tracking branch 'upstream/master' db54173 [DEBORAH SIEGEL] fixed aggregateMessages example in graphX doc --- docs/sql-programming-guide.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 49b1e69f0e9db..b8233ae06fdf3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -681,8 +681,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config
{% highlight scala %} -val df = sqlContext.load("people.parquet") -df.select("name", "age").save("namesAndAges.parquet") +val df = sqlContext.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% endhighlight %}
@@ -691,8 +691,8 @@ df.select("name", "age").save("namesAndAges.parquet") {% highlight java %} -DataFrame df = sqlContext.load("people.parquet"); -df.select("name", "age").save("namesAndAges.parquet"); +DataFrame df = sqlContext.load("examples/src/main/resources/users.parquet"); +df.select("name", "favorite_color").save("namesAndFavColors.parquet"); {% endhighlight %} @@ -702,8 +702,8 @@ df.select("name", "age").save("namesAndAges.parquet"); {% highlight python %} -df = sqlContext.load("people.parquet") -df.select("name", "age").save("namesAndAges.parquet") +df = sqlContext.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% endhighlight %} @@ -722,7 +722,7 @@ using this syntax.
{% highlight scala %} -val df = sqlContext.load("people.json", "json") +val df = sqlContext.load("examples/src/main/resources/people.json", "json") df.select("name", "age").save("namesAndAges.parquet", "parquet") {% endhighlight %} @@ -732,7 +732,7 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet") {% highlight java %} -DataFrame df = sqlContext.load("people.json", "json"); +DataFrame df = sqlContext.load("examples/src/main/resources/people.json", "json"); df.select("name", "age").save("namesAndAges.parquet", "parquet"); {% endhighlight %} @@ -743,7 +743,7 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet"); {% highlight python %} -df = sqlContext.load("people.json", "json") +df = sqlContext.load("examples/src/main/resources/people.json", "json") df.select("name", "age").save("namesAndAges.parquet", "parquet") {% endhighlight %} From cca9905b93483614b330b09b36c6526b551e17dc Mon Sep 17 00:00:00 2001 From: KeheCAI Date: Sat, 25 Apr 2015 08:42:38 -0400 Subject: [PATCH 127/191] update the deprecated CountMinSketchMonoid function to TopPctCMS function http://twitter.github.io/algebird/index.html#com.twitter.algebird.legacy.CountMinSketchMonoid$ The CountMinSketchMonoid has been deprecated since 0.8.1. Newer code should use TopPctCMS.monoid(). ![image](https://cloud.githubusercontent.com/assets/1327396/7269619/d8b48b92-e8d5-11e4-8902-087f630e6308.png) Author: KeheCAI Closes #5629 from caikehe/master and squashes the following commits: e8aa06f [KeheCAI] update algebird-core to version 0.9.0 from 0.8.1 5653351 [KeheCAI] change scala code style 4c0dfd1 [KeheCAI] update the deprecated CountMinSketchMonoid function to TopPctCMS function --- examples/pom.xml | 2 +- .../apache/spark/examples/streaming/TwitterAlgebirdCMS.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/pom.xml b/examples/pom.xml index df1717403b673..5b04b4f8d6ca0 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -245,7 +245,7 @@ com.twitter algebird-core_${scala.binary.version} - 0.8.1 + 0.9.0 org.scalacheck diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index 62f49530edb12..c10de84a80ffe 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming import com.twitter.algebird._ +import com.twitter.algebird.CMSHasherImplicits._ import org.apache.spark.SparkConf import org.apache.spark.SparkContext._ @@ -67,7 +68,8 @@ object TwitterAlgebirdCMS { val users = stream.map(status => status.getUser.getId) - val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) + // val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) + val cms = TopPctCMS.monoid[Long](EPS, DELTA, SEED, PERC) var globalCMS = cms.zero val mm = new MapMonoid[Long, Int]() var globalExact = Map[Long, Int]() From a61d65fc8b97c01be0fa756b52afdc91c46a8561 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 25 Apr 2015 10:37:34 -0700 Subject: [PATCH 128/191] Revert "[SPARK-6752][Streaming] Allow StreamingContext to be recreated from checkpoint and existing SparkContext" This reverts commit 534f2a43625fbf1a3a65d09550a19875cd1dce43. --- .../spark/api/java/function/Function0.java | 27 --- .../apache/spark/streaming/Checkpoint.scala | 26 +-- .../spark/streaming/StreamingContext.scala | 85 ++-------- .../api/java/JavaStreamingContext.scala | 119 +------------ .../apache/spark/streaming/JavaAPISuite.java | 145 ++++------------ .../spark/streaming/CheckpointSuite.scala | 3 +- .../streaming/StreamingContextSuite.scala | 159 ------------------ 7 files changed, 61 insertions(+), 503 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/api/java/function/Function0.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java deleted file mode 100644 index 38e410c5debe6..0000000000000 --- a/core/src/main/java/org/apache/spark/api/java/function/Function0.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java.function; - -import java.io.Serializable; - -/** - * A zero-argument function that returns an R. - */ -public interface Function0 extends Serializable { - public R call() throws Exception; -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 7bfae253c3a0c..0a50485118588 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -77,8 +77,7 @@ object Checkpoint extends Logging { } /** Get checkpoint files present in the give directory, ordered by oldest-first */ - def getCheckpointFiles(checkpointDir: String, fsOption: Option[FileSystem] = None): Seq[Path] = { - + def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = { def sortFunc(path1: Path, path2: Path): Boolean = { val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } @@ -86,7 +85,6 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) - val fs = fsOption.getOrElse(path.getFileSystem(new Configuration())) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { @@ -162,7 +160,7 @@ class CheckpointWriter( } // Delete old checkpoint files - val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) + val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) if (allCheckpointFiles.size > 10) { allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { logInfo("Deleting " + file) @@ -236,24 +234,15 @@ class CheckpointWriter( private[streaming] object CheckpointReader extends Logging { - /** - * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint - * files, then return None, else try to return the latest valid checkpoint object. If no - * checkpoint files could be read correctly, then return None (if ignoreReadError = true), - * or throw exception (if ignoreReadError = false). - */ - def read( - checkpointDir: String, - conf: SparkConf, - hadoopConf: Configuration, - ignoreReadError: Boolean = false): Option[Checkpoint] = { + def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = + { val checkpointPath = new Path(checkpointDir) // TODO(rxin): Why is this a def?! def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files - val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse if (checkpointFiles.isEmpty) { return None } @@ -293,10 +282,7 @@ object CheckpointReader extends Logging { }) // If none of checkpoint files could be read, then throw exception - if (!ignoreReadError) { - throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") - } - None + throw new SparkException("Failed to read checkpoint from directory " + checkpointPath) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 90c8b47aebce0..f57f295874645 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -107,19 +107,6 @@ class StreamingContext private[streaming] ( */ def this(path: String) = this(path, new Configuration) - /** - * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. - * @param path Path to the directory that was specified as the checkpoint directory - * @param sparkContext Existing SparkContext - */ - def this(path: String, sparkContext: SparkContext) = { - this( - sparkContext, - CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get, - null) - } - - if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") @@ -128,12 +115,10 @@ class StreamingContext private[streaming] ( private[streaming] val isCheckpointPresent = (cp_ != null) private[streaming] val sc: SparkContext = { - if (sc_ != null) { - sc_ - } else if (isCheckpointPresent) { + if (isCheckpointPresent) { new SparkContext(cp_.createSparkConf()) } else { - throw new SparkException("Cannot create StreamingContext without a SparkContext") + sc_ } } @@ -144,7 +129,7 @@ class StreamingContext private[streaming] ( private[streaming] val conf = sc.conf - private[streaming] val env = sc.env + private[streaming] val env = SparkEnv.get private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { @@ -189,9 +174,7 @@ class StreamingContext private[streaming] ( /** Register streaming source to metrics system */ private val streamingSource = new StreamingSource(this) - assert(env != null) - assert(env.metricsSystem != null) - env.metricsSystem.registerSource(streamingSource) + SparkEnv.get.metricsSystem.registerSource(streamingSource) /** Enumeration to identify current state of the StreamingContext */ private[streaming] object StreamingContextState extends Enumeration { @@ -638,59 +621,19 @@ object StreamingContext extends Logging { hadoopConf: Configuration = new Configuration(), createOnError: Boolean = false ): StreamingContext = { - val checkpointOption = CheckpointReader.read( - checkpointPath, new SparkConf(), hadoopConf, createOnError) + val checkpointOption = try { + CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf) + } catch { + case e: Exception => + if (createOnError) { + None + } else { + throw e + } + } checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the StreamingContext - * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note - * that the SparkConf configuration in the checkpoint data will not be restored as the - * SparkContext has already been created. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new StreamingContext using the given SparkContext - * @param sparkContext SparkContext using which the StreamingContext will be created - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: SparkContext => StreamingContext, - sparkContext: SparkContext - ): StreamingContext = { - getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the StreamingContext - * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note - * that the SparkConf configuration in the checkpoint data will not be restored as the - * SparkContext has already been created. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new StreamingContext using the given SparkContext - * @param sparkContext SparkContext using which the StreamingContext will be created - * @param createOnError Whether to create a new StreamingContext if there is an - * error in reading checkpoint data. By default, an exception will be - * thrown on error. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: SparkContext => StreamingContext, - sparkContext: SparkContext, - createOnError: Boolean - ): StreamingContext = { - val checkpointOption = CheckpointReader.read( - checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError) - checkpointOption.map(new StreamingContext(sparkContext, _, null)) - .getOrElse(creatingFunc(sparkContext)) - } - /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 572d7d8e8753d..4095a7cc84946 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -32,14 +32,13 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} -import org.apache.spark.api.java.function.{Function0 => JFunction0} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.receiver.Receiver import org.apache.hadoop.conf.Configuration +import org.apache.spark.streaming.dstream.{PluggableInputDStream, ReceiverInputDStream, DStream} +import org.apache.spark.streaming.receiver.Receiver /** * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main @@ -656,7 +655,6 @@ object JavaStreamingContext { * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext */ - @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( checkpointPath: String, factory: JavaStreamingContextFactory @@ -678,7 +676,6 @@ object JavaStreamingContext { * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system */ - @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -703,7 +700,6 @@ object JavaStreamingContext { * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. */ - @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -716,117 +712,6 @@ object JavaStreamingContext { new JavaStreamingContext(ssc) } - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction0[JavaStreamingContext] - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - creatingFunc.call().ssc - }) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible - * file system - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction0[JavaStreamingContext], - hadoopConf: Configuration - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - creatingFunc.call().ssc - }, hadoopConf) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible - * file system - * @param createOnError Whether to create a new JavaStreamingContext if there is an - * error in reading checkpoint data. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction0[JavaStreamingContext], - hadoopConf: Configuration, - createOnError: Boolean - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - creatingFunc.call().ssc - }, hadoopConf, createOnError) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param sparkContext SparkContext using which the StreamingContext will be created - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], - sparkContext: JavaSparkContext - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { - creatingFunc.call(new JavaSparkContext(sparkContext)).ssc - }, sparkContext.sc) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param sparkContext SparkContext using which the StreamingContext will be created - * @param createOnError Whether to create a new JavaStreamingContext if there is an - * error in reading checkpoint data. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], - sparkContext: JavaSparkContext, - createOnError: Boolean - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { - creatingFunc.call(new JavaSparkContext(sparkContext)).ssc - }, sparkContext.sc, createOnError) - new JavaStreamingContext(ssc) - } - /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index cb2e8380b4933..90340753a4eed 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -22,12 +22,10 @@ import java.nio.charset.Charset; import java.util.*; -import org.apache.commons.lang.mutable.MutableBoolean; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; - import scala.Tuple2; import org.junit.Assert; @@ -47,7 +45,6 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; -import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -932,7 +929,7 @@ public void testPairMap() { // Maps pair -> pair of different type public Tuple2 call(Tuple2 in) throws Exception { return in.swap(); } - }); + }); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -990,12 +987,12 @@ public void testPairMap2() { // Maps pair -> single JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaDStream reversed = pairStream.map( - new Function, Integer>() { - @Override - public Integer call(Tuple2 in) throws Exception { - return in._2(); - } - }); + new Function, Integer>() { + @Override + public Integer call(Tuple2 in) throws Exception { + return in._2(); + } + }); JavaTestUtils.attachTestOutputStream(reversed); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1126,7 +1123,7 @@ public void testCombineByKey() { JavaPairDStream combined = pairStream.combineByKey( new Function() { - @Override + @Override public Integer call(Integer i) throws Exception { return i; } @@ -1147,14 +1144,14 @@ public void testCountByValue() { Arrays.asList("hello")); List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), - Arrays.asList( - new Tuple2("hello", 1L))); + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("moon", 1L)), + Arrays.asList( + new Tuple2("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1252,17 +1249,17 @@ public void testUpdateStateByKey() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v : values) { - out = out + v; - } - return Optional.of(out); + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v: values) { + out = out + v; } + return Optional.of(out); + } }); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1295,17 +1292,17 @@ public void testUpdateStateByKeyWithInitial() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v : values) { - out = out + v; - } - return Optional.of(out); + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v: values) { + out = out + v; } + return Optional.of(out); + } }, new HashPartitioner(1), initialRDD); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1331,7 +1328,7 @@ public void testReduceByKeyAndWindowWithInverse() { JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1710,74 +1707,6 @@ public Integer call(String s) throws Exception { Utils.deleteRecursively(tempDir); } - @SuppressWarnings("unchecked") - @Test - public void testContextGetOrCreate() throws InterruptedException { - - final SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("test") - .set("newContext", "true"); - - File emptyDir = Files.createTempDir(); - emptyDir.deleteOnExit(); - StreamingContextSuite contextSuite = new StreamingContextSuite(); - String corruptedCheckpointDir = contextSuite.createCorruptedCheckpoint(); - String checkpointDir = contextSuite.createValidCheckpoint(); - - // Function to create JavaStreamingContext without any output operations - // (used to detect the new context) - final MutableBoolean newContextCreated = new MutableBoolean(false); - Function0 creatingFunc = new Function0() { - public JavaStreamingContext call() { - newContextCreated.setValue(true); - return new JavaStreamingContext(conf, Seconds.apply(1)); - } - }; - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc); - Assert.assertTrue("new context not created", newContextCreated.isTrue()); - ssc.stop(); - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration(), true); - Assert.assertTrue("new context not created", newContextCreated.isTrue()); - ssc.stop(); - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); - Assert.assertTrue("old context not recovered", newContextCreated.isFalse()); - ssc.stop(); - - // Function to create JavaStreamingContext using existing JavaSparkContext - // without any output operations (used to detect the new context) - Function creatingFunc2 = - new Function() { - public JavaStreamingContext call(JavaSparkContext context) { - newContextCreated.setValue(true); - return new JavaStreamingContext(context, Seconds.apply(1)); - } - }; - - JavaSparkContext sc = new JavaSparkContext(conf); - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc); - Assert.assertTrue("new context not created", newContextCreated.isTrue()); - ssc.stop(false); - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true); - Assert.assertTrue("new context not created", newContextCreated.isTrue()); - ssc.stop(false); - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc); - Assert.assertTrue("old context not recovered", newContextCreated.isFalse()); - ssc.stop(); - } /* TEST DISABLED: Pending a discussion about checkpoint() semantics with TD @SuppressWarnings("unchecked") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 6b0a3f91d4d06..54c30440a6e8d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -430,8 +430,9 @@ class CheckpointSuite extends TestSuiteBase { assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } // Wait for a checkpoint to be written + val fs = new Path(checkpointDir).getFileSystem(ssc.sc.hadoopConfiguration) eventually(eventuallyTimeout) { - assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) + assert(Checkpoint.getCheckpointFiles(checkpointDir, fs).size === 6) } ssc.stop() // Check that we shut down while the third batch was being processed diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 4f193322ad33e..58353a5f97c8a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,10 +17,8 @@ package org.apache.spark.streaming -import java.io.File import java.util.concurrent.atomic.AtomicInteger -import org.apache.commons.io.FileUtils import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ @@ -332,139 +330,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } - test("getOrCreate") { - val conf = new SparkConf().setMaster(master).setAppName(appName) - - // Function to create StreamingContext that has a config to identify it to be new context - var newContextCreated = false - def creatingFunction(): StreamingContext = { - newContextCreated = true - new StreamingContext(conf, batchDuration) - } - - // Call ssc.stop after a body of code - def testGetOrCreate(body: => Unit): Unit = { - newContextCreated = false - try { - body - } finally { - if (ssc != null) { - ssc.stop() - } - ssc = null - } - } - - val emptyPath = Utils.createTempDir().getAbsolutePath() - - // getOrCreate should create new context with empty path - testGetOrCreate { - ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - } - - val corrutedCheckpointPath = createCorruptedCheckpoint() - - // getOrCreate should throw exception with fake checkpoint file and createOnError = false - intercept[Exception] { - ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _) - } - - // getOrCreate should throw exception with fake checkpoint file - intercept[Exception] { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, createOnError = false) - } - - // getOrCreate should create new context with fake checkpoint file and createOnError = true - testGetOrCreate { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - } - - val checkpointPath = createValidCheckpoint() - - // getOrCreate should recover context with checkpoint path, and recover old configuration - testGetOrCreate { - ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) - assert(ssc != null, "no context created") - assert(!newContextCreated, "old context not recovered") - assert(ssc.conf.get("someKey") === "someValue") - } - } - - test("getOrCreate with existing SparkContext") { - val conf = new SparkConf().setMaster(master).setAppName(appName) - sc = new SparkContext(conf) - - // Function to create StreamingContext that has a config to identify it to be new context - var newContextCreated = false - def creatingFunction(sparkContext: SparkContext): StreamingContext = { - newContextCreated = true - new StreamingContext(sparkContext, batchDuration) - } - - // Call ssc.stop(stopSparkContext = false) after a body of cody - def testGetOrCreate(body: => Unit): Unit = { - newContextCreated = false - try { - body - } finally { - if (ssc != null) { - ssc.stop(stopSparkContext = false) - } - ssc = null - } - } - - val emptyPath = Utils.createTempDir().getAbsolutePath() - - // getOrCreate should create new context with empty path - testGetOrCreate { - ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - } - - val corrutedCheckpointPath = createCorruptedCheckpoint() - - // getOrCreate should throw exception with fake checkpoint file and createOnError = false - intercept[Exception] { - ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc) - } - - // getOrCreate should throw exception with fake checkpoint file - intercept[Exception] { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, sc, createOnError = false) - } - - // getOrCreate should create new context with fake checkpoint file and createOnError = true - testGetOrCreate { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, sc, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - } - - val checkpointPath = createValidCheckpoint() - - // StreamingContext.getOrCreate should recover context with checkpoint path - testGetOrCreate { - ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc) - assert(ssc != null, "no context created") - assert(!newContextCreated, "old context not recovered") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - assert(!ssc.conf.contains("someKey"), - "recovered StreamingContext unexpectedly has old config") - } - } - test("DStream and generated RDD creation sites") { testPackage.test() } @@ -474,30 +339,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val inputStream = new TestInputStream(s, input, 1) inputStream } - - def createValidCheckpoint(): String = { - val testDirectory = Utils.createTempDir().getAbsolutePath() - val checkpointDirectory = Utils.createTempDir().getAbsolutePath() - val conf = new SparkConf().setMaster(master).setAppName(appName) - conf.set("someKey", "someValue") - ssc = new StreamingContext(conf, batchDuration) - ssc.checkpoint(checkpointDirectory) - ssc.textFileStream(testDirectory).foreachRDD { rdd => rdd.count() } - ssc.start() - eventually(timeout(10000 millis)) { - assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) - } - ssc.stop() - checkpointDirectory - } - - def createCorruptedCheckpoint(): String = { - val checkpointDirectory = Utils.createTempDir().getAbsolutePath() - val fakeCheckpointFile = Checkpoint.checkpointFile(checkpointDirectory, Time(1000)) - FileUtils.write(new File(fakeCheckpointFile.toString()), "blablabla") - assert(Checkpoint.getCheckpointFiles(checkpointDirectory).nonEmpty) - checkpointDirectory - } } class TestException(msg: String) extends Exception(msg) From a7160c4e3aae22600d05e257d0b4d2428754b8ea Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 25 Apr 2015 12:27:19 -0700 Subject: [PATCH 129/191] [SPARK-6113] [ML] Tree ensembles for Pipelines API This is a continuation of [https://github.com/apache/spark/pull/5530] (which was for Decision Trees), but for ensembles: Random Forests and Gradient-Boosted Trees. Please refer to the JIRA [https://issues.apache.org/jira/browse/SPARK-6113], the design doc linked from the JIRA, and the previous PR linked above for design discussions. This PR follows the example set by the previous PR for Decision Trees. It includes a few cleanups to Decision Trees. Note: There is one issue which will be addressed in a separate PR: Ensembles' component Models have no parent or fittingParamMap. I plan to submit a separate PR which makes those values in Model be Options. It does not matter much which PR gets merged first. CC: mengxr manishamde codedeft chouqin Author: Joseph K. Bradley Closes #5626 from jkbradley/dt-api-ensembles and squashes the following commits: 729167a [Joseph K. Bradley] small cleanups based on code review bbae2a2 [Joseph K. Bradley] Updated per all comments in code review 855aa9a [Joseph K. Bradley] scala style fix ea3d901 [Joseph K. Bradley] Added GBT to spark.ml, with tests and examples c0f30c1 [Joseph K. Bradley] Added random forests and test suites to spark.ml. Not tested yet. Need to add example as well d045ebd [Joseph K. Bradley] some more updates, but far from done ee1a10b [Joseph K. Bradley] Added files from old PR and did some initial updates. --- .../examples/ml/DecisionTreeExample.scala | 139 ++++++---- .../apache/spark/examples/ml/GBTExample.scala | 238 +++++++++++++++++ .../examples/ml/RandomForestExample.scala | 248 +++++++++++++++++ .../mllib/GradientBoostedTreesRunner.scala | 1 + .../scala/org/apache/spark/ml/Model.scala | 2 + .../DecisionTreeClassifier.scala | 24 +- .../ml/classification/GBTClassifier.scala | 228 ++++++++++++++++ .../RandomForestClassifier.scala | 185 +++++++++++++ .../spark/ml/impl/tree/treeParams.scala | 249 +++++++++++++++--- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +- .../spark/ml/param/shared/sharedParams.scala | 20 ++ .../ml/regression/DecisionTreeRegressor.scala | 14 +- .../spark/ml/regression/GBTRegressor.scala | 218 +++++++++++++++ .../ml/regression/RandomForestRegressor.scala | 167 ++++++++++++ .../scala/org/apache/spark/ml/tree/Node.scala | 6 +- .../org/apache/spark/ml/tree/Split.scala | 22 +- .../org/apache/spark/ml/tree/treeModels.scala | 46 +++- .../JavaDecisionTreeClassifierSuite.java | 10 +- .../JavaGBTClassifierSuite.java | 100 +++++++ .../JavaRandomForestClassifierSuite.java | 103 ++++++++ .../JavaDecisionTreeRegressorSuite.java | 26 +- .../ml/regression/JavaGBTRegressorSuite.java | 99 +++++++ .../JavaRandomForestRegressorSuite.java | 102 +++++++ .../DecisionTreeClassifierSuite.scala | 2 +- .../classification/GBTClassifierSuite.scala | 136 ++++++++++ .../RandomForestClassifierSuite.scala | 166 ++++++++++++ .../org/apache/spark/ml/impl/TreeTests.scala | 10 +- .../DecisionTreeRegressorSuite.scala | 2 +- .../ml/regression/GBTRegressorSuite.scala | 137 ++++++++++ .../RandomForestRegressorSuite.scala | 122 +++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 6 +- 31 files changed, 2658 insertions(+), 174 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 2cd515c89d3d2..9002e99d82ad3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -22,10 +22,9 @@ import scala.language.reflectiveCalls import scopt.OptionParser -import org.apache.spark.ml.tree.DecisionTreeModel import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams -import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer} import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer} import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} @@ -64,8 +63,6 @@ object DecisionTreeExample { maxBins: Int = 32, minInstancesPerNode: Int = 1, minInfoGain: Double = 0.0, - numTrees: Int = 1, - featureSubsetStrategy: String = "auto", fracTest: Double = 0.2, cacheNodeIds: Boolean = false, checkpointDir: Option[String] = None, @@ -123,8 +120,8 @@ object DecisionTreeExample { .required() .action((x, c) => c.copy(input = x)) checkConfig { params => - if (params.fracTest < 0 || params.fracTest > 1) { - failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") } else { success } @@ -200,9 +197,18 @@ object DecisionTreeExample { throw new IllegalArgumentException("Algo ${params.algo} not supported.") } } - val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache()) + val dataframes = splits.map(_.toDF()).map(labelsToStrings) + val training = dataframes(0).cache() + val test = dataframes(1).cache() - (dataframes(0), dataframes(1)) + val numTraining = training.count() + val numTest = test.count() + val numFeatures = training.select("features").first().getAs[Vector](0).size + println("Loaded data:") + println(s" numTraining = $numTraining, numTest = $numTest") + println(s" numFeatures = $numFeatures") + + (training, test) } def run(params: Params) { @@ -217,13 +223,6 @@ object DecisionTreeExample { val (training: DataFrame, test: DataFrame) = loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest) - val numTraining = training.count() - val numTest = test.count() - val numFeatures = training.select("features").first().getAs[Vector](0).size - println("Loaded data:") - println(s" numTraining = $numTraining, numTest = $numTest") - println(s" numFeatures = $numFeatures") - // Set up Pipeline val stages = new mutable.ArrayBuffer[PipelineStage]() // (1) For classification, re-index classes. @@ -241,7 +240,7 @@ object DecisionTreeExample { .setOutputCol("indexedFeatures") .setMaxCategories(10) stages += featuresIndexer - // (3) Learn DecisionTree + // (3) Learn Decision Tree val dt = algo match { case "classification" => new DecisionTreeClassifier() @@ -275,62 +274,86 @@ object DecisionTreeExample { println(s"Training time: $elapsedTime seconds") // Get the trained Decision Tree from the fitted PipelineModel - val treeModel: DecisionTreeModel = algo match { + algo match { case "classification" => - pipelineModel.getModel[DecisionTreeClassificationModel]( + val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel]( dt.asInstanceOf[DecisionTreeClassifier]) + if (treeModel.numNodes < 20) { + println(treeModel.toDebugString) // Print full model. + } else { + println(treeModel) // Print model summary. + } case "regression" => - pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor]) - case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") - } - if (treeModel.numNodes < 20) { - println(treeModel.toDebugString) // Print full model. - } else { - println(treeModel) // Print model summary. - } - - // Predict on training - val trainingFullPredictions = pipelineModel.transform(training).cache() - val trainingPredictions = trainingFullPredictions.select("prediction") - .map(_.getDouble(0)) - val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0)) - // Predict on test data - val testFullPredictions = pipelineModel.transform(test).cache() - val testPredictions = testFullPredictions.select("prediction") - .map(_.getDouble(0)) - val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0)) - - // For classification, print number of classes for reference. - if (algo == "classification") { - val numClasses = - MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match { - case Some(n) => n - case None => throw new RuntimeException( - "DecisionTreeExample had unknown failure when indexing labels for classification.") + val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel]( + dt.asInstanceOf[DecisionTreeRegressor]) + if (treeModel.numNodes < 20) { + println(treeModel.toDebugString) // Print full model. + } else { + println(treeModel) // Print model summary. } - println(s"numClasses = $numClasses.") + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") } // Evaluate model on training, test data algo match { case "classification" => - val trainingAccuracy = - new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision - println(s"Train accuracy = $trainingAccuracy") - val testAccuracy = - new MulticlassMetrics(testPredictions.zip(testLabels)).precision - println(s"Test accuracy = $testAccuracy") + println("Training data results:") + evaluateClassificationModel(pipelineModel, training, labelColName) + println("Test data results:") + evaluateClassificationModel(pipelineModel, test, labelColName) case "regression" => - val trainingRMSE = - new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError - println(s"Training root mean squared error (RMSE) = $trainingRMSE") - val testRMSE = - new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError - println(s"Test root mean squared error (RMSE) = $testRMSE") + println("Training data results:") + evaluateRegressionModel(pipelineModel, training, labelColName) + println("Test data results:") + evaluateRegressionModel(pipelineModel, test, labelColName) case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") } sc.stop() } + + /** + * Evaluate the given ClassificationModel on data. Print the results. + * @param model Must fit ClassificationModel abstraction + * @param data DataFrame with "prediction" and labelColName columns + * @param labelColName Name of the labelCol parameter for the model + * + * TODO: Change model type to ClassificationModel once that API is public. SPARK-5995 + */ + private[ml] def evaluateClassificationModel( + model: Transformer, + data: DataFrame, + labelColName: String): Unit = { + val fullPredictions = model.transform(data).cache() + val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + // Print number of classes for reference + val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match { + case Some(n) => n + case None => throw new RuntimeException( + "Unknown failure when indexing labels for classification.") + } + val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision + println(s" Accuracy ($numClasses classes): $accuracy") + } + + /** + * Evaluate the given RegressionModel on data. Print the results. + * @param model Must fit RegressionModel abstraction + * @param data DataFrame with "prediction" and labelColName columns + * @param labelColName Name of the labelCol parameter for the model + * + * TODO: Change model type to RegressionModel once that API is public. SPARK-5995 + */ + private[ml] def evaluateRegressionModel( + model: Transformer, + data: DataFrame, + labelColName: String): Unit = { + val fullPredictions = model.transform(data).cache() + val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError + println(s" Root mean squared error (RMSE): $RMSE") + } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala new file mode 100644 index 0000000000000..5fccb142d4c3d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} +import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} +import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} +import org.apache.spark.sql.DataFrame + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.GBTExample [options] + * }}} + * Decision Trees and ensembles can take a large amount of memory. If the run-example command + * above fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.GBTExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object GBTExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + maxIter: Int = 10, + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("GBTExample") { + head("GBTExample: an example Gradient-Boosted Trees app.") + opt[String]("algo") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Int]("maxIter") + .text(s"number of trees in ensemble, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${ + defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + } + }") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"GBTExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"GBTExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, algo, params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) + stages += featuresIndexer + // (3) Learn GBT + val dt = algo match { + case "classification" => + new GBTClassifier() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setMaxIter(params.maxIter) + case "regression" => + new GBTRegressor() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setMaxIter(params.maxIter) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained GBT from the fitted PipelineModel + algo match { + case "classification" => + val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier]) + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case "regression" => + val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor]) + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName) + case "regression" => + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala new file mode 100644 index 0000000000000..9b909324ec82a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} +import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} +import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} +import org.apache.spark.sql.DataFrame + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.RandomForestExample [options] + * }}} + * Decision Trees and ensembles can take a large amount of memory. If the run-example command + * above fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.RandomForestExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object RandomForestExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + numTrees: Int = 10, + featureSubsetStrategy: String = "auto", + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("RandomForestExample") { + head("RandomForestExample: an example random forest app.") + opt[String]("algo") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Int]("numTrees") + .text(s"number of trees in ensemble, default: ${defaultParams.numTrees}") + .action((x, c) => c.copy(numTrees = x)) + opt[String]("featureSubsetStrategy") + .text(s"number of features to use per node (supported:" + + s" ${RandomForestClassifier.supportedFeatureSubsetStrategies.mkString(",")})," + + s" default: ${defaultParams.numTrees}") + .action((x, c) => c.copy(featureSubsetStrategy = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${ + defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + } + }") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"RandomForestExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"RandomForestExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, algo, params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) + stages += featuresIndexer + // (3) Learn Random Forest + val dt = algo match { + case "classification" => + new RandomForestClassifier() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setFeatureSubsetStrategy(params.featureSubsetStrategy) + .setNumTrees(params.numTrees) + case "regression" => + new RandomForestRegressor() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setFeatureSubsetStrategy(params.featureSubsetStrategy) + .setNumTrees(params.numTrees) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained Random Forest from the fitted PipelineModel + algo match { + case "classification" => + val rfModel = pipelineModel.getModel[RandomForestClassificationModel]( + dt.asInstanceOf[RandomForestClassifier]) + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case "regression" => + val rfModel = pipelineModel.getModel[RandomForestRegressionModel]( + dt.asInstanceOf[RandomForestRegressor]) + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName) + case "regression" => + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 431ead8c0c165..0763a7736305a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} import org.apache.spark.util.Utils + /** * An example runner for Gradient Boosting using decision trees as weak learners. Run with * {{{ diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index cae5082b51196..a491bc7ee8295 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -30,11 +30,13 @@ import org.apache.spark.ml.param.ParamMap abstract class Model[M <: Model[M]] extends Transformer { /** * The parent estimator that produced this model. + * Note: For ensembles' component Models, this value can be null. */ val parent: Estimator[M] /** * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model. + * Note: For ensembles' component Models, this value can be null. */ val fittingParamMap: ParamMap } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 3855e396b5534..ee2a8dc6db171 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -43,8 +43,7 @@ import org.apache.spark.sql.DataFrame @AlphaComponent final class DecisionTreeClassifier extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] - with DecisionTreeParams - with TreeClassifierParams { + with DecisionTreeParams with TreeClassifierParams { // Override parameter setters from parent trait for Java API compatibility. @@ -59,11 +58,9 @@ final class DecisionTreeClassifier override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) - override def setCacheNodeIds(value: Boolean): this.type = - super.setCacheNodeIds(value) + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) - override def setCheckpointInterval(value: Int): this.type = - super.setCheckpointInterval(value) + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) override def setImpurity(value: String): this.type = super.setImpurity(value) @@ -75,8 +72,9 @@ final class DecisionTreeClassifier val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { case Some(n: Int) => n case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" + - s" with invalid label column, without the number of classes specified.") - // TODO: Automatically index labels. + s" with invalid label column ${paramMap(labelCol)}, without the number of classes" + + " specified. See StringIndexer.") + // TODO: Automatically index labels: SPARK-7126 } val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) val strategy = getOldStrategy(categoricalFeatures, numClasses) @@ -85,18 +83,16 @@ final class DecisionTreeClassifier } /** (private[ml]) Create a Strategy instance to use with the old API. */ - override private[ml] def getOldStrategy( + private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], numClasses: Int): OldStrategy = { - val strategy = super.getOldStrategy(categoricalFeatures, numClasses) - strategy.algo = OldAlgo.Classification - strategy.setImpurity(getOldImpurity) - strategy + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, + subsamplingRate = 1.0) } } object DecisionTreeClassifier { - /** Accessor for supported impurities */ + /** Accessor for supported impurities: entropy, gini */ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala new file mode 100644 index 0000000000000..d2e052fbbbf22 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Param, Params, ParamMap} +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.{Loss => OldLoss, LogLoss => OldLogLoss} +import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * learning algorithm for classification. + * It supports binary labels, as well as both continuous and categorical features. + * Note: Multiclass labels are not currently supported. + */ +@AlphaComponent +final class GBTClassifier + extends Predictor[Vector, GBTClassifier, GBTClassificationModel] + with GBTParams with TreeClassifierParams with Logging { + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeClassifierParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + /** + * The impurity setting is ignored for GBT models. + * Individual trees are built using impurity "Variance." + */ + override def setImpurity(value: String): this.type = { + logWarning("GBTClassifier.setImpurity should NOT be used") + this + } + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = { + logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") + super.setSeed(value) + } + + // Parameters from GBTParams: + + override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + + override def setStepSize(value: Double): this.type = super.setStepSize(value) + + // Parameters for GBTClassifier: + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "logistic" + * (default = logistic) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTClassifier.supportedLossTypes.mkString(", ")}") + + setDefault(lossType -> "logistic") + + /** @group setParam */ + def setLossType(value: String): this.type = { + val lossStr = value.toLowerCase + require(GBTClassifier.supportedLossTypes.contains(lossStr), "GBTClassifier was given bad loss" + + s" type: $value. Supported options: ${GBTClassifier.supportedLossTypes.mkString(", ")}") + set(lossType, lossStr) + this + } + + /** @group getParam */ + def getLossType: String = getOrDefault(lossType) + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "logistic" => OldLogLoss + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") + } + } + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): GBTClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("GBTClassifier was given input" + + s" with invalid label column ${paramMap(labelCol)}, without the number of classes" + + " specified. See StringIndexer.") + // TODO: Automatically index labels: SPARK-7126 + } + require(numClasses == 2, + s"GBTClassifier only supports binary classification but was given numClasses = $numClasses") + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + val oldGBT = new OldGBT(boostingStrategy) + val oldModel = oldGBT.run(oldDataset) + GBTClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } +} + +object GBTClassifier { + // The losses below should be lowercase. + /** Accessor for supported loss settings: logistic */ + final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * model for classification. + * It supports binary labels, as well as both continuous and categorical features. + * Note: Multiclass labels are not currently supported. + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ +@AlphaComponent +final class GBTClassificationModel( + override val parent: GBTClassifier, + override val fittingParamMap: ParamMap, + private val _trees: Array[DecisionTreeRegressionModel], + private val _treeWeights: Array[Double]) + extends PredictionModel[Vector, GBTClassificationModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.") + require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + override def treeWeights: Array[Double] = _treeWeights + + override protected def predict(features: Vector): Double = { + // TODO: Override transform() to broadcast model: SPARK-7127 + // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 + // Classifies by thresholding sum of weighted tree predictions + val treePredictions = _trees.map(_.rootNode.predict(features)) + val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + if (prediction > 0.0) 1.0 else 0.0 + } + + override protected def copy(): GBTClassificationModel = { + val m = new GBTClassificationModel(parent, fittingParamMap, _trees, _treeWeights) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"GBTClassificationModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldGBTModel = { + new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) + } +} + +private[ml] object GBTClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldGBTModel, + parent: GBTClassifier, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): GBTClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent, fittingParamMap for each tree is null since there are no good ways to set these. + DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures) + } + new GBTClassificationModel(parent, fittingParamMap, newTrees, oldModel.treeWeights) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala new file mode 100644 index 0000000000000..cfd6508fce890 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import scala.collection.mutable + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for + * classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@AlphaComponent +final class RandomForestClassifier + extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel] + with RandomForestParams with TreeClassifierParams { + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeClassifierParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = super.setSeed(value) + + // Parameters from RandomForestParams: + + override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + + override def setFeatureSubsetStrategy(value: String): this.type = + super.setFeatureSubsetStrategy(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): RandomForestClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("RandomForestClassifier was given input" + + s" with invalid label column ${paramMap(labelCol)}, without the number of classes" + + " specified. See StringIndexer.") + // TODO: Automatically index labels: SPARK-7126 + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) + val oldModel = OldRandomForest.trainClassifier( + oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) + RandomForestClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } +} + +object RandomForestClassifier { + /** Accessor for supported impurity settings: entropy, gini */ + final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities + + /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + final val supportedFeatureSubsetStrategies: Array[String] = + RandomForestParams.supportedFeatureSubsetStrategies +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + * @param _trees Decision trees in the ensemble. + * Warning: These have null parents. + */ +@AlphaComponent +final class RandomForestClassificationModel private[ml] ( + override val parent: RandomForestClassifier, + override val fittingParamMap: ParamMap, + private val _trees: Array[DecisionTreeClassificationModel]) + extends PredictionModel[Vector, RandomForestClassificationModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + // Note: We may add support for weights (based on tree performance) later on. + private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + + override def treeWeights: Array[Double] = _treeWeights + + override protected def predict(features: Vector): Double = { + // TODO: Override transform() to broadcast model. SPARK-7127 + // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 + // Classifies using majority votes. + // Ignore the weights since all are 1.0 for now. + val votes = mutable.Map.empty[Int, Double] + _trees.view.foreach { tree => + val prediction = tree.rootNode.predict(features).toInt + votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight + } + votes.maxBy(_._2)._1 + } + + override protected def copy(): RandomForestClassificationModel = { + val m = new RandomForestClassificationModel(parent, fittingParamMap, _trees) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"RandomForestClassificationModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldRandomForestModel = { + new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld)) + } +} + +private[ml] object RandomForestClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldRandomForestModel, + parent: RandomForestClassifier, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent, fittingParamMap for each tree is null since there are no good ways to set these. + DecisionTreeClassificationModel.fromOld(tree, null, null, categoricalFeatures) + } + new RandomForestClassificationModel(parent, fittingParamMap, newTrees) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala index eb2609faef05a..ab6281b9b2e34 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -20,9 +20,12 @@ package org.apache.spark.ml.impl.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.impl.estimator.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.ml.param.shared.{HasSeed, HasMaxIter} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, + BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy, Impurity => OldImpurity, Variance => OldVariance} +import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} /** @@ -117,79 +120,68 @@ private[ml] trait DecisionTreeParams extends PredictorParams { def setMaxDepth(value: Int): this.type = { require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") set(maxDepth, value) - this } /** @group getParam */ - def getMaxDepth: Int = getOrDefault(maxDepth) + final def getMaxDepth: Int = getOrDefault(maxDepth) /** @group setParam */ def setMaxBins(value: Int): this.type = { require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value") set(maxBins, value) - this } /** @group getParam */ - def getMaxBins: Int = getOrDefault(maxBins) + final def getMaxBins: Int = getOrDefault(maxBins) /** @group setParam */ def setMinInstancesPerNode(value: Int): this.type = { require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value") set(minInstancesPerNode, value) - this } /** @group getParam */ - def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) + final def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) /** @group setParam */ - def setMinInfoGain(value: Double): this.type = { - set(minInfoGain, value) - this - } + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group getParam */ - def getMinInfoGain: Double = getOrDefault(minInfoGain) + final def getMinInfoGain: Double = getOrDefault(minInfoGain) /** @group expertSetParam */ def setMaxMemoryInMB(value: Int): this.type = { require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value") set(maxMemoryInMB, value) - this } /** @group expertGetParam */ - def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) + final def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) /** @group expertSetParam */ - def setCacheNodeIds(value: Boolean): this.type = { - set(cacheNodeIds, value) - this - } + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** @group expertGetParam */ - def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) + final def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) /** @group expertSetParam */ def setCheckpointInterval(value: Int): this.type = { require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value") set(checkpointInterval, value) - this } /** @group expertGetParam */ - def getCheckpointInterval: Int = getOrDefault(checkpointInterval) + final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) - /** - * Create a Strategy instance to use with the old API. - * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0, - * the default for single trees). - */ + /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], - numClasses: Int): OldStrategy = { - val strategy = OldStrategy.defaultStategy(OldAlgo.Classification) + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity, + subsamplingRate: Double): OldStrategy = { + val strategy = OldStrategy.defaultStategy(oldAlgo) + strategy.impurity = oldImpurity strategy.checkpointInterval = getCheckpointInterval strategy.maxBins = getMaxBins strategy.maxDepth = getMaxDepth @@ -199,13 +191,13 @@ private[ml] trait DecisionTreeParams extends PredictorParams { strategy.useNodeIdCache = getCacheNodeIds strategy.numClasses = numClasses strategy.categoricalFeaturesInfo = categoricalFeatures - strategy.subsamplingRate = 1.0 // default for individual trees + strategy.subsamplingRate = subsamplingRate strategy } } /** - * (private trait) Parameters for Decision Tree-based classification algorithms. + * Parameters for Decision Tree-based classification algorithms. */ private[ml] trait TreeClassifierParams extends Params { @@ -215,7 +207,7 @@ private[ml] trait TreeClassifierParams extends Params { * (default = gini) * @group param */ - val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}") @@ -228,11 +220,10 @@ private[ml] trait TreeClassifierParams extends Params { s"Tree-based classifier was given unrecognized impurity: $value." + s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}") set(impurity, impurityStr) - this } /** @group getParam */ - def getImpurity: String = getOrDefault(impurity) + final def getImpurity: String = getOrDefault(impurity) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -249,11 +240,11 @@ private[ml] trait TreeClassifierParams extends Params { private[ml] object TreeClassifierParams { // These options should be lowercase. - val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) + final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) } /** - * (private trait) Parameters for Decision Tree-based regression algorithms. + * Parameters for Decision Tree-based regression algorithms. */ private[ml] trait TreeRegressorParams extends Params { @@ -263,7 +254,7 @@ private[ml] trait TreeRegressorParams extends Params { * (default = variance) * @group param */ - val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}") @@ -276,11 +267,10 @@ private[ml] trait TreeRegressorParams extends Params { s"Tree-based regressor was given unrecognized impurity: $value." + s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}") set(impurity, impurityStr) - this } /** @group getParam */ - def getImpurity: String = getOrDefault(impurity) + final def getImpurity: String = getOrDefault(impurity) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -296,5 +286,186 @@ private[ml] trait TreeRegressorParams extends Params { private[ml] object TreeRegressorParams { // These options should be lowercase. - val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) + final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) +} + +/** + * :: DeveloperApi :: + * Parameters for Decision Tree-based ensemble algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { + + /** + * Fraction of the training data used for learning each decision tree. + * (default = 1.0) + * @group param + */ + final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate", + "Fraction of the training data used for learning each decision tree.") + + setDefault(subsamplingRate -> 1.0) + + /** @group setParam */ + def setSubsamplingRate(value: Double): this.type = { + require(value > 0.0 && value <= 1.0, + s"Subsampling rate must be in range (0,1]. Bad rate: $value") + set(subsamplingRate, value) + } + + /** @group getParam */ + final def getSubsamplingRate: Double = getOrDefault(subsamplingRate) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + /** + * Create a Strategy instance to use with the old API. + * NOTE: The caller should set impurity and seed. + */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) + } +} + +/** + * :: DeveloperApi :: + * Parameters for Random Forest algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait RandomForestParams extends TreeEnsembleParams { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)") + + /** + * The number of features to consider for splits at each tree node. + * Supported options: + * - "auto": Choose automatically for task: + * If numTrees == 1, set to "all." + * If numTrees > 1 (forest), set to "sqrt" for classification and + * to "onethird" for regression. + * - "all": use all features + * - "onethird": use 1/3 of the features + * - "sqrt": use sqrt(number of features) + * - "log2": use log2(number of features) + * (default = "auto") + * + * These various settings are based on the following references: + * - log2: tested in Breiman (2001) + * - sqrt: recommended by Breiman manual for random forests + * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest + * package. + * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]] + * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for + * random forests]] + * + * @group param + */ + final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", + "The number of features to consider for splits at each tree node." + + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}") + + setDefault(numTrees -> 20, featureSubsetStrategy -> "auto") + + /** @group setParam */ + def setNumTrees(value: Int): this.type = { + require(value >= 1, s"Random Forest numTrees parameter cannot be $value; it must be >= 1.") + set(numTrees, value) + } + + /** @group getParam */ + final def getNumTrees: Int = getOrDefault(numTrees) + + /** @group setParam */ + def setFeatureSubsetStrategy(value: String): this.type = { + val strategyStr = value.toLowerCase + require(RandomForestParams.supportedFeatureSubsetStrategies.contains(strategyStr), + s"RandomForestParams was given unrecognized featureSubsetStrategy: $value. Supported" + + s" options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}") + set(featureSubsetStrategy, strategyStr) + } + + /** @group getParam */ + final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy) +} + +private[ml] object RandomForestParams { + // These options should be lowercase. + final val supportedFeatureSubsetStrategies: Array[String] = + Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) +} + +/** + * :: DeveloperApi :: + * Parameters for Gradient-Boosted Tree algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { + + /** + * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each + * estimator. + * (default = 0.1) + * @group param + */ + final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + + " learning rate) in interval (0, 1] for shrinking the contribution of each estimator") + + /* TODO: Add this doc when we add this param. SPARK-7132 + * Threshold for stopping early when runWithValidation is used. + * If the error rate on the validation input changes by less than the validationTol, + * then learning will stop early (before [[numIterations]]). + * This parameter is ignored when run is used. + * (default = 1e-5) + * @group param + */ + // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") + // validationTol -> 1e-5 + + setDefault(maxIter -> 20, stepSize -> 0.1) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = { + require(value >= 1, s"Gradient Boosting maxIter parameter cannot be $value; it must be >= 1.") + set(maxIter, value) + } + + /** @group setParam */ + def setStepSize(value: Double): this.type = { + require(value > 0.0 && value <= 1.0, + s"GBT given invalid step size ($value). Value should be in (0,1].") + set(stepSize, value) + } + + /** @group getParam */ + final def getStepSize: Double = getOrDefault(stepSize) + + /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ + private[ml] def getOldBoostingStrategy( + categoricalFeatures: Map[Int, Int], + oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) + // NOTE: The old API does not support "seed" so we ignore it. + new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) + } + + /** Get old Gradient Boosting Loss type */ + private[ml] def getOldLossType: OldLoss } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 95d7e64790c79..e88c48741e99f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -45,7 +45,8 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name"), ParamDesc[Int]("checkpointInterval", "checkpoint interval"), - ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true"))) + ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()"))) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" @@ -154,6 +155,7 @@ private[shared] object SharedParamsCodeGen { | |import org.apache.spark.annotation.DeveloperApi |import org.apache.spark.ml.param._ + |import org.apache.spark.util.Utils | |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. | diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 72b08bf276483..a860b8834cff9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.param.shared import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ +import org.apache.spark.util.Utils // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. @@ -256,4 +257,23 @@ trait HasFitIntercept extends Params { /** @group getParam */ final def getFitIntercept: Boolean = getOrDefault(fitIntercept) } + +/** + * :: DeveloperApi :: + * Trait for shared param seed (default: Utils.random.nextLong()). + */ +@DeveloperApi +trait HasSeed extends Params { + + /** + * Param for random seed. + * @group param + */ + final val seed: LongParam = new LongParam(this, "seed", "random seed") + + setDefault(seed, Utils.random.nextLong()) + + /** @group getParam */ + final def getSeed: Long = getOrDefault(seed) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 49a8b77acf960..756725a64b0f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -42,8 +42,7 @@ import org.apache.spark.sql.DataFrame @AlphaComponent final class DecisionTreeRegressor extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] - with DecisionTreeParams - with TreeRegressorParams { + with DecisionTreeParams with TreeRegressorParams { // Override parameter setters from parent trait for Java API compatibility. @@ -60,8 +59,7 @@ final class DecisionTreeRegressor override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) - override def setCheckpointInterval(value: Int): this.type = - super.setCheckpointInterval(value) + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) override def setImpurity(value: String): this.type = super.setImpurity(value) @@ -78,15 +76,13 @@ final class DecisionTreeRegressor /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { - val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0) - strategy.algo = OldAlgo.Regression - strategy.setImpurity(getOldImpurity) - strategy + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, + subsamplingRate = 1.0) } } object DecisionTreeRegressor { - /** Accessor for supported impurities */ + /** Accessor for supported impurities: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala new file mode 100644 index 0000000000000..c784cf39ed31a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap, Param} +import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, + SquaredError => OldSquaredError} +import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * learning algorithm for regression. + * It supports both continuous and categorical features. + */ +@AlphaComponent +final class GBTRegressor + extends Predictor[Vector, GBTRegressor, GBTRegressionModel] + with GBTParams with TreeRegressorParams with Logging { + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeRegressorParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + /** + * The impurity setting is ignored for GBT models. + * Individual trees are built using impurity "Variance." + */ + override def setImpurity(value: String): this.type = { + logWarning("GBTRegressor.setImpurity should NOT be used") + this + } + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = { + logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") + super.setSeed(value) + } + + // Parameters from GBTParams: + + override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + + override def setStepSize(value: Double): this.type = super.setStepSize(value) + + // Parameters for GBTRegressor: + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "squared" (L2) and "absolute" (L1) + * (default = squared) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTRegressor.supportedLossTypes.mkString(", ")}") + + setDefault(lossType -> "squared") + + /** @group setParam */ + def setLossType(value: String): this.type = { + val lossStr = value.toLowerCase + require(GBTRegressor.supportedLossTypes.contains(lossStr), "GBTRegressor was given bad loss" + + s" type: $value. Supported options: ${GBTRegressor.supportedLossTypes.mkString(", ")}") + set(lossType, lossStr) + this + } + + /** @group getParam */ + def getLossType: String = getOrDefault(lossType) + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "squared" => OldSquaredError + case "absolute" => OldAbsoluteError + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") + } + } + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): GBTRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + val oldGBT = new OldGBT(boostingStrategy) + val oldModel = oldGBT.run(oldDataset) + GBTRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } +} + +object GBTRegressor { + // The losses below should be lowercase. + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * model for regression. + * It supports both continuous and categorical features. + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ +@AlphaComponent +final class GBTRegressionModel( + override val parent: GBTRegressor, + override val fittingParamMap: ParamMap, + private val _trees: Array[DecisionTreeRegressionModel], + private val _treeWeights: Array[Double]) + extends PredictionModel[Vector, GBTRegressionModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.") + require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + + s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + override def treeWeights: Array[Double] = _treeWeights + + override protected def predict(features: Vector): Double = { + // TODO: Override transform() to broadcast model. SPARK-7127 + // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 + // Classifies by thresholding sum of weighted tree predictions + val treePredictions = _trees.map(_.rootNode.predict(features)) + val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + if (prediction > 0.0) 1.0 else 0.0 + } + + override protected def copy(): GBTRegressionModel = { + val m = new GBTRegressionModel(parent, fittingParamMap, _trees, _treeWeights) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"GBTRegressionModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldGBTModel = { + new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) + } +} + +private[ml] object GBTRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldGBTModel, + parent: GBTRegressor, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): GBTRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent, fittingParamMap for each tree is null since there are no good ways to set these. + DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures) + } + new GBTRegressionModel(parent, fittingParamMap, newTrees, oldModel.treeWeights) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala new file mode 100644 index 0000000000000..2171ef3d32c26 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams} +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression. + * It supports both continuous and categorical features. + */ +@AlphaComponent +final class RandomForestRegressor + extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] + with RandomForestParams with TreeRegressorParams { + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeRegressorParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = super.setSeed(value) + + // Parameters from RandomForestParams: + + override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + + override def setFeatureSubsetStrategy(value: String): this.type = + super.setFeatureSubsetStrategy(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): RandomForestRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) + val oldModel = OldRandomForest.trainRegressor( + oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) + RandomForestRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } +} + +object RandomForestRegressor { + /** Accessor for supported impurity settings: variance */ + final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + + /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + final val supportedFeatureSubsetStrategies: Array[String] = + RandomForestParams.supportedFeatureSubsetStrategies +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. + * It supports both continuous and categorical features. + * @param _trees Decision trees in the ensemble. + */ +@AlphaComponent +final class RandomForestRegressionModel private[ml] ( + override val parent: RandomForestRegressor, + override val fittingParamMap: ParamMap, + private val _trees: Array[DecisionTreeRegressionModel]) + extends PredictionModel[Vector, RandomForestRegressionModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + // Note: We may add support for weights (based on tree performance) later on. + private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + + override def treeWeights: Array[Double] = _treeWeights + + override protected def predict(features: Vector): Double = { + // TODO: Override transform() to broadcast model. SPARK-7127 + // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 + // Predict average of tree predictions. + // Ignore the weights since all are 1.0 for now. + _trees.map(_.rootNode.predict(features)).sum / numTrees + } + + override protected def copy(): RandomForestRegressionModel = { + val m = new RandomForestRegressionModel(parent, fittingParamMap, _trees) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"RandomForestRegressionModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldRandomForestModel = { + new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) + } +} + +private[ml] object RandomForestRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldRandomForestModel, + parent: RandomForestRegressor, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent, fittingParamMap for each tree is null since there are no good ways to set these. + DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures) + } + new RandomForestRegressionModel(parent, fittingParamMap, newTrees) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d6e2203d9f937..d2dec0c76cb12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -28,9 +28,9 @@ import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformation sealed abstract class Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree - // code into the new API and deprecate the old API. + // code into the new API and deprecate the old API. SPARK-3727 - /** Prediction this node makes (or would make, if it is an internal node) */ + /** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */ def prediction: Double /** Impurity measure at this node (for training data) */ @@ -194,7 +194,7 @@ private object InternalNode { s"$featureStr > ${contSplit.threshold}" } case catSplit: CategoricalSplit => - val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}") + val categoriesStr = catSplit.leftCategories.mkString("{", ",", "}") if (left) { s"$featureStr in $categoriesStr" } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 708c769087dd0..90f1d052764d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -44,7 +44,7 @@ private[tree] object Split { oldSplit.featureType match { case OldFeatureType.Categorical => new CategoricalSplit(featureIndex = oldSplit.feature, - leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature)) + _leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature)) case OldFeatureType.Continuous => new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold) } @@ -54,30 +54,30 @@ private[tree] object Split { /** * Split which tests a categorical feature. * @param featureIndex Index of the feature to test - * @param leftCategories If the feature value is in this set of categories, then the split goes - * left. Otherwise, it goes right. + * @param _leftCategories If the feature value is in this set of categories, then the split goes + * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ final class CategoricalSplit private[ml] ( override val featureIndex: Int, - leftCategories: Array[Double], + _leftCategories: Array[Double], private val numCategories: Int) extends Split { - require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + - s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}") + require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + + s" (should be in range [0, $numCategories)): ${_leftCategories.mkString(",")}") /** * If true, then "categories" is the set of categories for splitting to the left, and vice versa. */ - private val isLeft: Boolean = leftCategories.length <= numCategories / 2 + private val isLeft: Boolean = _leftCategories.length <= numCategories / 2 /** Set of categories determining the splitting rule, along with [[isLeft]]. */ private val categories: Set[Double] = { if (isLeft) { - leftCategories.toSet + _leftCategories.toSet } else { - setComplement(leftCategories.toSet) + setComplement(_leftCategories.toSet) } } @@ -107,13 +107,13 @@ final class CategoricalSplit private[ml] ( } /** Get sorted categories which split to the left */ - def getLeftCategories: Array[Double] = { + def leftCategories: Array[Double] = { val cats = if (isLeft) categories else setComplement(categories) cats.toArray.sorted } /** Get sorted categories which split to the right */ - def getRightCategories: Array[Double] = { + def rightCategories: Array[Double] = { val cats = if (isLeft) setComplement(categories) else categories cats.toArray.sorted } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 8e3bc3849dcf0..1929f9d02156e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -17,18 +17,13 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.AlphaComponent - /** - * :: AlphaComponent :: - * * Abstraction for Decision Tree models. * - * TODO: Add support for predicting probabilities and raw predictions + * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 */ -@AlphaComponent -trait DecisionTreeModel { +private[ml] trait DecisionTreeModel { /** Root of the decision tree */ def rootNode: Node @@ -58,3 +53,40 @@ trait DecisionTreeModel { header + rootNode.subtreeToString(2) } } + +/** + * Abstraction for models which are ensembles of decision trees + * + * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 + */ +private[ml] trait TreeEnsembleModel { + + // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of + // DecisionTreeModel. + + /** Trees in this ensemble. Warning: These have null parent Estimators. */ + def trees: Array[DecisionTreeModel] + + /** Weights for each tree, zippable with [[trees]] */ + def treeWeights: Array[Double] + + /** Summary of the model */ + override def toString: String = { + // Implementing classes should generally override this method to be more descriptive. + s"TreeEnsembleModel with $numTrees trees" + } + + /** Full description of model */ + def toDebugString: String = { + val header = toString + "\n" + header + trees.zip(treeWeights).zipWithIndex.map { case ((tree, weight), treeIndex) => + s" Tree $treeIndex (weight $weight):\n" + tree.rootNode.subtreeToString(4) + }.fold("")(_ + _) + } + + /** Number of trees in ensemble */ + val numTrees: Int = trees.length + + /** Total number of nodes, summed over all trees in the ensemble. */ + lazy val totalNumNodes: Int = trees.map(_.numNodes).sum +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java index 43b8787f9dd7e..60f25e5cce437 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.ml.classification; -import java.io.File; import java.io.Serializable; import java.util.HashMap; import java.util.Map; @@ -32,7 +31,6 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; -import org.apache.spark.util.Utils; public class JavaDecisionTreeClassifierSuite implements Serializable { @@ -57,7 +55,7 @@ public void runDT() { double B = -1.5; JavaRDD data = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap(); DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); @@ -71,8 +69,8 @@ public void runDT() { .setCacheNodeIds(false) .setCheckpointInterval(10) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) { - dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]); + for (String impurity: DecisionTreeClassifier.supportedImpurities()) { + dt.setImpurity(impurity); } DecisionTreeClassificationModel model = dt.fit(dataFrame); @@ -82,7 +80,7 @@ public void runDT() { model.toDebugString(); /* - // TODO: Add test once save/load are implemented. + // TODO: Add test once save/load are implemented. SPARK-6725 File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); String path = tempDir.toURI().toString(); try { diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java new file mode 100644 index 0000000000000..3c69467fa119e --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaGBTClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGBTClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + GBTClassifier rf = new GBTClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setMaxIter(3) + .setStepSize(0.1) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String lossType: GBTClassifier.supportedLossTypes()) { + rf.setLossType(lossType); + } + GBTClassificationModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + GBTClassificationModel sameModel = GBTClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java new file mode 100644 index 0000000000000..32d0b3856b7e2 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaRandomForestClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + RandomForestClassifier rf = new RandomForestClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setNumTrees(3) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: RandomForestClassifier.supportedImpurities()) { + rf.setImpurity(impurity); + } + for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) { + rf.setFeatureSubsetStrategy(featureSubsetStrategy); + } + RandomForestClassificationModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + RandomForestClassificationModel sameModel = + RandomForestClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index a3a339004f31c..71b041818d7ee 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.ml.regression; -import java.io.File; import java.io.Serializable; import java.util.HashMap; import java.util.Map; @@ -32,7 +31,6 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; -import org.apache.spark.util.Utils; public class JavaDecisionTreeRegressorSuite implements Serializable { @@ -57,22 +55,22 @@ public void runDT() { double B = -1.5; JavaRDD data = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap(); DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. DecisionTreeRegressor dt = new DecisionTreeRegressor() - .setMaxDepth(2) - .setMaxBins(10) - .setMinInstancesPerNode(5) - .setMinInfoGain(0.0) - .setMaxMemoryInMB(256) - .setCacheNodeIds(false) - .setCheckpointInterval(10) - .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) { - dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]); + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: DecisionTreeRegressor.supportedImpurities()) { + dt.setImpurity(impurity); } DecisionTreeRegressionModel model = dt.fit(dataFrame); @@ -82,7 +80,7 @@ public void runDT() { model.toDebugString(); /* - // TODO: Add test once save/load are implemented. + // TODO: Add test once save/load are implemented. SPARK-6725 File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); String path = tempDir.toURI().toString(); try { diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java new file mode 100644 index 0000000000000..fc8c13db07e6f --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaGBTRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGBTRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + + GBTRegressor rf = new GBTRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setMaxIter(3) + .setStepSize(0.1) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String lossType: GBTRegressor.supportedLossTypes()) { + rf.setLossType(lossType); + } + GBTRegressionModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + GBTRegressionModel sameModel = GBTRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java new file mode 100644 index 0000000000000..e306ebadfe7cf --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaRandomForestRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + + // This tests setters. Training with various options is tested in Scala. + RandomForestRegressor rf = new RandomForestRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setNumTrees(3) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: RandomForestRegressor.supportedImpurities()) { + rf.setImpurity(impurity); + } + for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) { + rf.setFeatureSubsetStrategy(featureSubsetStrategy); + } + RandomForestRegressionModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + RandomForestRegressionModel sameModel = RandomForestRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index af88595df5245..9b31adecdcb1c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -230,7 +230,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented + // TODO: Reinstate test once save/load are implemented SPARK-6725 /* test("model save/load") { val tempDir = Utils.createTempDir() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala new file mode 100644 index 0000000000000..e6ccc2c93cba8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * Test suite for [[GBTClassifier]]. + */ +class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext { + + import GBTClassifierSuite.compareAPIs + + // Combinations for estimators, learning rates and subsamplingRate + private val testCombinations = + Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) + + private var data: RDD[LabeledPoint] = _ + private var trainData: RDD[LabeledPoint] = _ + private var validationData: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) + trainData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) + validationData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) + } + + test("Binary classification with continuous features: Log Loss") { + val categoricalFeatures = Map.empty[Int, Int] + testCombinations.foreach { + case (maxIter, learningRate, subsamplingRate) => + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setSubsamplingRate(subsamplingRate) + .setLossType("logistic") + .setMaxIter(maxIter) + .setStepSize(learningRate) + compareAPIs(data, None, gbt, categoricalFeatures) + } + } + + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 + /* + test("runWithValidation stops early and performs better on a validation dataset") { + val categoricalFeatures = Map.empty[Int, Int] + // Set maxIter large enough so that it stops early. + val maxIter = 20 + GBTClassifier.supportedLossTypes.foreach { loss => + val gbt = new GBTClassifier() + .setMaxIter(maxIter) + .setMaxDepth(2) + .setLossType(loss) + .setValidationTol(0.0) + compareAPIs(trainData, None, gbt, categoricalFeatures) + compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures) + } + } + */ + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray + val treeWeights = Array(0.1, 0.3, 1.1) + val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights) + val newModel = GBTClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = GBTClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object GBTClassifierSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + validationData: Option[RDD[LabeledPoint]], + gbt: GBTClassifier, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldBoostingStrategy = + gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + val oldGBT = new OldGBT(oldBoostingStrategy) + val oldModel = oldGBT.run(data) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val newModel = gbt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldModelAsNew = GBTClassificationModel.fromOld(oldModel, newModel.parent, + newModel.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala new file mode 100644 index 0000000000000..ed41a9664f94f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * Test suite for [[RandomForestClassifier]]. + */ +class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext { + + import RandomForestClassifierSuite.compareAPIs + + private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ + private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + orderedLabeledPoints50_1000 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) + orderedLabeledPoints5_20 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20)) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier) { + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + val newRF = rf + .setImpurity("Gini") + .setMaxDepth(2) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses) + } + + test("Binary classification with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestClassifier() + binaryClassificationTestWithContinuousFeatures(rf) + } + + test("Binary classification with continuous features and node Id cache:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestClassifier() + .setCacheNodeIds(true) + binaryClassificationTestWithContinuousFeatures(rf) + } + + test("alternating categorical and continuous features with multiclass labels to test indexing") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)), + LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + ) + val rdd = sc.parallelize(arr) + val categoricalFeatures = Map(0 -> 3, 2 -> 2, 4 -> 4) + val numClasses = 3 + + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(5) + .setNumTrees(2) + .setFeatureSubsetStrategy("sqrt") + .setSeed(12345) + compareAPIs(rdd, rf, categoricalFeatures, numClasses) + } + + test("subsampling rate in RandomForest"){ + val rdd = orderedLabeledPoints5_20 + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val rf1 = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setCacheNodeIds(true) + .setNumTrees(3) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(rdd, rf1, categoricalFeatures, numClasses) + + val rf2 = rf1.setSubsamplingRate(0.5) + compareAPIs(rdd, rf2, categoricalFeatures, numClasses) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = + Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray + val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees) + val newModel = RandomForestClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = RandomForestClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object RandomForestClassifierSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + rf: RandomForestClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int): Unit = { + val oldStrategy = + rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity) + val oldModel = OldRandomForest.trainClassifier( + data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + val newModel = rf.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldModelAsNew = RandomForestClassificationModel.fromOld(oldModel, newModel.parent, + newModel.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 2e57d4ce37f1d..1505ad872536b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -23,8 +23,7 @@ import org.scalatest.FunSuite import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} -import org.apache.spark.ml.impl.tree._ -import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node} +import org.apache.spark.ml.tree._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, DataFrame} @@ -111,22 +110,19 @@ private[ml] object TreeTests extends FunSuite { } } - // TODO: Reinstate after adding ensembles /** * Check if the two models are exactly the same. * If the models are not equal, this throws an exception. */ - /* def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = { try { - a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) => + a.trees.zip(b.trees).foreach { case (treeA, treeB) => TreeTests.checkEqual(treeA, treeB) } - assert(a.getTreeWeights === b.getTreeWeights) + assert(a.treeWeights === b.treeWeights) } catch { case ex: Exception => throw new AssertionError( "checkEqual failed since the two tree ensembles were not identical") } } - */ } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 0b40fe33fae9d..c87a171b4b229 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -66,7 +66,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: test("model save/load") + // TODO: test("model save/load") SPARK-6725 } private[ml] object DecisionTreeRegressorSuite extends FunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala new file mode 100644 index 0000000000000..4aec36948ac92 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * Test suite for [[GBTRegressor]]. + */ +class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext { + + import GBTRegressorSuite.compareAPIs + + // Combinations for estimators, learning rates and subsamplingRate + private val testCombinations = + Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) + + private var data: RDD[LabeledPoint] = _ + private var trainData: RDD[LabeledPoint] = _ + private var validationData: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) + trainData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) + validationData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) + } + + test("Regression with continuous features: SquaredError") { + val categoricalFeatures = Map.empty[Int, Int] + GBTRegressor.supportedLossTypes.foreach { loss => + testCombinations.foreach { + case (maxIter, learningRate, subsamplingRate) => + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setSubsamplingRate(subsamplingRate) + .setLossType(loss) + .setMaxIter(maxIter) + .setStepSize(learningRate) + compareAPIs(data, None, gbt, categoricalFeatures) + } + } + } + + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 + /* + test("runWithValidation stops early and performs better on a validation dataset") { + val categoricalFeatures = Map.empty[Int, Int] + // Set maxIter large enough so that it stops early. + val maxIter = 20 + GBTRegressor.supportedLossTypes.foreach { loss => + val gbt = new GBTRegressor() + .setMaxIter(maxIter) + .setMaxDepth(2) + .setLossType(loss) + .setValidationTol(0.0) + compareAPIs(trainData, None, gbt, categoricalFeatures) + compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures) + } + } + */ + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray + val treeWeights = Array(0.1, 0.3, 1.1) + val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights) + val newModel = GBTRegressionModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = GBTRegressionModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object GBTRegressorSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + validationData: Option[RDD[LabeledPoint]], + gbt: GBTRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + val oldGBT = new OldGBT(oldBoostingStrategy) + val oldModel = oldGBT.run(data) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newModel = gbt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, + newModel.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala new file mode 100644 index 0000000000000..c6dc1cc29b6ff --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * Test suite for [[RandomForestRegressor]]. + */ +class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { + + import RandomForestRegressorSuite.compareAPIs + + private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + orderedLabeledPoints50_1000 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + def regressionTestWithContinuousFeatures(rf: RandomForestRegressor) { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val newRF = rf + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeaturesInfo) + } + + test("Regression with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestRegressor() + regressionTestWithContinuousFeatures(rf) + } + + test("Regression with continuous features and node Id cache :" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestRegressor() + .setCacheNodeIds(true) + regressionTestWithContinuousFeatures(rf) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray + val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees) + val newModel = RandomForestRegressionModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = RandomForestRegressionModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object RandomForestRegressorSuite extends FunSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + rf: RandomForestRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldStrategy = + rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity) + val oldModel = OldRandomForest.trainRegressor( + data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newModel = rf.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldModelAsNew = RandomForestRegressionModel.fromOld(oldModel, newModel.parent, + newModel.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 249b8eae19b17..ce983eb27fa35 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -998,7 +998,7 @@ object DecisionTreeSuite extends FunSuite { node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical, categories = List(0.0, 1.0))) } - // TODO: The information gain stats should be consistent with the same info stored in children. + // TODO: The information gain stats should be consistent with info in children: SPARK-7131 node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2, leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6))) node @@ -1006,9 +1006,9 @@ object DecisionTreeSuite extends FunSuite { /** * Create a tree model. This is deterministic and contains a variety of node and feature types. - * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.) + * TODO: Update to be a correct tree (with matching probabilities, impurities, etc.): SPARK-7131 */ - private[mllib] def createModel(algo: Algo): DecisionTreeModel = { + private[spark] def createModel(algo: Algo): DecisionTreeModel = { val topNode = createInternalNode(id = 1, Continuous) val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical)) val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7)) From aa6966ff34dacc83c3ca675b5109b05e35015469 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 25 Apr 2015 13:43:39 -0700 Subject: [PATCH 130/191] [SQL] Update SQL readme to include instructions on generating golden answer files based on Hive 0.13.1. Author: Yin Huai Closes #5702 from yhuai/howToGenerateGoldenFiles and squashes the following commits: 9c4a7f8 [Yin Huai] Update readme to include instructions on generating golden answer files based on Hive 0.13.1. --- sql/README.md | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/README.md b/sql/README.md index 237620e3fa808..46aec7cef7984 100644 --- a/sql/README.md +++ b/sql/README.md @@ -12,7 +12,10 @@ Spark SQL is broken up into four subprojects: Other dependencies for developers --------------------------------- -In order to create new hive test cases , you will need to set several environmental variables. +In order to create new hive test cases (i.e. a test suite based on `HiveComparisonTest`), +you will need to setup your development environment based on the following instructions. + +If you are working with Hive 0.12.0, you will need to set several environmental variables as follows. ``` export HIVE_HOME="/hive/build/dist" @@ -20,6 +23,24 @@ export HIVE_DEV_HOME="/hive/" export HADOOP_HOME="/hadoop-1.0.4" ``` +If you are working with Hive 0.13.1, the following steps are needed: + +1. Download Hive's [0.13.1](https://hive.apache.org/downloads.html) and set `HIVE_HOME` with `export HIVE_HOME=""`. Please do not set `HIVE_DEV_HOME` (See [SPARK-4119](https://issues.apache.org/jira/browse/SPARK-4119)). +2. Set `HADOOP_HOME` with `export HADOOP_HOME=""` +3. Download all Hive 0.13.1a jars (Hive jars actually used by Spark) from [here](http://mvnrepository.com/artifact/org.spark-project.hive) and replace corresponding original 0.13.1 jars in `$HIVE_HOME/lib`. +4. Download [Kryo 2.21 jar](http://mvnrepository.com/artifact/com.esotericsoftware.kryo/kryo/2.21) (Note: 2.22 jar does not work) and [Javolution 5.5.1 jar](http://mvnrepository.com/artifact/javolution/javolution/5.5.1) to `$HIVE_HOME/lib`. +5. This step is optional. But, when generating golden answer files, if a Hive query fails and you find that Hive tries to talk to HDFS or you find weird runtime NPEs, set the following in your test suite... + +``` +val testTempDir = Utils.createTempDir() +// We have to use kryo to let Hive correctly serialize some plans. +sql("set hive.plan.serialization.format=kryo") +// Explicitly set fs to local fs. +sql(s"set fs.default.name=file://$testTempDir/") +// Ask Hive to run jobs in-process as a single map and reduce task. +sql("set mapred.job.tracker=local") +``` + Using the console ================= An interactive scala console can be invoked by running `build/sbt hive/console`. From a11c8683c76c67f45749a1b50a0912a731fd2487 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Sat, 25 Apr 2015 18:07:34 -0400 Subject: [PATCH 131/191] [SPARK-7092] Update spark scala version to 2.11.6 Author: Prashant Sharma Closes #5662 from ScrapCodes/SPARK-7092/scala-update-2.11.6 and squashes the following commits: 58cf4f9 [Prashant Sharma] [SPARK-7092] Update spark scala version to 2.11.6 --- pom.xml | 4 ++-- .../src/main/scala/org/apache/spark/repl/SparkIMain.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index 4b0b0c85eff21..9fbce1d639d8b 100644 --- a/pom.xml +++ b/pom.xml @@ -1745,9 +1745,9 @@ scala-2.11 - 2.11.2 + 2.11.6 2.11 - 2.12 + 2.12.1 jline diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 1bb62c84abddc..1cb910f376060 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1129,7 +1129,7 @@ class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings def apply(line: String): Result = debugging(s"""parse("$line")""") { var isIncomplete = false - currentRun.reporting.withIncompleteHandler((_, _) => isIncomplete = true) { + currentRun.parsing.withIncompleteHandler((_, _) => isIncomplete = true) { reporter.reset() val trees = newUnitParser(line).parseStats() if (reporter.hasErrors) Error From f5473c2bbf66cc1144a90b4c29f3ce54ad7cc419 Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Sat, 25 Apr 2015 20:02:23 -0400 Subject: [PATCH 132/191] [SPARK-6014] [CORE] [HOTFIX] Add try-catch block around ShutDownHook Add a try/catch block around removeShutDownHook else IllegalStateException thrown in YARN cluster mode (see https://github.com/apache/spark/pull/4690) cc andrewor14, srowen Author: Nishkam Ravi Author: nishkamravi2 Author: nravi Closes #5672 from nishkamravi2/master_nravi and squashes the following commits: 0f1abd0 [nishkamravi2] Update Utils.scala 474e3bf [nishkamravi2] Update DiskBlockManager.scala 97c383e [nishkamravi2] Update Utils.scala 8691e0c [Nishkam Ravi] Add a try/catch block around Utils.removeShutdownHook 2be1e76 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 1c13b79 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi bad4349 [nishkamravi2] Update Main.java 36a6f87 [Nishkam Ravi] Minor changes and bug fixes b7f4ae7 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 4a45d6a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 458af39 [Nishkam Ravi] Locate the jar using getLocation, obviates the need to pass assembly path as an argument d9658d6 [Nishkam Ravi] Changes for SPARK-6406 ccdc334 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 3faa7a4 [Nishkam Ravi] Launcher library changes (SPARK-6406) 345206a [Nishkam Ravi] spark-class merge Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi ac58975 [Nishkam Ravi] spark-class changes 06bfeb0 [nishkamravi2] Update spark-class 35af990 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 32c3ab3 [nishkamravi2] Update AbstractCommandBuilder.java 4bd4489 [nishkamravi2] Update AbstractCommandBuilder.java 746f35b [Nishkam Ravi] "hadoop" string in the assembly name should not be mandatory (everywhere else in spark we mandate spark-assembly*hadoop*.jar) bfe96e0 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi ee902fa [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi d453197 [nishkamravi2] Update NewHadoopRDD.scala 6f41a1d [nishkamravi2] Update NewHadoopRDD.scala 0ce2c32 [nishkamravi2] Update HadoopRDD.scala f7e33c2 [Nishkam Ravi] Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi ba1eb8b [Nishkam Ravi] Try-catch block around the two occurrences of removeShutDownHook. Deletion of semi-redundant occurrences of expensive operation inShutDown. 71d0e17 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 494d8c0 [nishkamravi2] Update DiskBlockManager.scala 3c5ddba [nishkamravi2] Update DiskBlockManager.scala f0d12de [Nishkam Ravi] Workaround for IllegalStateException caused by recent changes to BlockManager.stop 79ea8b4 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi b446edc [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 5c9a4cb [nishkamravi2] Update TaskSetManagerSuite.scala 535295a [nishkamravi2] Update TaskSetManager.scala 3e1b616 [Nishkam Ravi] Modify test for maxResultSize 9f6583e [Nishkam Ravi] Changes to maxResultSize code (improve error message and add condition to check if maxResultSize > 0) 5f8f9ed [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 636a9ff [nishkamravi2] Update YarnAllocator.scala 8f76c8b [Nishkam Ravi] Doc change for yarn memory overhead 35daa64 [Nishkam Ravi] Slight change in the doc for yarn memory overhead 5ac2ec1 [Nishkam Ravi] Remove out dac1047 [Nishkam Ravi] Additional documentation for yarn memory overhead issue 42c2c3d [Nishkam Ravi] Additional changes for yarn memory overhead issue 362da5e [Nishkam Ravi] Additional changes for yarn memory overhead c726bd9 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi f00fa31 [Nishkam Ravi] Improving logging for AM memoryOverhead 1cf2d1e [nishkamravi2] Update YarnAllocator.scala ebcde10 [Nishkam Ravi] Modify default YARN memory_overhead-- from an additive constant to a multiplier (redone to resolve merge conflicts) 2e69f11 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi efd688a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark 2b630f9 [nravi] Accept memory input as "30g", "512M" instead of an int value, to be consistent with rest of Spark 3bf8fad [nravi] Merge branch 'master' of https://github.com/apache/spark 5423a03 [nravi] Merge branch 'master' of https://github.com/apache/spark eb663ca [nravi] Merge branch 'master' of https://github.com/apache/spark df2aeb1 [nravi] Improved fix for ConcurrentModificationIssue (Spark-1097, Hadoop-10456) 6b840f0 [nravi] Undo the fix for SPARK-1758 (the problem is fixed) 5108700 [nravi] Fix in Spark for the Concurrent thread modification issue (SPARK-1097, HADOOP-10456) 681b36f [nravi] Fix for SPARK-1758: failing test org.apache.spark.JavaAPISuite.wholeTextFiles --- .../scala/org/apache/spark/storage/DiskBlockManager.scala | 7 ++++++- core/src/main/scala/org/apache/spark/util/Utils.scala | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 7ea5e54f9e1fe..5764c16902c66 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -148,7 +148,12 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { // Remove the shutdown hook. It causes memory leaks if we leave it around. - Utils.removeShutdownHook(shutdownHook) + try { + Utils.removeShutdownHook(shutdownHook) + } catch { + case e: Exception => + logError(s"Exception while removing shutdown hook.", e) + } doStop() } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 667aa168e7ef3..c6c6df7cfa56e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2172,7 +2172,7 @@ private [util] class SparkShutdownHookManager { def runAll(): Unit = synchronized { shuttingDown = true while (!hooks.isEmpty()) { - Utils.logUncaughtExceptions(hooks.poll().run()) + Try(Utils.logUncaughtExceptions(hooks.poll().run())) } } @@ -2184,7 +2184,6 @@ private [util] class SparkShutdownHookManager { } def remove(ref: AnyRef): Boolean = synchronized { - checkState() hooks.remove(ref) } From 9a5bbe05fc1b1141e139d32661821fef47d7a13c Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 26 Apr 2015 07:14:24 -0400 Subject: [PATCH 133/191] [MINOR] [MLLIB] Refactor toString method in MLLIB MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. predict(predict.toString) has already output prefix “predict” thus it’s duplicated to print ", predict = " again 2. there are some extra spaces Author: Alain Closes #5687 from AiHe/tree-node-issue-2 and squashes the following commits: 9862b9a [Alain] Pass scala coding style checking 44ba947 [Alain] Minor][MLLIB] Format toString method in MLLIB bdc402f [Alain] [Minor][MLLIB] Fix a formatting bug in toString method in Node 426eee7 [Alain] [Minor][MLLIB] Fix a formatting bug in toString method in Node.scala --- .../main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 2 +- .../org/apache/spark/mllib/regression/LabeledPoint.scala | 2 +- .../apache/spark/mllib/tree/model/InformationGainStats.scala | 4 ++-- .../main/scala/org/apache/spark/mllib/tree/model/Node.scala | 4 ++-- .../scala/org/apache/spark/mllib/tree/model/Predict.scala | 4 +--- .../main/scala/org/apache/spark/mllib/tree/model/Split.scala | 5 ++--- 6 files changed, 9 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 4ef171f4f0419..166c00cff634d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -526,7 +526,7 @@ class SparseVector( s" ${values.size} values.") override def toString: String = - "(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]")) + s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" override def toArray: Array[Double] = { val data = new Array[Double](size) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 2067b36f246b3..d5fea822ad77b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -32,7 +32,7 @@ import org.apache.spark.SparkException @BeanInfo case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { - "(%s,%s)".format(label, features) + s"($label,$features)" } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index f209fdafd3653..2d087c967f679 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -39,8 +39,8 @@ class InformationGainStats( val rightPredict: Predict) extends Serializable { override def toString: String = { - "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" - .format(gain, impurity, leftImpurity, rightImpurity) + s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " + + s"right impurity = $rightImpurity" } override def equals(o: Any): Boolean = o match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 86390a20cb5cc..431a839817eac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -51,8 +51,8 @@ class Node ( var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString: String = { - "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "impurity = " + impurity + ", split = " + split + ", stats = " + stats + s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " + + s"split = $split, stats = $stats" } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 25990af7c6cf7..5cbe7c280dbee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -29,9 +29,7 @@ class Predict( val predict: Double, val prob: Double = 0.0) extends Serializable { - override def toString: String = { - "predict = %f, prob = %f".format(predict, prob) - } + override def toString: String = s"$predict (prob = $prob)" override def equals(other: Any): Boolean = { other match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index fb35e70a8d077..be6c9b3de5479 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -39,8 +39,8 @@ case class Split( categories: List[Double]) { override def toString: String = { - "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + - ", categories = " + categories + s"Feature = $feature, threshold = $threshold, featureType = $featureType, " + + s"categories = $categories" } } @@ -68,4 +68,3 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType) */ private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) - From ca55dc95b777d96b27d4e4c0457dd25145dcd6e9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 26 Apr 2015 11:46:58 -0700 Subject: [PATCH 134/191] [SPARK-7152][SQL] Add a Column expression for partition ID. Author: Reynold Xin Closes #5705 from rxin/df-pid and squashes the following commits: 401018f [Reynold Xin] [SPARK-7152][SQL] Add a Column expression for partition ID. --- python/pyspark/sql/functions.py | 30 +++++++++----- .../expressions/SparkPartitionID.scala | 39 +++++++++++++++++++ .../sql/execution/expressions/package.scala | 23 +++++++++++ .../org/apache/spark/sql/functions.scala | 29 +++++++++----- .../spark/sql/ColumnExpressionSuite.scala | 8 ++++ 5 files changed, 110 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bb47923f24b82..f48b7b5d10af7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -75,6 +75,20 @@ def _(col): __all__.sort() +def approxCountDistinct(col, rsd=None): + """Returns a new :class:`Column` for approximate distinct count of ``col``. + + >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) + + def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. @@ -89,18 +103,16 @@ def countDistinct(col, *cols): return Column(jc) -def approxCountDistinct(col, rsd=None): - """Returns a new :class:`Column` for approximate distinct count of ``col``. +def sparkPartitionId(): + """Returns a column for partition ID of the Spark task. - >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() - [Row(c=2)] + Note that this is indeterministic because it depends on data partitioning and task scheduling. + + >>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect() + [Row(pid=0), Row(pid=0)] """ sc = SparkContext._active_spark_context - if rsd is None: - jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) - else: - jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) - return Column(jc) + return Column(sc._jvm.functions.sparkPartitionId()) class UserDefinedFunction(object): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala new file mode 100644 index 0000000000000..fe7607c6ac340 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.expressions + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.expressions.{Row, Expression} +import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types.{IntegerType, DataType} + + +/** + * Expression that returns the current partition id of the Spark task. + */ +case object SparkPartitionID extends Expression with trees.LeafNode[Expression] { + self: Product => + + override type EvaluatedType = Int + + override def nullable: Boolean = false + + override def dataType: DataType = IntegerType + + override def eval(input: Row): Int = TaskContext.get().partitionId() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala new file mode 100644 index 0000000000000..568b7ac2c5987 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +/** + * Package containing expressions that are specific to Spark runtime. + */ +package object expressions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ff91e1d74bc2c..9738fd4f93bad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -276,6 +276,13 @@ object functions { // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Computes the absolute value. + * + * @group normal_funcs + */ + def abs(e: Column): Column = Abs(e.expr) + /** * Returns the first column that is not null. * {{{ @@ -287,6 +294,13 @@ object functions { @scala.annotation.varargs def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + /** + * Converts a string exprsesion to lower case. + * + * @group normal_funcs + */ + def lower(e: Column): Column = Lower(e.expr) + /** * Unary minus, i.e. negate the expression. * {{{ @@ -317,18 +331,13 @@ object functions { def not(e: Column): Column = !e /** - * Converts a string expression to upper case. + * Partition ID of the Spark task. * - * @group normal_funcs - */ - def upper(e: Column): Column = Upper(e.expr) - - /** - * Converts a string exprsesion to lower case. + * Note that this is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs */ - def lower(e: Column): Column = Lower(e.expr) + def sparkPartitionId(): Column = execution.expressions.SparkPartitionID /** * Computes the square root of the specified float value. @@ -338,11 +347,11 @@ object functions { def sqrt(e: Column): Column = Sqrt(e.expr) /** - * Computes the absolutle value. + * Converts a string expression to upper case. * * @group normal_funcs */ - def abs(e: Column): Column = Abs(e.expr) + def upper(e: Column): Column = Upper(e.expr) ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index bc8fae100db6a..904073b8cb2aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -310,6 +310,14 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("sparkPartitionId") { + val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + checkAnswer( + df.select(sparkPartitionId()), + Row(0) + ) + } + test("lift alias out of cast") { compareExpressions( col("1234").as("name").cast("int").expr, From d188b8bad82836bf654e57f9dd4e1ddde1d530f4 Mon Sep 17 00:00:00 2001 From: wangfei Date: Sun, 26 Apr 2015 21:08:47 -0700 Subject: [PATCH 135/191] [SQL][Minor] rename DataTypeParser.apply to DataTypeParser.parse rename DataTypeParser.apply to DataTypeParser.parse to make it more clear and readable. /cc rxin Author: wangfei Closes #5710 from scwf/apply and squashes the following commits: c319977 [wangfei] rename apply to parse --- .../org/apache/spark/sql/catalyst/planning/patterns.scala | 2 +- .../scala/org/apache/spark/sql/types/DataTypeParser.scala | 2 +- .../org/apache/spark/sql/types/DataTypeParserSuite.scala | 4 ++-- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 2 +- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9c8c643f7d17a..4574934d910db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -92,7 +92,7 @@ object PhysicalOperation extends PredicateHelper { } def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { - case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child + case a @ Alias(child, _) => a.toAttribute -> child }.toMap def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 5163f05879e42..04f3379afb38d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -108,7 +108,7 @@ private[sql] object DataTypeParser { override val lexical = new SqlLexical } - def apply(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) + def parse(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) } /** The exception thrown from the [[DataTypeParser]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index 169125264a803..3e7cf7cbb5e63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -23,13 +23,13 @@ class DataTypeParserSuite extends FunSuite { def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { test(s"parse ${dataTypeString.replace("\n", "")}") { - assert(DataTypeParser(dataTypeString) === expectedDataType) + assert(DataTypeParser.parse(dataTypeString) === expectedDataType) } } def unsupported(dataTypeString: String): Unit = { test(s"$dataTypeString is not supported") { - intercept[DataTypeException](DataTypeParser(dataTypeString)) + intercept[DataTypeException](DataTypeParser.parse(dataTypeString)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index edb229c059e6b..33f9d0b37d006 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -647,7 +647,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops */ - def cast(to: String): Column = cast(DataTypeParser(to)) + def cast(to: String): Column = cast(DataTypeParser.parse(to)) /** * Returns an ordering used in sorting. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f1c0bd92aa23d..4d222cf88e5e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -871,7 +871,7 @@ private[hive] case class MetastoreRelation private[hive] object HiveMetastoreTypes { - def toDataType(metastoreType: String): DataType = DataTypeParser(metastoreType) + def toDataType(metastoreType: String): DataType = DataTypeParser.parse(metastoreType) def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" From 82bb7fd41a2c7992e0aea69623c504bd439744f7 Mon Sep 17 00:00:00 2001 From: baishuo Date: Mon, 27 Apr 2015 14:08:05 +0800 Subject: [PATCH 136/191] [SPARK-6505] [SQL] Remove the reflection call in HiveFunctionWrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit according liancheng‘s comment in https://issues.apache.org/jira/browse/SPARK-6505, this patch remove the reflection call in HiveFunctionWrapper, and implement the functions named "deserializeObjectByKryo" and "serializeObjectByKryo" according the functions with the save name in org.apache.hadoop.hive.ql.exec.Utilities.java Author: baishuo Closes #5660 from baishuo/SPARK-6505-20150423 and squashes the following commits: ae61ec4 [baishuo] modify code style 78d9fa3 [baishuo] modify code style 0b522a7 [baishuo] modify code style a5ff9c7 [baishuo] Remove the reflection call in HiveFunctionWrapper --- .../org/apache/spark/sql/hive/Shim13.scala | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index d331c210e8939..dbc5e029e2047 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -19,11 +19,15 @@ package org.apache.spark.sql.hive import java.rmi.server.UID import java.util.{Properties, ArrayList => JArrayList} +import java.io.{OutputStream, InputStream} import scala.collection.JavaConversions._ import scala.language.implicitConversions +import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst @@ -46,6 +50,7 @@ import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} +import org.apache.spark.util.Utils._ /** * This class provides the UDF creation and also the UDF instance serialization and @@ -61,39 +66,34 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) // for Serialization def this() = this(null) - import org.apache.spark.util.Utils._ - @transient - private val methodDeSerialize = { - val method = classOf[Utilities].getDeclaredMethod( - "deserializeObjectByKryo", - classOf[Kryo], - classOf[java.io.InputStream], - classOf[Class[_]]) - method.setAccessible(true) - - method + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp,clazz).asInstanceOf[T] + inp.close() + t } @transient - private val methodSerialize = { - val method = classOf[Utilities].getDeclaredMethod( - "serializeObjectByKryo", - classOf[Kryo], - classOf[Object], - classOf[java.io.OutputStream]) - method.setAccessible(true) - - method + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream ) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() } def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { - methodDeSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), is, clazz) + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) .asInstanceOf[UDFType] } def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { - methodSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), function, out) + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) } private var instance: AnyRef = null From 998aac21f0a0588a70f8cf123ae4080163c612fb Mon Sep 17 00:00:00 2001 From: Misha Chernetsov Date: Mon, 27 Apr 2015 11:27:56 -0700 Subject: [PATCH 137/191] [SPARK-4925] Publish Spark SQL hive-thriftserver maven artifact turned on hive-thriftserver profile in release script Author: Misha Chernetsov Closes #5429 from chernetsov/master and squashes the following commits: 9cc36af [Misha Chernetsov] [SPARK-4925] Publish Spark SQL hive-thriftserver maven artifact turned on hive-thriftserver profile in release script for scala 2.10 --- dev/create-release/create-release.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index b5a67dd783b93..3dbb35f7054a2 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -119,7 +119,7 @@ if [[ ! "$@" =~ --skip-publish ]]; then rm -rf $SPARK_REPO build/mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-version-to-2.11.sh From 7078f6028bf012235c664b02ec3541cbb0a248a7 Mon Sep 17 00:00:00 2001 From: Jeff Harrison Date: Mon, 27 Apr 2015 13:38:25 -0700 Subject: [PATCH 138/191] [SPARK-6856] [R] Make RDD information more useful in SparkR Author: Jeff Harrison Closes #5667 from His-name-is-Joof/joofspark and squashes the following commits: f8814a6 [Jeff Harrison] newline added after RDD show() output 4d9d972 [Jeff Harrison] Merge branch 'master' into joofspark 9d2295e [Jeff Harrison] parallelize with 1:10 878b830 [Jeff Harrison] Merge branch 'master' into joofspark c8c0b80 [Jeff Harrison] add test for RDD function show() 123be65 [Jeff Harrison] SPARK-6856 --- R/pkg/R/RDD.R | 5 +++++ R/pkg/inst/tests/test_rdd.R | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 1662d6bb3b1ac..f90c26b253455 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -66,6 +66,11 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, .Object }) +setMethod("show", "RDD", + function(.Object) { + cat(paste(callJMethod(.Object@jrdd, "toString"), "\n", sep="")) + }) + setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { .Object@env <- new.env() .Object@env$isCached <- FALSE diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index d55af93e3e50a..03207353c31c6 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -759,6 +759,11 @@ test_that("collectAsMap() on a pairwise RDD", { expect_equal(vals, list(`1` = "a", `2` = "b")) }) +test_that("show()", { + rdd <- parallelize(sc, list(1:10)) + expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") +}) + test_that("sampleByKey() on pairwise RDDs", { rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) From ef82bddc11d1aea42e22d2f85613a869cbe9a990 Mon Sep 17 00:00:00 2001 From: tedyu Date: Mon, 27 Apr 2015 14:42:40 -0700 Subject: [PATCH 139/191] SPARK-7107 Add parameter for zookeeper.znode.parent to hbase_inputformat... ....py Author: tedyu Closes #5673 from tedyu/master and squashes the following commits: ab7c72b [tedyu] SPARK-7107 Adjust indentation to pass Python style tests 6e25939 [tedyu] Adjust line length to be shorter than 100 characters 18d172a [tedyu] SPARK-7107 Add parameter for zookeeper.znode.parent to hbase_inputformat.py --- examples/src/main/python/hbase_inputformat.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index e17819d5feb76..5b82a14fba413 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -54,8 +54,9 @@ Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \ - /path/to/examples/hbase_inputformat.py + /path/to/examples/hbase_inputformat.py
[] Assumes you have some data in HBase already, running on , in
+ optionally, you can specify parent znode for your hbase cluster - """, file=sys.stderr) exit(-1) @@ -64,6 +65,9 @@ sc = SparkContext(appName="HBaseInputFormat") conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table} + if len(sys.argv) > 3: + conf = {"hbase.zookeeper.quorum": host, "zookeeper.znode.parent": sys.argv[3], + "hbase.mapreduce.inputtable": table} keyConv = "org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter" valueConv = "org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter" From ca9f4ebb8e510e521bf4df0331375ddb385fb9d2 Mon Sep 17 00:00:00 2001 From: hlin09 Date: Mon, 27 Apr 2015 15:04:37 -0700 Subject: [PATCH 140/191] [SPARK-6991] [SPARKR] Adds support for zipPartitions. Author: hlin09 Closes #5568 from hlin09/zipPartitions and squashes the following commits: 12c08a5 [hlin09] Fix comments d2d32db [hlin09] Merge branch 'master' into zipPartitions ec56d2f [hlin09] Fix test. 27655d3 [hlin09] Adds support for zipPartitions. --- R/pkg/NAMESPACE | 1 + R/pkg/R/RDD.R | 46 +++++++++++++++++++++++++ R/pkg/R/generics.R | 5 +++ R/pkg/inst/tests/test_binary_function.R | 33 ++++++++++++++++++ 4 files changed, 85 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 80283643861ac..e077eace74375 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -71,6 +71,7 @@ exportMethods( "unpersist", "value", "values", + "zipPartitions", "zipRDD", "zipWithIndex", "zipWithUniqueId" diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index f90c26b253455..a3a0421a0746d 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -1595,3 +1595,49 @@ setMethod("intersection", keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction)) }) + +#' Zips an RDD's partitions with one (or more) RDD(s). +#' Same as zipPartitions in Spark. +#' +#' @param ... RDDs to be zipped. +#' @param func A function to transform zipped partitions. +#' @return A new RDD by applying a function to the zipped partitions. +#' Assumes that all the RDDs have the *same number of partitions*, but +#' does *not* require them to have the same number of elements in each partition. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 +#' rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 +#' rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 +#' collect(zipPartitions(rdd1, rdd2, rdd3, +#' func = function(x, y, z) { list(list(x, y, z))} )) +#' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) +#'} +#' @rdname zipRDD +#' @aliases zipPartitions,RDD +setMethod("zipPartitions", + "RDD", + function(..., func) { + rrdds <- list(...) + if (length(rrdds) == 1) { + return(rrdds[[1]]) + } + nPart <- sapply(rrdds, numPartitions) + if (length(unique(nPart)) != 1) { + stop("Can only zipPartitions RDDs which have the same number of partitions.") + } + + rrdds <- lapply(rrdds, function(rdd) { + mapPartitionsWithIndex(rdd, function(partIndex, part) { + print(length(part)) + list(list(partIndex, part)) + }) + }) + union.rdd <- Reduce(unionRDD, rrdds) + zipped.rdd <- values(groupByKey(union.rdd, numPartitions = nPart[1])) + res <- mapPartitions(zipped.rdd, function(plist) { + do.call(func, plist[[1]]) + }) + res + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 34dbe84051c50..e88729387ef95 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -217,6 +217,11 @@ setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) #' @export setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) +#' @rdname zipRDD +#' @export +setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, + signature = "...") + #' @rdname zipWithIndex #' @seealso zipWithUniqueId #' @export diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index c15553ba28517..6785a7bdae8cb 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -66,3 +66,36 @@ test_that("cogroup on two RDDs", { expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) }) + +test_that("zipPartitions() on RDDs", { + rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 + rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 + rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 + actual <- collect(zipPartitions(rdd1, rdd2, rdd3, + func = function(x, y, z) { list(list(x, y, z))} )) + expect_equal(actual, + list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName, 1) + actual <- collect(zipPartitions(rdd, rdd, + func = function(x, y) { list(paste(x, y, sep = "\n")) })) + expected <- list(paste(mockFile, mockFile, sep = "\n")) + expect_equal(actual, expected) + + rdd1 <- parallelize(sc, 0:1, 1) + actual <- collect(zipPartitions(rdd1, rdd, + func = function(x, y) { list(x + nchar(y)) })) + expected <- list(0:1 + nchar(mockFile)) + expect_equal(actual, expected) + + rdd <- map(rdd, function(x) { x }) + actual <- collect(zipPartitions(rdd, rdd1, + func = function(x, y) { list(y + nchar(x)) })) + expect_equal(actual, expected) + + unlink(fileName) +}) From b9de9e040aff371c6acf9b3f3d1ff8b360c0cd56 Mon Sep 17 00:00:00 2001 From: Steven She Date: Mon, 27 Apr 2015 18:55:02 -0400 Subject: [PATCH 141/191] [SPARK-7103] Fix crash with SparkContext.union when RDD has no partitioner Added a check to the SparkContext.union method to check that a partitioner is defined on all RDDs when instantiating a PartitionerAwareUnionRDD. Author: Steven She Closes #5679 from stevencanopy/SPARK-7103 and squashes the following commits: 5a3d846 [Steven She] SPARK-7103: Fix crash with SparkContext.union when at least one RDD has no partitioner --- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/rdd/PartitionerAwareUnionRDD.scala | 1 + .../scala/org/apache/spark/rdd/RDDSuite.scala | 21 +++++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 86269eac52db0..ea4ddcc2e265d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1055,7 +1055,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Build the union of a list of RDDs. */ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = { val partitioners = rdds.flatMap(_.partitioner).toSet - if (partitioners.size == 1) { + if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { new PartitionerAwareUnionRDD(this, rdds) } else { new UnionRDD(this, rdds) diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 92b0641d0fb6e..7598ff617b399 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -60,6 +60,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( var rdds: Seq[RDD[T]] ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) { require(rdds.length > 0) + require(rdds.forall(_.partitioner.isDefined)) require(rdds.flatMap(_.partitioner).toSet.size == 1, "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index df42faab64505..ef8c36a28655b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -99,6 +99,27 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) } + test("SparkContext.union creates UnionRDD if at least one RDD has no partitioner") { + val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val unionRdd = sc.union(rddWithNoPartitioner, rddWithPartitioner) + assert(unionRdd.isInstanceOf[UnionRDD[_]]) + } + + test("SparkContext.union creates PartitionAwareUnionRDD if all RDDs have partitioners") { + val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val unionRdd = sc.union(rddWithPartitioner, rddWithPartitioner) + assert(unionRdd.isInstanceOf[PartitionerAwareUnionRDD[_]]) + } + + test("PartitionAwareUnionRDD raises exception if at least one RDD has no partitioner") { + val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + intercept[IllegalArgumentException] { + new PartitionerAwareUnionRDD(sc, Seq(rddWithNoPartitioner, rddWithPartitioner)) + } + } + test("partitioner aware union") { def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) From 8e1c00dbf4b60962908626dead744e5d73c8085e Mon Sep 17 00:00:00 2001 From: Hong Shen Date: Mon, 27 Apr 2015 18:57:31 -0400 Subject: [PATCH 142/191] [SPARK-6738] [CORE] Improve estimate the size of a large array Currently, SizeEstimator.visitArray is not correct in the follow case, ``` array size > 200, elem has the share object ``` when I add a debug log in SizeTracker.scala: ``` System.err.println(s"numUpdates:$numUpdates, size:$ts, bytesPerUpdate:$bytesPerUpdate, cost time:$b") ``` I get the following log: ``` numUpdates:1, size:262448, bytesPerUpdate:0.0, cost time:35 numUpdates:2, size:420698, bytesPerUpdate:158250.0, cost time:35 numUpdates:4, size:420754, bytesPerUpdate:28.0, cost time:32 numUpdates:7, size:420754, bytesPerUpdate:0.0, cost time:27 numUpdates:12, size:420754, bytesPerUpdate:0.0, cost time:28 numUpdates:20, size:420754, bytesPerUpdate:0.0, cost time:25 numUpdates:32, size:420754, bytesPerUpdate:0.0, cost time:21 numUpdates:52, size:420754, bytesPerUpdate:0.0, cost time:20 numUpdates:84, size:420754, bytesPerUpdate:0.0, cost time:20 numUpdates:135, size:420754, bytesPerUpdate:0.0, cost time:20 numUpdates:216, size:420754, bytesPerUpdate:0.0, cost time:11 numUpdates:346, size:420754, bytesPerUpdate:0.0, cost time:6 numUpdates:554, size:488911, bytesPerUpdate:327.67788461538464, cost time:8 numUpdates:887, size:2312259426, bytesPerUpdate:6942253.798798799, cost time:198 15/04/21 14:27:26 INFO collection.ExternalAppendOnlyMap: Thread 51 spilling in-memory map of 3.0 GB to disk (1 time so far) 15/04/21 14:27:26 INFO collection.ExternalAppendOnlyMap: /data11/yarnenv/local/usercache/spark/appcache/application_1426746631567_11745/spark-local-20150421142719-c001/30/temp_local_066af981-c2fc-4b70-a00e-110e23006fbc ``` But in fact the file size is only 162K: ``` $ ll -h /data11/yarnenv/local/usercache/spark/appcache/application_1426746631567_11745/spark-local-20150421142719-c001/30/temp_local_066af981-c2fc-4b70-a00e-110e23006fbc -rw-r----- 1 spark users 162K Apr 21 14:27 /data11/yarnenv/local/usercache/spark/appcache/application_1426746631567_11745/spark-local-20150421142719-c001/30/temp_local_066af981-c2fc-4b70-a00e-110e23006fbc ``` In order to test case, I change visitArray to: ``` var size = 0l for (i <- 0 until length) { val obj = JArray.get(array, i) size += SizeEstimator.estimate(obj, state.visited).toLong } state.size += size ``` I get the following log: ``` ... 14895 277016088 566.9046118590662 time:8470 23832 281840544 552.3308270676691 time:8031 38132 289891824 539.8294729775092 time:7897 61012 302803640 563.0265734265735 time:13044 97620 322904416 564.3276223776223 time:13554 15/04/14 11:46:43 INFO collection.ExternalAppendOnlyMap: Thread 51 spilling in-memory map of 314.5 MB to disk (1 time so far) 15/04/14 11:46:43 INFO collection.ExternalAppendOnlyMap: /data1/yarnenv/local/usercache/spark/appcache/application_1426746631567_8477/spark-local-20150414114020-2fcb/14/temp_local_5b6b98d5-5bfa-47e2-8216-059482ccbda0 ``` the file size is 85M. ``` $ ll -h /data1/yarnenv/local/usercache/spark/appcache/application_1426746631567_8477/spark- local-20150414114020-2fcb/14/ total 85M -rw-r----- 1 spark users 85M Apr 14 11:46 temp_local_5b6b98d5-5bfa-47e2-8216-059482ccbda0 ``` The following log is when I use this patch, ``` .... numUpdates:32, size:365484, bytesPerUpdate:0.0, cost time:7 numUpdates:52, size:365484, bytesPerUpdate:0.0, cost time:5 numUpdates:84, size:365484, bytesPerUpdate:0.0, cost time:5 numUpdates:135, size:372208, bytesPerUpdate:131.84313725490196, cost time:86 numUpdates:216, size:379020, bytesPerUpdate:84.09876543209876, cost time:21 numUpdates:346, size:1865208, bytesPerUpdate:11432.215384615385, cost time:23 numUpdates:554, size:2052380, bytesPerUpdate:899.8653846153846, cost time:16 numUpdates:887, size:2142820, bytesPerUpdate:271.59159159159157, cost time:15 .. numUpdates:14895, size:251675500, bytesPerUpdate:438.5263157894737, cost time:13 numUpdates:23832, size:257010268, bytesPerUpdate:596.9305135951662, cost time:14 numUpdates:38132, size:263922396, bytesPerUpdate:483.3655944055944, cost time:15 numUpdates:61012, size:268962596, bytesPerUpdate:220.28846153846155, cost time:24 numUpdates:97620, size:286980644, bytesPerUpdate:492.1888111888112, cost time:22 15/04/21 14:45:12 INFO collection.ExternalAppendOnlyMap: Thread 53 spilling in-memory map of 328.7 MB to disk (1 time so far) 15/04/21 14:45:12 INFO collection.ExternalAppendOnlyMap: /data4/yarnenv/local/usercache/spark/appcache/application_1426746631567_11758/spark-local-20150421144456-a2a5/2a/temp_local_9c109510-af16-4468-8f23-48cad04da88f ``` the file size is 88M. ``` $ ll -h /data4/yarnenv/local/usercache/spark/appcache/application_1426746631567_11758/spark-local-20150421144456-a2a5/2a/ total 88M -rw-r----- 1 spark users 88M Apr 21 14:45 temp_local_9c109510-af16-4468-8f23-48cad04da88f ``` Author: Hong Shen Closes #5608 from shenh062326/my_change5 and squashes the following commits: 5506bae [Hong Shen] Fix compile error c275dd3 [Hong Shen] Alter code style fe202a2 [Hong Shen] Change the code style and add documentation. a9fca84 [Hong Shen] Add test case for SizeEstimator 4877eee [Hong Shen] Improve estimate the size of a large array a2ea7ac [Hong Shen] Alter code style 4c28e36 [Hong Shen] Improve estimate the size of a large array --- .../org/apache/spark/util/SizeEstimator.scala | 45 ++++++++++++------- .../spark/util/SizeEstimatorSuite.scala | 18 ++++++++ 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 26ffbf9350388..4dd7ab9e0767b 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -179,7 +179,7 @@ private[spark] object SizeEstimator extends Logging { } // Estimate the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling. - private val ARRAY_SIZE_FOR_SAMPLING = 200 + private val ARRAY_SIZE_FOR_SAMPLING = 400 private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState) { @@ -204,25 +204,40 @@ private[spark] object SizeEstimator extends Logging { } } else { // Estimate the size of a large array by sampling elements without replacement. - var size = 0.0 + // To exclude the shared objects that the array elements may link, sample twice + // and use the min one to caculate array size. val rand = new Random(42) - val drawn = new OpenHashSet[Int](ARRAY_SAMPLE_SIZE) - var numElementsDrawn = 0 - while (numElementsDrawn < ARRAY_SAMPLE_SIZE) { - var index = 0 - do { - index = rand.nextInt(length) - } while (drawn.contains(index)) - drawn.add(index) - val elem = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef] - size += SizeEstimator.estimate(elem, state.visited) - numElementsDrawn += 1 - } - state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong + val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE) + val s1 = sampleArray(array, state, rand, drawn, length) + val s2 = sampleArray(array, state, rand, drawn, length) + val size = math.min(s1, s2) + state.size += math.max(s1, s2) + + (size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong } } } + private def sampleArray( + array: AnyRef, + state: SearchState, + rand: Random, + drawn: OpenHashSet[Int], + length: Int): Long = { + var size = 0L + for (i <- 0 until ARRAY_SAMPLE_SIZE) { + var index = 0 + do { + index = rand.nextInt(length) + } while (drawn.contains(index)) + drawn.add(index) + val obj = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef] + if (obj != null) { + size += SizeEstimator.estimate(obj, state.visited).toLong + } + } + size + } + private def primitiveSize(cls: Class[_]): Long = { if (cls == classOf[Byte]) { BYTE_SIZE diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 67a9f75ff2187..28915bd53354e 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import scala.collection.mutable.ArrayBuffer + import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite, PrivateMethodTester} class DummyClass1 {} @@ -96,6 +98,22 @@ class SizeEstimatorSuite // Past size 100, our samples 100 elements, but we should still get the right size. assertResult(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3))) + + val arr = new Array[Char](100000) + assertResult(200016)(SizeEstimator.estimate(arr)) + assertResult(480032)(SizeEstimator.estimate(Array.fill(10000)(new DummyString(arr)))) + + val buf = new ArrayBuffer[DummyString]() + for (i <- 0 until 5000) { + buf.append(new DummyString(new Array[Char](10))) + } + assertResult(340016)(SizeEstimator.estimate(buf.toArray)) + + for (i <- 0 until 5000) { + buf.append(new DummyString(arr)) + } + assertResult(683912)(SizeEstimator.estimate(buf.toArray)) + // If an array contains the *same* element many times, we should only count it once. val d1 = new DummyClass1 // 10 pointers plus 8-byte object From 5d45e1f60059e2f2fc8ad64778b9ddcc8887c570 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 27 Apr 2015 19:46:17 -0400 Subject: [PATCH 143/191] [SPARK-3090] [CORE] Stop SparkContext if user forgets to. Set up a shutdown hook to try to stop the Spark context in case the user forgets to do it. The main effect is that any open logs files are flushed and closed, which is particularly interesting for event logs. Author: Marcelo Vanzin Closes #5696 from vanzin/SPARK-3090 and squashes the following commits: 3b554b5 [Marcelo Vanzin] [SPARK-3090] [core] Stop SparkContext if user forgets to. --- .../scala/org/apache/spark/SparkContext.scala | 38 ++++++++++++------- .../scala/org/apache/spark/util/Utils.scala | 10 ++++- .../spark/deploy/yarn/ApplicationMaster.scala | 10 +---- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ea4ddcc2e265d..65b903a55d5bd 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -223,6 +223,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private var _listenerBusStarted: Boolean = false private var _jars: Seq[String] = _ private var _files: Seq[String] = _ + private var _shutdownHookRef: AnyRef = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -517,6 +518,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _taskScheduler.postStartHook() _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) + + // Make sure the context is stopped if the user forgets about it. This avoids leaving + // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM + // is killed, though. + _shutdownHookRef = Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => + logInfo("Invoking stop() from shutdown hook") + stop() + } } catch { case NonFatal(e) => logError("Error initializing SparkContext.", e) @@ -1481,6 +1490,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli logInfo("SparkContext already stopped.") return } + if (_shutdownHookRef != null) { + Utils.removeShutdownHook(_shutdownHookRef) + } postApplicationEnd() _ui.foreach(_.stop()) @@ -1891,7 +1903,7 @@ object SparkContext extends Logging { * * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK. */ - private val activeContext: AtomicReference[SparkContext] = + private val activeContext: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null) /** @@ -1944,11 +1956,11 @@ object SparkContext extends Logging { } /** - * This function may be used to get or instantiate a SparkContext and register it as a - * singleton object. Because we can only have one active SparkContext per JVM, - * this is useful when applications may wish to share a SparkContext. + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, + * this is useful when applications may wish to share a SparkContext. * - * Note: This function cannot be used to create multiple SparkContext instances + * Note: This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(config: SparkConf): SparkContext = { @@ -1961,17 +1973,17 @@ object SparkContext extends Logging { activeContext.get() } } - + /** - * This function may be used to get or instantiate a SparkContext and register it as a - * singleton object. Because we can only have one active SparkContext per JVM, + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. - * + * * This method allows not passing a SparkConf (useful if just retrieving). - * - * Note: This function cannot be used to create multiple SparkContext instances - * even if multiple contexts are allowed. - */ + * + * Note: This function cannot be used to create multiple SparkContext instances + * even if multiple contexts are allowed. + */ def getOrCreate(): SparkContext = { getOrCreate(new SparkConf()) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c6c6df7cfa56e..342bc9a06db47 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -67,6 +67,12 @@ private[spark] object Utils extends Logging { val DEFAULT_SHUTDOWN_PRIORITY = 100 + /** + * The shutdown priority of the SparkContext instance. This is lower than the default + * priority, so that by default hooks are run before the context is shut down. + */ + val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null @@ -2116,7 +2122,7 @@ private[spark] object Utils extends Logging { * @return A handle that can be used to unregister the shutdown hook. */ def addShutdownHook(hook: () => Unit): AnyRef = { - addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY, hook) + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) } /** @@ -2126,7 +2132,7 @@ private[spark] object Utils extends Logging { * @param hook The code to run during shutdown. * @return A handle that can be used to unregister the shutdown hook. */ - def addShutdownHook(priority: Int, hook: () => Unit): AnyRef = { + def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { shutdownHooks.add(priority, hook) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 93ae45133ce24..70cb57ffd8c69 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -95,14 +95,8 @@ private[spark] class ApplicationMaster( val fs = FileSystem.get(yarnConf) - Utils.addShutdownHook { () => - // If the SparkContext is still registered, shut it down as a best case effort in case - // users do not call sc.stop or do System.exit(). - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - } + // This shutdown hook should run *after* the SparkContext is shut down. + Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1) { () => val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts From ab5adb7a973eec9d95c7575c864cba9f8d83a0fd Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 27 Apr 2015 19:50:55 -0400 Subject: [PATCH 144/191] [SPARK-7145] [CORE] commons-lang (2.x) classes used instead of commons-lang3 (3.x); commons-io used without dependency Remove use of commons-lang in favor of commons-lang3 classes; remove commons-io use in favor of Guava Author: Sean Owen Closes #5703 from srowen/SPARK-7145 and squashes the following commits: 21fbe03 [Sean Owen] Remove use of commons-lang in favor of commons-lang3 classes; remove commons-io use in favor of Guava --- .../test/scala/org/apache/spark/FileServerSuite.scala | 7 +++---- .../apache/spark/metrics/InputOutputMetricsSuite.scala | 4 ++-- .../netty/NettyBlockTransferSecuritySuite.scala | 10 +++++++--- external/flume-sink/pom.xml | 4 ++++ .../flume/sink/SparkAvroCallbackHandler.scala | 4 ++-- .../main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 6 +++++- .../sql/hive/thriftserver/AbstractSparkSQLDriver.scala | 4 ++-- .../sql/hive/thriftserver/SparkSQLCLIDriver.scala | 8 +++----- .../apache/spark/sql/hive/execution/UDFListString.java | 6 +++--- .../spark/sql/hive/MetastoreDataSourcesSuite.scala | 9 ++++----- 10 files changed, 35 insertions(+), 27 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index a69e9b761f9a7..c0439f934813e 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -22,8 +22,7 @@ import java.net.URI import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl.SSLException -import com.google.common.io.ByteStreams -import org.apache.commons.io.{FileUtils, IOUtils} +import com.google.common.io.{ByteStreams, Files} import org.apache.commons.lang3.RandomUtils import org.scalatest.FunSuite @@ -239,7 +238,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { def fileTransferTest(server: HttpFileServer, sm: SecurityManager = null): Unit = { val randomContent = RandomUtils.nextBytes(100) val file = File.createTempFile("FileServerSuite", "sslTests", tmpDir) - FileUtils.writeByteArrayToFile(file, randomContent) + Files.write(randomContent, file) server.addFile(file) val uri = new URI(server.serverUri + "/files/" + file.getName) @@ -254,7 +253,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { Utils.setupSecureURLConnection(connection, sm) } - val buf = IOUtils.toByteArray(connection.getInputStream) + val buf = ByteStreams.toByteArray(connection.getInputStream) assert(buf === randomContent) } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 190b08d950a02..ef3e213f1fcce 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -21,7 +21,7 @@ import java.io.{File, FileWriter, PrintWriter} import scala.collection.mutable.ArrayBuffer -import org.apache.commons.lang.math.RandomUtils +import org.apache.commons.lang3.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} @@ -60,7 +60,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(tmpFile)) for (x <- 1 to numRecords) { - pw.println(RandomUtils.nextInt(numBuckets)) + pw.println(RandomUtils.nextInt(0, numBuckets)) } pw.close() diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 94bfa67451892..46d2e5173acae 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.network.netty +import java.io.InputStreamReader import java.nio._ +import java.nio.charset.Charset import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} import scala.util.{Failure, Success, Try} -import org.apache.commons.io.IOUtils +import com.google.common.io.CharStreams import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.{BlockDataManager, BlockTransferService} @@ -32,7 +34,7 @@ import org.apache.spark.storage.{BlockId, ShuffleBlockId} import org.apache.spark.{SecurityManager, SparkConf} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers} +import org.scalatest.{FunSuite, ShouldMatchers} class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { test("security default off") { @@ -113,7 +115,9 @@ class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with Sh val result = fetchBlock(exec0, exec1, "1", blockId) match { case Success(buf) => - IOUtils.toString(buf.createInputStream()) should equal(blockString) + val actualString = CharStreams.toString( + new InputStreamReader(buf.createInputStream(), Charset.forName("UTF-8"))) + actualString should equal(blockString) buf.release() Success() case Failure(t) => diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 67907bbfb6d1b..1f3e619d97a24 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -35,6 +35,10 @@ http://spark.apache.org/ + + org.apache.commons + commons-lang3 + org.apache.flume flume-ng-sdk diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index 4373be443e67d..fd01807fc3ac4 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -21,9 +21,9 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.flume.Channel -import org.apache.commons.lang.RandomStringUtils import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.flume.Channel +import org.apache.commons.lang3.RandomStringUtils /** * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index f326510042122..f3b5455574d1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import java.util.Properties -import org.apache.commons.lang.StringEscapeUtils.escapeSql +import org.apache.commons.lang3.StringUtils + import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} @@ -239,6 +240,9 @@ private[sql] class JDBCRDD( case _ => value } + private def escapeSql(value: String): String = + if (value == null) null else StringUtils.replace(value, "'", "''") + /** * Turns a single Filter into a String representing a SQL expression. * Returns null for an unhandled filter. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala index 59f3a75768082..48ac9062af96a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import scala.collection.JavaConversions._ -import org.apache.commons.lang.exception.ExceptionUtils +import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse @@ -61,7 +61,7 @@ private[hive] abstract class AbstractSparkSQLDriver( } catch { case cause: Throwable => logError(s"Failed in [$command]", cause) - new CommandProcessorResponse(1, ExceptionUtils.getFullStackTrace(cause), null) + new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7e307bb4ad1e8..b7b6925aa87f7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -24,18 +24,16 @@ import java.util.{ArrayList => JArrayList} import jline.{ConsoleReader, History} -import org.apache.commons.lang.StringUtils +import org.apache.commons.lang3.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} -import org.apache.hadoop.hive.common.LogUtils.LogInitializationException -import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} +import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor} import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket import org.apache.spark.Logging diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java index efd34df293c88..f33210ebdae1b 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java @@ -17,10 +17,10 @@ package org.apache.spark.sql.hive.execution; -import org.apache.hadoop.hive.ql.exec.UDF; - import java.util.List; -import org.apache.commons.lang.StringUtils; + +import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.hive.ql.exec.UDF; public class UDFListString extends UDF { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e09c702c8969e..0538aa203c5a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.BeforeAndAfterEach -import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.metastore.TableType import org.apache.hadoop.hive.ql.metadata.Table @@ -174,7 +173,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT * FROM jsonTable"), Row("a", "b")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF() .toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -190,7 +189,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), Row("a1", "b1", "c1")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) } test("drop, change, recreate") { @@ -212,7 +211,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT * FROM jsonTable"), Row("a", "b")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) sparkContext.parallelize(("a", "b", "c") :: Nil).toDF() .toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -231,7 +230,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), Row("a", "b", "c")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) } test("invalidate cache and reload") { From 62888a4ded91b3c2cbb05936c374c7ebfc10799e Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Mon, 27 Apr 2015 19:52:41 -0400 Subject: [PATCH 145/191] [SPARK-7162] [YARN] Launcher error in yarn-client jira: https://issues.apache.org/jira/browse/SPARK-7162 Author: GuoQiang Li Closes #5716 from witgo/SPARK-7162 and squashes the following commits: b64564c [GuoQiang Li] Launcher error in yarn-client --- yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 019afbd1a1743..741239c953794 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -354,7 +354,7 @@ private[spark] class Client( val dir = new File(path) if (dir.isDirectory()) { dir.listFiles().foreach { file => - if (!hadoopConfFiles.contains(file.getName())) { + if (file.isFile && !hadoopConfFiles.contains(file.getName())) { hadoopConfFiles(file.getName()) = file } } From 4d9e560b5470029143926827b1cb9d72a0bfbeff Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 27 Apr 2015 19:02:51 -0700 Subject: [PATCH 146/191] [SPARK-7090] [MLLIB] Introduce LDAOptimizer to LDA to further improve extensibility jira: https://issues.apache.org/jira/browse/SPARK-7090 LDA was implemented with extensibility in mind. And with the development of OnlineLDA and Gibbs Sampling, we are collecting more detailed requirements from different algorithms. As Joseph Bradley jkbradley proposed in https://github.com/apache/spark/pull/4807 and with some further discussion, we'd like to adjust the code structure a little to present the common interface and extension point clearly. Basically class LDA would be a common entrance for LDA computing. And each LDA object will refer to a LDAOptimizer for the concrete algorithm implementation. Users can customize LDAOptimizer with specific parameters and assign it to LDA. Concrete changes: 1. Add a trait `LDAOptimizer`, which defines the common iterface for concrete implementations. Each subClass is a wrapper for a specific LDA algorithm. 2. Move EMOptimizer to file LDAOptimizer and inherits from LDAOptimizer, rename to EMLDAOptimizer. (in case a more generic EMOptimizer comes in the future) -adjust the constructor of EMOptimizer, since all the parameters should be passed in through initialState method. This can avoid unwanted confusion or overwrite. -move the code from LDA.initalState to initalState of EMLDAOptimizer 3. Add property ldaOptimizer to LDA and its getter/setter, and EMLDAOptimizer is the default Optimizer. 4. Change the return type of LDA.run from DistributedLDAModel to LDAModel. Further work: add OnlineLDAOptimizer and other possible Optimizers once ready. Author: Yuhao Yang Closes #5661 from hhbyyh/ldaRefactor and squashes the following commits: 0e2e006 [Yuhao Yang] respond to review comments 08a45da [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor e756ce4 [Yuhao Yang] solve mima exception d74fd8f [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor 0bb8400 [Yuhao Yang] refactor LDA with Optimizer ec2f857 [Yuhao Yang] protoptype for discussion --- .../spark/examples/mllib/JavaLDAExample.java | 2 +- .../spark/examples/mllib/LDAExample.scala | 4 +- .../apache/spark/mllib/clustering/LDA.scala | 181 +++------------ .../spark/mllib/clustering/LDAModel.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 210 ++++++++++++++++++ .../spark/mllib/clustering/JavaLDASuite.java | 2 +- .../spark/mllib/clustering/LDASuite.scala | 2 +- project/MimaExcludes.scala | 4 + 8 files changed, 256 insertions(+), 151 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java index 36207ae38d9a9..fd53c81cc4974 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -58,7 +58,7 @@ public Tuple2 call(Tuple2 doc_id) { corpus.cache(); // Cluster the documents into three topics using LDA - DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus); // Output topics. Each is a distribution over words (matching word count vectors) System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 08a93595a2e17..a1850390c0a86 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -26,7 +26,7 @@ import scopt.OptionParser import org.apache.log4j.{Level, Logger} import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -137,7 +137,7 @@ object LDAExample { sc.setCheckpointDir(params.checkpointDir.get) } val startTime = System.nanoTime() - val ldaModel = lda.run(corpus) + val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] val elapsed = (System.nanoTime() - startTime) / 1e9 println(s"Finished training LDA model. Summary:") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index d006b39acb213..37bf88b73b911 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -17,16 +17,11 @@ package org.apache.spark.mllib.clustering -import java.util.Random - -import breeze.linalg.{DenseVector => BDV, normalize} - +import breeze.linalg.{DenseVector => BDV} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.GraphImpl -import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -42,16 +37,9 @@ import org.apache.spark.util.Utils * - "token": instance of a term appearing in a document * - "topic": multinomial distribution over words representing some concept * - * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented - * according to the Asuncion et al. (2009) paper referenced below. - * * References: * - Original LDA paper (journal version): * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. - * - This class implements their "smoothed" LDA model. - * - Paper which clearly explains several algorithms, including EM: - * Asuncion, Welling, Smyth, and Teh. - * "On Smoothing and Inference for Topic Models." UAI, 2009. * * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation * (Wikipedia)]] @@ -63,10 +51,11 @@ class LDA private ( private var docConcentration: Double, private var topicConcentration: Double, private var seed: Long, - private var checkpointInterval: Int) extends Logging { + private var checkpointInterval: Int, + private var ldaOptimizer: LDAOptimizer) extends Logging { def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, - seed = Utils.random.nextLong(), checkpointInterval = 10) + seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer) /** * Number of topics to infer. I.e., the number of soft cluster centers. @@ -220,6 +209,32 @@ class LDA private ( this } + + /** LDAOptimizer used to perform the actual calculation */ + def getOptimizer: LDAOptimizer = ldaOptimizer + + /** + * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer) + */ + def setOptimizer(optimizer: LDAOptimizer): this.type = { + this.ldaOptimizer = optimizer + this + } + + /** + * Set the LDAOptimizer used to perform the actual calculation by algorithm name. + * Currently "em" is supported. + */ + def setOptimizer(optimizerName: String): this.type = { + this.ldaOptimizer = + optimizerName.toLowerCase match { + case "em" => new EMLDAOptimizer + case other => + throw new IllegalArgumentException(s"Only em is supported but got $other.") + } + this + } + /** * Learn an LDA model using the given dataset. * @@ -229,9 +244,9 @@ class LDA private ( * Document IDs must be unique and >= 0. * @return Inferred LDA model */ - def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { - val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, - checkpointInterval) + def run(documents: RDD[(Long, Vector)]): LDAModel = { + val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration, + seed, checkpointInterval) var iter = 0 val iterationTimes = Array.fill[Double](maxIterations)(0) while (iter < maxIterations) { @@ -241,12 +256,11 @@ class LDA private ( iterationTimes(iter) = elapsedSeconds iter += 1 } - state.graphCheckpointer.deleteAllCheckpoints() - new DistributedLDAModel(state, iterationTimes) + state.getLDAModel(iterationTimes) } /** Java-friendly version of [[run()]] */ - def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = { + def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } } @@ -320,88 +334,10 @@ private[clustering] object LDA { private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 - /** - * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. - * - * @param graph EM graph, storing current parameter estimates in vertex descriptors and - * data (token counts) in edge descriptors. - * @param k Number of topics - * @param vocabSize Number of unique terms - * @param docConcentration "alpha" - * @param topicConcentration "beta" or "eta" - */ - private[clustering] class EMOptimizer( - var graph: Graph[TopicCounts, TokenCount], - val k: Int, - val vocabSize: Int, - val docConcentration: Double, - val topicConcentration: Double, - checkpointInterval: Int) { - - private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( - graph, checkpointInterval) - - def next(): EMOptimizer = { - val eta = topicConcentration - val W = vocabSize - val alpha = docConcentration - - val N_k = globalTopicTotals - val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = - (edgeContext) => { - // Compute N_{wj} gamma_{wjk} - val N_wj = edgeContext.attr - // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count - // N_{wj}. - val scaledTopicDistribution: TopicCounts = - computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj - edgeContext.sendToDst((false, scaledTopicDistribution)) - edgeContext.sendToSrc((false, scaledTopicDistribution)) - } - // This is a hack to detect whether we could modify the values in-place. - // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) - val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = - (m0, m1) => { - val sum = - if (m0._1) { - m0._2 += m1._2 - } else if (m1._1) { - m1._2 += m0._2 - } else { - m0._2 + m1._2 - } - (true, sum) - } - // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. - val docTopicDistributions: VertexRDD[TopicCounts] = - graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) - .mapValues(_._2) - // Update the vertex descriptors with the new counts. - val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) - graph = newGraph - graphCheckpointer.updateGraph(newGraph) - globalTopicTotals = computeGlobalTopicTotals() - this - } - - /** - * Aggregate distributions over topics from all term vertices. - * - * Note: This executes an action on the graph RDDs. - */ - var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() - - private def computeGlobalTopicTotals(): TopicCounts = { - val numTopics = k - graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) - } - - } - /** * Compute gamma_{wjk}, a distribution over topics k. */ - private def computePTopic( + private[clustering] def computePTopic( docTopicCounts: TopicCounts, termTopicCounts: TopicCounts, totalTopicCounts: TopicCounts, @@ -427,49 +363,4 @@ private[clustering] object LDA { // normalize BDV(gamma_wj) /= sum } - - /** - * Compute bipartite term/doc graph. - */ - private def initialState( - docs: RDD[(Long, Vector)], - k: Int, - docConcentration: Double, - topicConcentration: Double, - randomSeed: Long, - checkpointInterval: Int): EMOptimizer = { - // For each document, create an edge (Document -> Term) for each unique term in the document. - val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => - // Add edges for terms with non-zero counts. - termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => - Edge(docID, term2index(term), cnt) - } - } - - val vocabSize = docs.take(1).head._2.size - - // Create vertices. - // Initially, we use random soft assignments of tokens to topics (random gamma). - def createVertices(): RDD[(VertexId, TopicCounts)] = { - val verticesTMP: RDD[(VertexId, TopicCounts)] = - edges.mapPartitionsWithIndex { case (partIndex, partEdges) => - val random = new Random(partIndex + randomSeed) - partEdges.flatMap { edge => - val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) - val sum = gamma * edge.attr - Seq((edge.srcId, sum), (edge.dstId, sum)) - } - } - verticesTMP.reduceByKey(_ + _) - } - - val docTermVertices = createVertices() - - // Partition such that edges are grouped by document - val graph = Graph(docTermVertices, edges) - .partitionBy(PartitionStrategy.EdgePartition1D) - - new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval) - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 0a3f21ecee0dc..6cf26445f20a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -203,7 +203,7 @@ class DistributedLDAModel private ( import LDA._ - private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = { + private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = { this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, state.topicConcentration, iterationTimes) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala new file mode 100644 index 0000000000000..ffd72a294c6c6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV, normalize} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * + * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can + * hold optimizer-specific parameters for users to set. + */ +@Experimental +trait LDAOptimizer{ + + /* + DEVELOPERS NOTE: + + An LDAOptimizer contains an algorithm for LDA and performs the actual computation, which + stores internal data structure (Graph or Matrix) and other parameters for the algorithm. + The interface is isolated to improve the extensibility of LDA. + */ + + /** + * Initializer for the optimizer. LDA passes the common parameters to the optimizer and + * the internal structure can be initialized properly. + */ + private[clustering] def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): LDAOptimizer + + private[clustering] def next(): LDAOptimizer + + private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel +} + +/** + * :: Experimental :: + * + * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. + * + * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented + * according to the Asuncion et al. (2009) paper referenced below. + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * - This class implements their "smoothed" LDA model. + * - Paper which clearly explains several algorithms, including EM: + * Asuncion, Welling, Smyth, and Teh. + * "On Smoothing and Inference for Topic Models." UAI, 2009. + * + */ +@Experimental +class EMLDAOptimizer extends LDAOptimizer{ + + import LDA._ + + /** + * Following fields will only be initialized through initialState method + */ + private[clustering] var graph: Graph[TopicCounts, TokenCount] = null + private[clustering] var k: Int = 0 + private[clustering] var vocabSize: Int = 0 + private[clustering] var docConcentration: Double = 0 + private[clustering] var topicConcentration: Double = 0 + private[clustering] var checkpointInterval: Int = 10 + private var graphCheckpointer: PeriodicGraphCheckpointer[TopicCounts, TokenCount] = null + + /** + * Compute bipartite term/doc graph. + */ + private[clustering] override def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): LDAOptimizer = { + // For each document, create an edge (Document -> Term) for each unique term in the document. + val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => + // Add edges for terms with non-zero counts. + termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => + Edge(docID, term2index(term), cnt) + } + } + + val vocabSize = docs.take(1).head._2.size + + // Create vertices. + // Initially, we use random soft assignments of tokens to topics (random gamma). + def createVertices(): RDD[(VertexId, TopicCounts)] = { + val verticesTMP: RDD[(VertexId, TopicCounts)] = + edges.mapPartitionsWithIndex { case (partIndex, partEdges) => + val random = new Random(partIndex + randomSeed) + partEdges.flatMap { edge => + val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) + val sum = gamma * edge.attr + Seq((edge.srcId, sum), (edge.dstId, sum)) + } + } + verticesTMP.reduceByKey(_ + _) + } + + val docTermVertices = createVertices() + + // Partition such that edges are grouped by document + this.graph = Graph(docTermVertices, edges).partitionBy(PartitionStrategy.EdgePartition1D) + this.k = k + this.vocabSize = vocabSize + this.docConcentration = docConcentration + this.topicConcentration = topicConcentration + this.checkpointInterval = checkpointInterval + this.graphCheckpointer = new + PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval) + this.globalTopicTotals = computeGlobalTopicTotals() + this + } + + private[clustering] override def next(): EMLDAOptimizer = { + require(graph != null, "graph is null, EMLDAOptimizer not initialized.") + + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration + + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = + (edgeContext) => { + // Compute N_{wj} gamma_{wjk} + val N_wj = edgeContext.attr + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count + // N_{wj}. + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj + edgeContext.sendToDst((false, scaledTopicDistribution)) + edgeContext.sendToSrc((false, scaledTopicDistribution)) + } + // This is a hack to detect whether we could modify the values in-place. + // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) + val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = + (m0, m1) => { + val sum = + if (m0._1) { + m0._2 += m1._2 + } else if (m1._1) { + m1._2 += m0._2 + } else { + m0._2 + m1._2 + } + (true, sum) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val docTopicDistributions: VertexRDD[TopicCounts] = + graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) + .mapValues(_._2) + // Update the vertex descriptors with the new counts. + val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + graph = newGraph + graphCheckpointer.updateGraph(newGraph) + globalTopicTotals = computeGlobalTopicTotals() + this + } + + /** + * Aggregate distributions over topics from all term vertices. + * + * Note: This executes an action on the graph RDDs. + */ + private[clustering] var globalTopicTotals: TopicCounts = null + + private def computeGlobalTopicTotals(): TopicCounts = { + val numTopics = k + graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) + } + + private[clustering] override def getLDAModel(iterationTimes: Array[Double]): LDAModel = { + require(graph != null, "graph is null, EMLDAOptimizer not initialized.") + this.graphCheckpointer.deleteAllCheckpoints() + new DistributedLDAModel(this, iterationTimes) + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index dc10aa67c7c1f..fbe171b4b1ab1 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -88,7 +88,7 @@ public void distributedLDAModel() { .setMaxIterations(5) .setSeed(12345); - DistributedLDAModel model = lda.run(corpus); + DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus); // Check: basic parameters LocalLDAModel localModel = model.toLocal(); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index cc747dabb9968..41ec794146c69 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { .setSeed(12345) val corpus = sc.parallelize(tinyCorpus, 2) - val model: DistributedLDAModel = lda.run(corpus) + val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] // Check: basic parameters val localModel = model.toLocal diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7ef363a2f07ad..967961c2bf5c3 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -72,6 +72,10 @@ object MimaExcludes { // SPARK-6703 Add getOrCreate method to SparkContext ProblemFilters.exclude[IncompatibleResultTypeProblem] ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") + )++ Seq( + // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.mllib.clustering.LDA$EMOptimizer") ) case v if v.startsWith("1.3") => From 874a2ca93d095a0dfa1acfdacf0e9d80388c4422 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 27 Apr 2015 21:45:40 -0700 Subject: [PATCH 147/191] [SPARK-7174][Core] Move calling `TaskScheduler.executorHeartbeatReceived` to another thread `HeartbeatReceiver` will call `TaskScheduler.executorHeartbeatReceived`, which is a blocking operation because `TaskScheduler.executorHeartbeatReceived` will call ```Scala blockManagerMaster.driverEndpoint.askWithReply[Boolean]( BlockManagerHeartbeat(blockManagerId), 600 seconds) ``` finally. Even if it asks from a local Actor, it may block the current Akka thread. E.g., the reply may be dispatched to the same thread of the ask operation. So the reply cannot be processed. An extreme case is setting the thread number of Akka dispatch thread pool to 1. jstack log: ``` "sparkDriver-akka.actor.default-dispatcher-14" daemon prio=10 tid=0x00007f2a8c02d000 nid=0x725 waiting on condition [0x00007f2b1d6d0000] java.lang.Thread.State: TIMED_WAITING (parking) at sun.misc.Unsafe.park(Native Method) - parking to wait for <0x00000006197a0868> (a scala.concurrent.impl.Promise$CompletionLatch) at java.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:226) at java.util.concurrent.locks.AbstractQueuedSynchronizer.doAcquireSharedNanos(AbstractQueuedSynchronizer.java:1033) at java.util.concurrent.locks.AbstractQueuedSynchronizer.tryAcquireSharedNanos(AbstractQueuedSynchronizer.java:1326) at scala.concurrent.impl.Promise$DefaultPromise.tryAwait(Promise.scala:208) at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:218) at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223) at scala.concurrent.Await$$anonfun$result$1.apply(package.scala:107) at akka.dispatch.MonitorableThreadFactory$AkkaForkJoinWorkerThread$$anon$3.block(ThreadPoolBuilder.scala:169) at scala.concurrent.forkjoin.ForkJoinPool.managedBlock(ForkJoinPool.java:3640) at akka.dispatch.MonitorableThreadFactory$AkkaForkJoinWorkerThread.blockOn(ThreadPoolBuilder.scala:167) at scala.concurrent.Await$.result(package.scala:107) at org.apache.spark.rpc.RpcEndpointRef.askWithReply(RpcEnv.scala:355) at org.apache.spark.scheduler.DAGScheduler.executorHeartbeatReceived(DAGScheduler.scala:169) at org.apache.spark.scheduler.TaskSchedulerImpl.executorHeartbeatReceived(TaskSchedulerImpl.scala:367) at org.apache.spark.HeartbeatReceiver$$anonfun$receiveAndReply$1.applyOrElse(HeartbeatReceiver.scala:103) at org.apache.spark.rpc.akka.AkkaRpcEnv.org$apache$spark$rpc$akka$AkkaRpcEnv$$processMessage(AkkaRpcEnv.scala:182) at org.apache.spark.rpc.akka.AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1$$anonfun$receiveWithLogging$1$$anonfun$applyOrElse$4.apply$mcV$sp(AkkaRpcEnv.scala:128) at org.apache.spark.rpc.akka.AkkaRpcEnv.org$apache$spark$rpc$akka$AkkaRpcEnv$$safelyCall(AkkaRpcEnv.scala:203) at org.apache.spark.rpc.akka.AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1$$anonfun$receiveWithLogging$1.applyOrElse(AkkaRpcEnv.scala:127) at scala.runtime.AbstractPartialFunction$mcVL$sp.apply$mcVL$sp(AbstractPartialFunction.scala:33) at scala.runtime.AbstractPartialFunction$mcVL$sp.apply(AbstractPartialFunction.scala:33) at scala.runtime.AbstractPartialFunction$mcVL$sp.apply(AbstractPartialFunction.scala:25) at org.apache.spark.util.ActorLogReceive$$anon$1.apply(ActorLogReceive.scala:59) at org.apache.spark.util.ActorLogReceive$$anon$1.apply(ActorLogReceive.scala:42) at scala.PartialFunction$class.applyOrElse(PartialFunction.scala:118) at org.apache.spark.util.ActorLogReceive$$anon$1.applyOrElse(ActorLogReceive.scala:42) at akka.actor.Actor$class.aroundReceive(Actor.scala:465) at org.apache.spark.rpc.akka.AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1.aroundReceive(AkkaRpcEnv.scala:94) at akka.actor.ActorCell.receiveMessage(ActorCell.scala:516) at akka.actor.ActorCell.invoke(ActorCell.scala:487) at akka.dispatch.Mailbox.processMailbox(Mailbox.scala:238) at akka.dispatch.Mailbox.run(Mailbox.scala:220) at akka.dispatch.ForkJoinExecutorConfigurator$AkkaForkJoinTask.exec(AbstractDispatcher.scala:393) at scala.concurrent.forkjoin.ForkJoinTask.doExec(ForkJoinTask.java:260) at scala.concurrent.forkjoin.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:1339) at scala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1979) at scala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:107) ``` This PR moved this blocking operation to a separated thread. Author: zsxwing Closes #5723 from zsxwing/SPARK-7174 and squashes the following commits: 98bfe48 [zsxwing] Use a single thread for checking timeout and reporting executorHeartbeatReceived 5b3b545 [zsxwing] Move calling `TaskScheduler.executorHeartbeatReceived` to another thread to avoid blocking the Akka thread pool --- .../org/apache/spark/HeartbeatReceiver.scala | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 68d05d5b02537..f2b024ff6cb67 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -76,13 +76,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) private var timeoutCheckingTask: ScheduledFuture[_] = null - private val timeoutCheckingThread = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-timeout-checking-thread") + // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not + // block the thread for a long time. + private val eventLoopThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-receiver-event-loop-thread") private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread") override def onStart(): Unit = { - timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { + timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ExpireDeadHosts)) } @@ -99,11 +101,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) executorLastSeen(executorId) = System.currentTimeMillis() - context.reply(response) + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to @@ -125,7 +131,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) if (sc.supportDynamicAllocation) { // Asynchronously kill the executor to avoid blocking the current thread killExecutorThread.submit(new Runnable { - override def run(): Unit = sc.killExecutor(executorId) + override def run(): Unit = Utils.tryLogNonFatalError { + sc.killExecutor(executorId) + } }) } executorLastSeen.remove(executorId) @@ -137,7 +145,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel(true) } - timeoutCheckingThread.shutdownNow() + eventLoopThread.shutdownNow() killExecutorThread.shutdownNow() } } From 29576e786072bd4218e10036ddfc8d367b1c1446 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 27 Apr 2015 23:10:14 -0700 Subject: [PATCH 148/191] [SPARK-6829] Added math functions for DataFrames Implemented almost all math functions found in scala.math (max, min and abs were already present). cc mengxr marmbrus Author: Burak Yavuz Closes #5616 from brkyvz/math-udfs and squashes the following commits: fb27153 [Burak Yavuz] reverted exception message 836a098 [Burak Yavuz] fixed test and addressed small comment e5f0d13 [Burak Yavuz] addressed code review v2.2 b26c5fb [Burak Yavuz] addressed review v2.1 2761f08 [Burak Yavuz] addressed review v2 6588a5b [Burak Yavuz] fixed merge conflicts b084e10 [Burak Yavuz] Addressed code review 029e739 [Burak Yavuz] fixed atan2 test 534cc11 [Burak Yavuz] added more tests, addressed comments fa68dbe [Burak Yavuz] added double specific test data 937d5a5 [Burak Yavuz] use doubles instead of ints 8e28fff [Burak Yavuz] Added apache header 7ec8f7f [Burak Yavuz] Added math functions for DataFrames --- .../catalyst/analysis/HiveTypeCoercion.scala | 19 + .../sql/catalyst/expressions/Expression.scala | 10 + .../expressions/mathfuncs/binary.scala | 93 +++ .../expressions/mathfuncs/unary.scala | 168 ++++++ .../ExpressionEvaluationSuite.scala | 165 +++++ .../org/apache/spark/sql/mathfunctions.scala | 562 ++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 9 + .../spark/sql/ColumnExpressionSuite.scala | 1 - .../spark/sql/MathExpressionsSuite.scala | 233 ++++++++ 9 files changed, 1259 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 35c7f00d4e42a..73c9a1c7afdad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -79,6 +79,7 @@ trait HiveTypeCoercion { CaseWhenCoercion :: Division :: PropagateTypes :: + ExpectedInputConversion :: Nil /** @@ -643,4 +644,22 @@ trait HiveTypeCoercion { } } + /** + * Casts types according to the expected input types for Expressions that have the trait + * `ExpectsInputTypes`. + */ + object ExpectedInputConversion extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => + val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case (child, actual, expected) => + if (actual == expected) child else Cast(child, expected) + } + e.withNewChildren(newC) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4e3bbc06a5b4c..1d71c1b4b0c7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -109,3 +109,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException } + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait ExpectsInputTypes { + + def expectedChildTypes: Seq[DataType] + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala new file mode 100644 index 0000000000000..5b4d912a64f71 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} +import org.apache.spark.sql.types._ + +/** + * A binary expression specifically for math functions that take two `Double`s as input and returns + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) + extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + type EvaluatedType = Any + override def symbol: String = null + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + + override def nullable: Boolean = left.nullable || right.nullable + override def toString: String = s"$name($left, $right)" + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType && + !DecimalType.isFixed(left.dataType) + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } + } +} + +case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") + +case class Hypot( + left: Expression, + right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") + +case class Atan2( + left: Expression, + right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala new file mode 100644 index 0000000000000..96cb77d487529 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} +import org.apache.spark.sql.types._ + +/** + * A unary expression specifically for math functions. Math Functions expect a specific type of + * input format, therefore these functions extend `ExpectsInputTypes`. + * @param name The short name of the function + */ +abstract class MathematicalExpression(name: String) + extends UnaryExpression with Serializable with ExpectsInputTypes { + self: Product => + type EvaluatedType = Any + + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"$name($child)" +} + +/** + * A unary expression specifically for math functions that take a `Double` as input and return + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForDouble(f: Double => Double, name: String) + extends MathematicalExpression(name) { self: Product => + + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val result = f(evalE.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } +} + +/** + * A unary expression specifically for math functions that take an `Int` as input and return + * an `Int`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForInt(f: Int => Int, name: String) + extends MathematicalExpression(name) { self: Product => + + override def dataType: DataType = IntegerType + override def expectedChildTypes: Seq[DataType] = Seq(IntegerType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) null else f(evalE.asInstanceOf[Int]) + } +} + +/** + * A unary expression specifically for math functions that take a `Float` as input and return + * a `Float`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForFloat(f: Float => Float, name: String) + extends MathematicalExpression(name) { self: Product => + + override def dataType: DataType = FloatType + override def expectedChildTypes: Seq[DataType] = Seq(FloatType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val result = f(evalE.asInstanceOf[Float]) + if (result.isNaN) null else result + } + } +} + +/** + * A unary expression specifically for math functions that take a `Long` as input and return + * a `Long`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForLong(f: Long => Long, name: String) + extends MathematicalExpression(name) { self: Product => + + override def dataType: DataType = LongType + override def expectedChildTypes: Seq[DataType] = Seq(LongType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) null else f(evalE.asInstanceOf[Long]) + } +} + +case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin, "SIN") + +case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin, "ASIN") + +case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh, "SINH") + +case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos, "COS") + +case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos, "ACOS") + +case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh, "COSH") + +case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan, "TAN") + +case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan, "ATAN") + +case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh, "TANH") + +case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil, "CEIL") + +case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor, "FLOOR") + +case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint, "ROUND") + +case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt, "CBRT") + +case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum, "SIGNUM") + +case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum, "ISIGNUM") + +case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum, "FSIGNUM") + +case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum, "LSIGNUM") + +case class ToDegrees(child: Expression) + extends MathematicalExpressionForDouble(math.toDegrees, "DEGREES") + +case class ToRadians(child: Expression) + extends MathematicalExpressionForDouble(math.toRadians, "RADIANS") + +case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log, "LOG") + +case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10, "LOG10") + +case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p, "LOG1P") + +case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp, "EXP") + +case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1, "EXPM1") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 76298f03c94ae..5390ce43c6639 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.types._ @@ -1152,6 +1153,170 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @tparam T Generic type for primitives + */ + def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( + c: Expression => Expression, + f: T => T, + domain: Iterable[T] = (-20 to 20).map(_ * 0.1), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { value => + checkEvaluation(c(Literal(value)), null, EmptyRow) + } + } else { + domain.foreach { value => + checkEvaluation(c(Literal(value)), f(value), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("sin") { + unaryMathFunctionEvaluation(Sin, math.sin) + } + + test("asin") { + unaryMathFunctionEvaluation(Asin, math.asin, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Asin, math.asin, (11 to 20).map(_ * 0.1), true) + } + + test("sinh") { + unaryMathFunctionEvaluation(Sinh, math.sinh) + } + + test("cos") { + unaryMathFunctionEvaluation(Cos, math.cos) + } + + test("acos") { + unaryMathFunctionEvaluation(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Acos, math.acos, (11 to 20).map(_ * 0.1), true) + } + + test("cosh") { + unaryMathFunctionEvaluation(Cosh, math.cosh) + } + + test("tan") { + unaryMathFunctionEvaluation(Tan, math.tan) + } + + test("atan") { + unaryMathFunctionEvaluation(Atan, math.atan) + } + + test("tanh") { + unaryMathFunctionEvaluation(Tanh, math.tanh) + } + + test("toDeg") { + unaryMathFunctionEvaluation(ToDegrees, math.toDegrees) + } + + test("toRad") { + unaryMathFunctionEvaluation(ToRadians, math.toRadians) + } + + test("cbrt") { + unaryMathFunctionEvaluation(Cbrt, math.cbrt) + } + + test("ceil") { + unaryMathFunctionEvaluation(Ceil, math.ceil) + } + + test("floor") { + unaryMathFunctionEvaluation(Floor, math.floor) + } + + test("rint") { + unaryMathFunctionEvaluation(Rint, math.rint) + } + + test("exp") { + unaryMathFunctionEvaluation(Exp, math.exp) + } + + test("expm1") { + unaryMathFunctionEvaluation(Expm1, math.expm1) + } + + test("signum") { + unaryMathFunctionEvaluation[Double](Signum, math.signum) + } + + test("isignum") { + unaryMathFunctionEvaluation[Int](ISignum, math.signum, (-5 to 5)) + } + + test("fsignum") { + unaryMathFunctionEvaluation[Float](FSignum, math.signum, (-5 to 5).map(_.toFloat)) + } + + test("lsignum") { + unaryMathFunctionEvaluation[Long](LSignum, math.signum, (5 to 5).map(_.toLong)) + } + + test("log") { + unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true) + } + + test("log10") { + unaryMathFunctionEvaluation(Log10, math.log10, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log10, math.log10, (-5 to -1).map(_ * 0.1), true) + } + + test("log1p") { + unaryMathFunctionEvaluation(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), true) + } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + */ + def binaryMathFunctionEvaluation( + c: (Expression, Expression) => Expression, + f: (Double, Double) => Double, + domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), null, create_row(null)) + } + } else { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(c(v2, v1), f(v2 + 0.0, v1 + 0.0), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType), 1.0), null, create_row(null)) + checkEvaluation(c(1.0, Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("pow") { + binaryMathFunctionEvaluation(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + binaryMathFunctionEvaluation(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), true) + } + + test("hypot") { + binaryMathFunctionEvaluation(Hypot, math.hypot) + } + + test("atan2") { + binaryMathFunctionEvaluation(Atan2, math.atan2) + } } // TODO: Make the tests work with codegen. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala new file mode 100644 index 0000000000000..84f62bf47f955 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -0,0 +1,562 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ +import org.apache.spark.sql.functions.lit + +/** + * :: Experimental :: + * Mathematical Functions available for [[DataFrame]]. + * + * @groupname double_funcs Functions that require DoubleType as an input + * @groupname int_funcs Functions that require IntegerType as an input + * @groupname float_funcs Functions that require FloatType as an input + * @groupname long_funcs Functions that require LongType as an input + */ +@Experimental +// scalastyle:off +object mathfunctions { +// scalastyle:on + + private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + + /** + * Computes the sine of the given value. + * + * @group double_funcs + */ + def sin(e: Column): Column = Sin(e.expr) + + /** + * Computes the sine of the given column. + * + * @group double_funcs + */ + def sin(columnName: String): Column = sin(Column(columnName)) + + /** + * Computes the sine inverse of the given value; the returned angle is in the range + * -pi/2 through pi/2. + * + * @group double_funcs + */ + def asin(e: Column): Column = Asin(e.expr) + + /** + * Computes the sine inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2. + * + * @group double_funcs + */ + def asin(columnName: String): Column = asin(Column(columnName)) + + /** + * Computes the hyperbolic sine of the given value. + * + * @group double_funcs + */ + def sinh(e: Column): Column = Sinh(e.expr) + + /** + * Computes the hyperbolic sine of the given column. + * + * @group double_funcs + */ + def sinh(columnName: String): Column = sinh(Column(columnName)) + + /** + * Computes the cosine of the given value. + * + * @group double_funcs + */ + def cos(e: Column): Column = Cos(e.expr) + + /** + * Computes the cosine of the given column. + * + * @group double_funcs + */ + def cos(columnName: String): Column = cos(Column(columnName)) + + /** + * Computes the cosine inverse of the given value; the returned angle is in the range + * 0.0 through pi. + * + * @group double_funcs + */ + def acos(e: Column): Column = Acos(e.expr) + + /** + * Computes the cosine inverse of the given column; the returned angle is in the range + * 0.0 through pi. + * + * @group double_funcs + */ + def acos(columnName: String): Column = acos(Column(columnName)) + + /** + * Computes the hyperbolic cosine of the given value. + * + * @group double_funcs + */ + def cosh(e: Column): Column = Cosh(e.expr) + + /** + * Computes the hyperbolic cosine of the given column. + * + * @group double_funcs + */ + def cosh(columnName: String): Column = cosh(Column(columnName)) + + /** + * Computes the tangent of the given value. + * + * @group double_funcs + */ + def tan(e: Column): Column = Tan(e.expr) + + /** + * Computes the tangent of the given column. + * + * @group double_funcs + */ + def tan(columnName: String): Column = tan(Column(columnName)) + + /** + * Computes the tangent inverse of the given value. + * + * @group double_funcs + */ + def atan(e: Column): Column = Atan(e.expr) + + /** + * Computes the tangent inverse of the given column. + * + * @group double_funcs + */ + def atan(columnName: String): Column = atan(Column(columnName)) + + /** + * Computes the hyperbolic tangent of the given value. + * + * @group double_funcs + */ + def tanh(e: Column): Column = Tanh(e.expr) + + /** + * Computes the hyperbolic tangent of the given column. + * + * @group double_funcs + */ + def tanh(columnName: String): Column = tanh(Column(columnName)) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * @group double_funcs + */ + def toDeg(e: Column): Column = ToDegrees(e.expr) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * @group double_funcs + */ + def toDeg(columnName: String): Column = toDeg(Column(columnName)) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * + * @group double_funcs + */ + def toRad(e: Column): Column = ToRadians(e.expr) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * + * @group double_funcs + */ + def toRad(columnName: String): Column = toRad(Column(columnName)) + + /** + * Computes the ceiling of the given value. + * + * @group double_funcs + */ + def ceil(e: Column): Column = Ceil(e.expr) + + /** + * Computes the ceiling of the given column. + * + * @group double_funcs + */ + def ceil(columnName: String): Column = ceil(Column(columnName)) + + /** + * Computes the floor of the given value. + * + * @group double_funcs + */ + def floor(e: Column): Column = Floor(e.expr) + + /** + * Computes the floor of the given column. + * + * @group double_funcs + */ + def floor(columnName: String): Column = floor(Column(columnName)) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + * + * @group double_funcs + */ + def rint(e: Column): Column = Rint(e.expr) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + * + * @group double_funcs + */ + def rint(columnName: String): Column = rint(Column(columnName)) + + /** + * Computes the cube-root of the given value. + * + * @group double_funcs + */ + def cbrt(e: Column): Column = Cbrt(e.expr) + + /** + * Computes the cube-root of the given column. + * + * @group double_funcs + */ + def cbrt(columnName: String): Column = cbrt(Column(columnName)) + + /** + * Computes the signum of the given value. + * + * @group double_funcs + */ + def signum(e: Column): Column = Signum(e.expr) + + /** + * Computes the signum of the given column. + * + * @group double_funcs + */ + def signum(columnName: String): Column = signum(Column(columnName)) + + /** + * Computes the signum of the given value. For IntegerType. + * + * @group int_funcs + */ + def isignum(e: Column): Column = ISignum(e.expr) + + /** + * Computes the signum of the given column. For IntegerType. + * + * @group int_funcs + */ + def isignum(columnName: String): Column = isignum(Column(columnName)) + + /** + * Computes the signum of the given value. For FloatType. + * + * @group float_funcs + */ + def fsignum(e: Column): Column = FSignum(e.expr) + + /** + * Computes the signum of the given column. For FloatType. + * + * @group float_funcs + */ + def fsignum(columnName: String): Column = fsignum(Column(columnName)) + + /** + * Computes the signum of the given value. For LongType. + * + * @group long_funcs + */ + def lsignum(e: Column): Column = LSignum(e.expr) + + /** + * Computes the signum of the given column. For FloatType. + * + * @group long_funcs + */ + def lsignum(columnName: String): Column = lsignum(Column(columnName)) + + /** + * Computes the natural logarithm of the given value. + * + * @group double_funcs + */ + def log(e: Column): Column = Log(e.expr) + + /** + * Computes the natural logarithm of the given column. + * + * @group double_funcs + */ + def log(columnName: String): Column = log(Column(columnName)) + + /** + * Computes the logarithm of the given value in Base 10. + * + * @group double_funcs + */ + def log10(e: Column): Column = Log10(e.expr) + + /** + * Computes the logarithm of the given value in Base 10. + * + * @group double_funcs + */ + def log10(columnName: String): Column = log10(Column(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + * + * @group double_funcs + */ + def log1p(e: Column): Column = Log1p(e.expr) + + /** + * Computes the natural logarithm of the given column plus one. + * + * @group double_funcs + */ + def log1p(columnName: String): Column = log1p(Column(columnName)) + + /** + * Computes the exponential of the given value. + * + * @group double_funcs + */ + def exp(e: Column): Column = Exp(e.expr) + + /** + * Computes the exponential of the given column. + * + * @group double_funcs + */ + def exp(columnName: String): Column = exp(Column(columnName)) + + /** + * Computes the exponential of the given value minus one. + * + * @group double_funcs + */ + def expm1(e: Column): Column = Expm1(e.expr) + + /** + * Computes the exponential of the given column. + * + * @group double_funcs + */ + def expm1(columnName: String): Column = expm1(Column(columnName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, r: Column): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, rightName: String): Column = pow(Column(leftName), Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, r: Double): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, r: Column): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, rightName: String): Column = + hypot(Column(leftName), Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, r: Double): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, rightName: String): Column = + atan2(Column(leftName), Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index e02c84872c628..e5c9504d21042 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -41,6 +41,7 @@ import java.util.Map; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.mathfunctions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -98,6 +99,14 @@ public void testVarargMethods() { df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); + + // Varargs with mathfunctions + DataFrame df2 = context.table("testData2"); + df2.select(exp("a"), exp("b")); + df2.select(exp(log("a"))); + df2.select(pow("a", "a"), pow("b", 2.0)); + df2.select(pow(col("a"), col("b")), exp("b")); + df2.select(sin("a"), acos("b")); } @Ignore diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 904073b8cb2aa..680b5c636960d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ - class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala new file mode 100644 index 0000000000000..561553cc925cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.{Double => JavaDouble} + +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.mathfunctions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ + +private[this] object MathExpressionsTestData { + + case class DoubleData(a: JavaDouble, b: JavaDouble) + val doubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() + + val nnDoubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() + + case class NullDoubles(a: JavaDouble) + val nullDoubles = + TestSQLContext.sparkContext.parallelize( + NullDoubles(1.0) :: + NullDoubles(2.0) :: + NullDoubles(3.0) :: + NullDoubles(null) :: Nil + ).toDF() +} + +class MathExpressionsSuite extends QueryTest { + + import MathExpressionsTestData._ + + def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + c: Column => Column, + f: T => T): Unit = { + checkAnswer( + doubleData.select(c('a)), + (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c('b)), + (1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a)), + (1 to 10).map(n => Row(f(n * 0.1))) + ) + + if (f(-1) === math.log1p(-1)) { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) + ) + } else { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 10).map(n => Row(null)) + ) + } + + checkAnswer( + nnDoubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testTwoToOneMathFunction( + c: (Column, Column) => Column, + d: (Column, Double) => Column, + f: (Double, Double) => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a, 'a)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + + checkAnswer( + nnDoubleData.select(c('a, 'b)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) + ) + + checkAnswer( + nnDoubleData.select(d('a, 2.0)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), 2.0))) + ) + + checkAnswer( + nnDoubleData.select(d('a, -0.5)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), -0.5))) + ) + + val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) + + checkAnswer( + nullDoubles.select(c('a, 'a)).orderBy('a.asc), + Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + } + + test("sin") { + testOneToOneMathFunction(sin, math.sin) + } + + test("asin") { + testOneToOneMathFunction(asin, math.asin) + } + + test("sinh") { + testOneToOneMathFunction(sinh, math.sinh) + } + + test("cos") { + testOneToOneMathFunction(cos, math.cos) + } + + test("acos") { + testOneToOneMathFunction(acos, math.acos) + } + + test("cosh") { + testOneToOneMathFunction(cosh, math.cosh) + } + + test("tan") { + testOneToOneMathFunction(tan, math.tan) + } + + test("atan") { + testOneToOneMathFunction(atan, math.atan) + } + + test("tanh") { + testOneToOneMathFunction(tanh, math.tanh) + } + + test("toDeg") { + testOneToOneMathFunction(toDeg, math.toDegrees) + } + + test("toRad") { + testOneToOneMathFunction(toRad, math.toRadians) + } + + test("cbrt") { + testOneToOneMathFunction(cbrt, math.cbrt) + } + + test("ceil") { + testOneToOneMathFunction(ceil, math.ceil) + } + + test("floor") { + testOneToOneMathFunction(floor, math.floor) + } + + test("rint") { + testOneToOneMathFunction(rint, math.rint) + } + + test("exp") { + testOneToOneMathFunction(exp, math.exp) + } + + test("expm1") { + testOneToOneMathFunction(expm1, math.expm1) + } + + test("signum") { + testOneToOneMathFunction[Double](signum, math.signum) + } + + test("isignum") { + testOneToOneMathFunction[Int](isignum, math.signum) + } + + test("fsignum") { + testOneToOneMathFunction[Float](fsignum, math.signum) + } + + test("lsignum") { + testOneToOneMathFunction[Long](lsignum, math.signum) + } + + test("pow") { + testTwoToOneMathFunction(pow, pow, math.pow) + } + + test("hypot") { + testTwoToOneMathFunction(hypot, hypot, math.hypot) + } + + test("atan2") { + testTwoToOneMathFunction(atan2, atan2, math.atan2) + } + + test("log") { + testOneToOneNonNegativeMathFunction(log, math.log) + } + + test("log10") { + testOneToOneNonNegativeMathFunction(log10, math.log10) + } + + test("log1p") { + testOneToOneNonNegativeMathFunction(log1p, math.log1p) + } + +} From 9e4e82b7bca1129bcd5e0274b9ae1b1be3fb93da Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 27 Apr 2015 23:48:02 -0700 Subject: [PATCH 149/191] [SPARK-5946] [STREAMING] Add Python API for direct Kafka stream Currently only added `createDirectStream` API, I'm not sure if `createRDD` is also needed, since some Java object needs to be wrapped in Python. Please help to review, thanks a lot. Author: jerryshao Author: Saisai Shao Closes #4723 from jerryshao/direct-kafka-python-api and squashes the following commits: a1fe97c [jerryshao] Fix rebase issue eebf333 [jerryshao] Address the comments da40f4e [jerryshao] Fix Python 2.6 Syntax error issue 5c0ee85 [jerryshao] Style fix 4aeac18 [jerryshao] Fix bug in example code 7146d86 [jerryshao] Add unit test bf3bdd6 [jerryshao] Add more APIs and address the comments f5b3801 [jerryshao] Small style fix 8641835 [Saisai Shao] Rebase and update the code 589c05b [Saisai Shao] Fix the style d6fcb6a [Saisai Shao] Address the comments dfda902 [Saisai Shao] Style fix 0f7d168 [Saisai Shao] Add the doc and fix some style issues 67e6880 [Saisai Shao] Fix test bug 917b0db [Saisai Shao] Add Python createRDD API for Kakfa direct stream c3fc11d [jerryshao] Modify the docs 2c00936 [Saisai Shao] address the comments 3360f44 [jerryshao] Fix code style e0e0f0d [jerryshao] Code clean and bug fix 338c41f [Saisai Shao] Add python API and example for direct kafka stream --- .../streaming/direct_kafka_wordcount.py | 55 ++++++ .../spark/streaming/kafka/KafkaUtils.scala | 92 +++++++++- python/pyspark/streaming/kafka.py | 167 +++++++++++++++++- python/pyspark/streaming/tests.py | 84 ++++++++- 4 files changed, 383 insertions(+), 15 deletions(-) create mode 100644 examples/src/main/python/streaming/direct_kafka_wordcount.py diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py new file mode 100644 index 0000000000000..6ef188a220c51 --- /dev/null +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text directly received from Kafka in every 2 seconds. + Usage: direct_kafka_wordcount.py + + To run this on your local machine, you need to setup Kafka and create a producer first, see + http://kafka.apache.org/documentation.html#quickstart + + and then run the example + `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ + spark-streaming-kafka-assembly-*.jar \ + examples/src/main/python/streaming/direct_kafka_wordcount.py \ + localhost:9092 test` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kafka import KafkaUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: direct_kafka_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") + ssc = StreamingContext(sc, 2) + + brokers, topic = sys.argv[1:] + kvs = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 5a9bd4214cf51..0721ddaf7055a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -21,6 +21,7 @@ import java.lang.{Integer => JInt} import java.lang.{Long => JLong} import java.util.{Map => JMap} import java.util.{Set => JSet} +import java.util.{List => JList} import scala.reflect.ClassTag import scala.collection.JavaConversions._ @@ -234,7 +235,6 @@ object KafkaUtils { new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler) } - /** * Create a RDD from Kafka using offset ranges for each topic and partition. * @@ -558,4 +558,94 @@ private class KafkaUtilsPythonHelper { topics, storageLevel) } + + def createRDD( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jrdd = KafkaUtils.createRDD[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jsc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), + leaders, + messageHandler + ) + new JavaPairRDD(jrdd.rdd) + } + + def createDirectStream( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong] + ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { + + if (!fromOffsets.isEmpty) { + import scala.collection.JavaConversions._ + val topicsFromOffsets = fromOffsets.keySet().map(_.topic) + if (topicsFromOffsets != topics.toSet) { + throw new IllegalStateException(s"The specified topics: ${topics.toSet.mkString(" ")} " + + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") + } + } + + if (fromOffsets.isEmpty) { + KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + kafkaParams, + topics) + } else { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jstream = KafkaUtils.createDirectStream[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + fromOffsets, + messageHandler) + new JavaPairInputDStream(jstream.inputDStream) + } + } + + def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong + ): OffsetRange = OffsetRange.create(topic, partition, fromOffset, untilOffset) + + def createTopicAndPartition(topic: String, partition: JInt): TopicAndPartition = + TopicAndPartition(topic, partition) + + def createBroker(host: String, port: JInt): Broker = Broker(host, port) } diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 8d610d6569b4a..e278b29003f69 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -17,11 +17,12 @@ from py4j.java_gateway import Py4JJavaError +from pyspark.rdd import RDD from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream -__all__ = ['KafkaUtils', 'utf8_decoder'] +__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] def utf8_decoder(s): @@ -67,7 +68,104 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, except Py4JJavaError as e: # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): - print(""" + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create an input stream that directly pulls messages from a Kafka Broker and specific offset. + + This is not a receiver based Kafka input stream, it directly pulls the message from Kafka + in each batch duration and processed without storing. + + This does not use Zookeeper to store offsets. The consumed offsets are tracked + by the stream itself. For interoperability with Kafka monitoring tools that depend on + Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + You can access the offsets used in each batch from the generated RDDs (see + + To recover from driver failures, you have to enable checkpointing in the StreamingContext. + The information on consumed offset can be recovered from the checkpoint. + See the programming guide for details (constraints, etc.). + + :param ssc: StreamingContext object. + :param topics: list of topic_name to consume. + :param kafkaParams: Additional params for Kafka. + :param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting + point of the stream. + :param keyDecoder: A function used to decode key (default is utf8_decoder). + :param valueDecoder: A function used to decode value (default is utf8_decoder). + :return: A DStream object + """ + if not isinstance(topics, list): + raise TypeError("topics should be list") + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + + jfromOffsets = dict([(k._jTopicAndPartition(helper), + v) for (k, v) in fromOffsets.items()]) + jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createRDD(sc, kafkaParams, offsetRanges, leaders={}, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create a RDD from Kafka using offset ranges for each topic and partition. + :param sc: SparkContext object + :param kafkaParams: Additional params for Kafka + :param offsetRanges: list of offsetRange to specify topic:partition:[start, end) to consume + :param leaders: Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty + map, in which case leaders will be looked up on the driver. + :param keyDecoder: A function used to decode key (default is utf8_decoder) + :param valueDecoder: A function used to decode value (default is utf8_decoder) + :return: A RDD object + """ + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + if not isinstance(offsetRanges, list): + raise TypeError("offsetRanges should be list") + + try: + helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] + jleaders = dict([(k._jTopicAndPartition(helper), + v._jBroker(helper)) for (k, v) in leaders.items()]) + jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(sc) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + rdd = RDD(jrdd, sc, ser) + return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def _printErrorMsg(sc): + print(""" ________________________________________________________________________________________________ Spark Streaming's Kafka libraries not found in class path. Try one of the following. @@ -85,8 +183,63 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, ________________________________________________________________________________________________ -""" % (ssc.sparkContext.version, ssc.sparkContext.version)) - raise e - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) - return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) +""" % (sc.version, sc.version)) + + +class OffsetRange(object): + """ + Represents a range of offsets from a single Kafka TopicAndPartition. + """ + + def __init__(self, topic, partition, fromOffset, untilOffset): + """ + Create a OffsetRange to represent range of offsets + :param topic: Kafka topic name. + :param partition: Kafka partition id. + :param fromOffset: Inclusive starting offset. + :param untilOffset: Exclusive ending offset. + """ + self._topic = topic + self._partition = partition + self._fromOffset = fromOffset + self._untilOffset = untilOffset + + def _jOffsetRange(self, helper): + return helper.createOffsetRange(self._topic, self._partition, self._fromOffset, + self._untilOffset) + + +class TopicAndPartition(object): + """ + Represents a specific top and partition for Kafka. + """ + + def __init__(self, topic, partition): + """ + Create a Python TopicAndPartition to map to the Java related object + :param topic: Kafka topic name. + :param partition: Kafka partition id. + """ + self._topic = topic + self._partition = partition + + def _jTopicAndPartition(self, helper): + return helper.createTopicAndPartition(self._topic, self._partition) + + +class Broker(object): + """ + Represent the host and port info for a Kafka broker. + """ + + def __init__(self, host, port): + """ + Create a Python Broker to map to the Java related object. + :param host: Broker's hostname. + :param port: Broker's port. + """ + self._host = host + self._port = port + + def _jBroker(self, helper): + return helper.createBroker(self._host, self._port) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5fa1e5ef081ab..7c06c203455d9 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -21,6 +21,7 @@ import time import operator import tempfile +import random import struct from functools import reduce @@ -35,7 +36,7 @@ from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext -from pyspark.streaming.kafka import KafkaUtils +from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition class PySparkStreamingTestCase(unittest.TestCase): @@ -590,9 +591,27 @@ def tearDown(self): super(KafkaStreamTests, self).tearDown() + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _validateStreamResult(self, sendData, stream): + result = {} + for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), + sum(sendData.values()))): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + + def _validateRddResult(self, sendData, rdd): + result = {} + for i in rdd.map(lambda x: x[1]).collect(): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + def test_kafka_stream(self): """Test the Python Kafka stream API.""" - topic = "topic1" + topic = self._randomTopic() sendData = {"a": 3, "b": 5, "c": 10} self._kafkaTestUtils.createTopic(topic) @@ -601,13 +620,64 @@ def test_kafka_stream(self): stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, {"auto.offset.reset": "smallest"}) + self._validateStreamResult(sendData, stream) - result = {} - for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), - sum(sendData.values()))): - result[i] = result.get(i, 0) + 1 + def test_kafka_direct_stream(self): + """Test the Python direct Kafka stream API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} - self.assertEqual(sendData, result) + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_from_offset(self): + """Test the Python direct Kafka stream API with start offset specified.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + fromOffsets = {TopicAndPartition(topic, 0): long(0)} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd(self): + """Test the Python direct Kafka RDD API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) + self._validateRddResult(sendData, rdd) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_with_leaders(self): + """Test the Python direct Kafka RDD API with leaders.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + address = self._kafkaTestUtils.brokerAddress().split(":") + leaders = {TopicAndPartition(topic, 0): Broker(address[0], int(address[1]))} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) + self._validateRddResult(sendData, rdd) if __name__ == "__main__": unittest.main() From bf35edd9d4b8b11df9f47b6ff43831bc95f06322 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 28 Apr 2015 00:38:14 -0700 Subject: [PATCH 150/191] [SPARK-7187] SerializationDebugger should not crash user code rxin Author: Andrew Or Closes #5734 from andrewor14/ser-deb and squashes the following commits: e8aad6c [Andrew Or] NonFatal 57d0ef4 [Andrew Or] try catch improveException --- .../spark/serializer/SerializationDebugger.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index cecb992579655..5abfa467c0ec8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -23,6 +23,7 @@ import java.security.AccessController import scala.annotation.tailrec import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.Logging @@ -35,8 +36,15 @@ private[serializer] object SerializationDebugger extends Logging { */ def improveException(obj: Any, e: NotSerializableException): NotSerializableException = { if (enableDebugging && reflect != null) { - new NotSerializableException( - e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + try { + new NotSerializableException( + e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + } catch { + case NonFatal(t) => + // Fall back to old exception + logWarning("Exception in serialization debugger", t) + e + } } else { e } From d94cd1a733d5715792e6c4eac87f0d5c81aebbe2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 28 Apr 2015 00:39:08 -0700 Subject: [PATCH 151/191] [SPARK-7135][SQL] DataFrame expression for monotonically increasing IDs. Author: Reynold Xin Closes #5709 from rxin/inc-id and squashes the following commits: 7853611 [Reynold Xin] private sql. a9fda0d [Reynold Xin] Missed a few numbers. 343d896 [Reynold Xin] Self review feedback. a7136cb [Reynold Xin] [SPARK-7135][SQL] DataFrame expression for monotonically increasing IDs. --- python/pyspark/sql/functions.py | 22 +++++++- .../MonotonicallyIncreasingID.scala | 53 +++++++++++++++++++ .../expressions/SparkPartitionID.scala | 6 +-- .../org/apache/spark/sql/functions.scala | 16 ++++++ .../spark/sql/ColumnExpressionSuite.scala | 11 ++++ 5 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f48b7b5d10af7..7b86655d9c82f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -103,8 +103,28 @@ def countDistinct(col, *cols): return Column(jc) +def monotonicallyIncreasingId(): + """A column that generates monotonically increasing 64-bit integers. + + The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + The current implementation puts the partition ID in the upper 31 bits, and the record number + within each partition in the lower 33 bits. The assumption is that the data frame has + less than 1 billion partitions, and each partition has less than 8 billion records. + + As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + This expression would return the following IDs: + 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + + >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) + >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.monotonicallyIncreasingId()) + + def sparkPartitionId(): - """Returns a column for partition ID of the Spark task. + """A column for partition ID of the Spark task. Note that this is indeterministic because it depends on data partitioning and task scheduling. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala new file mode 100644 index 0000000000000..9ac732b55b188 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.expressions + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.expressions.{Row, LeafExpression} +import org.apache.spark.sql.types.{LongType, DataType} + +/** + * Returns monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the lower 33 bits + * represent the record number within each partition. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * Since this expression is stateful, it cannot be a case object. + */ +private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { + + /** + * Record ID within each partition. By being transient, count's value is reset to 0 every time + * we serialize and deserialize it. + */ + @transient private[this] var count: Long = 0L + + override type EvaluatedType = Long + + override def nullable: Boolean = false + + override def dataType: DataType = LongType + + override def eval(input: Row): Long = { + val currentCount = count + count += 1 + (TaskContext.get().partitionId().toLong << 33) + currentCount + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index fe7607c6ac340..c2c6cbd491598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -18,16 +18,14 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.expressions.{Row, Expression} -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Row} import org.apache.spark.sql.types.{IntegerType, DataType} /** * Expression that returns the current partition id of the Spark task. */ -case object SparkPartitionID extends Expression with trees.LeafNode[Expression] { - self: Product => +private[sql] case object SparkPartitionID extends LeafExpression { override type EvaluatedType = Int diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9738fd4f93bad..aa31d04a0cbe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -301,6 +301,22 @@ object functions { */ def lower(e: Column): Column = Lower(e.expr) + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + */ + def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID() + /** * Unary minus, i.e. negate the expression. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 680b5c636960d..2ba5fc21ff57c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -309,6 +309,17 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("monotonicallyIncreasingId") { + // Make sure we have 2 partitions, each with 2 records. + val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + Iterator(Tuple1(1), Tuple1(2)) + }.toDF("a") + checkAnswer( + df.select(monotonicallyIncreasingId()), + Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil + ) + } + test("sparkPartitionId") { val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") checkAnswer( From e13cd86567a43672297bb488088dd8f40ec799bf Mon Sep 17 00:00:00 2001 From: Pei-Lun Lee Date: Tue, 28 Apr 2015 16:50:18 +0800 Subject: [PATCH 152/191] [SPARK-6352] [SQL] Custom parquet output committer Add new config "spark.sql.parquet.output.committer.class" to allow custom parquet output committer and an output committer class specific to use on s3. Fix compilation error introduced by https://github.com/apache/spark/pull/5042. Respect ParquetOutputFormat.ENABLE_JOB_SUMMARY flag. Author: Pei-Lun Lee Closes #5525 from ypcat/spark-6352 and squashes the following commits: 54c6b15 [Pei-Lun Lee] error handling 472870e [Pei-Lun Lee] add back custom parquet output committer ddd0f69 [Pei-Lun Lee] Merge branch 'master' of https://github.com/apache/spark into spark-6352 9ece5c5 [Pei-Lun Lee] compatibility with hadoop 1.x 8413fcd [Pei-Lun Lee] Merge branch 'master' of https://github.com/apache/spark into spark-6352 fe65915 [Pei-Lun Lee] add support for parquet config parquet.enable.summary-metadata e17bf47 [Pei-Lun Lee] Merge branch 'master' of https://github.com/apache/spark into spark-6352 9ae7545 [Pei-Lun Lee] [SPARL-6352] [SQL] Change to allow custom parquet output committer. 0d540b9 [Pei-Lun Lee] [SPARK-6352] [SQL] add license c42468c [Pei-Lun Lee] [SPARK-6352] [SQL] add test case 0fc03ca [Pei-Lun Lee] [SPARK-6532] [SQL] hide class DirectParquetOutputCommitter 769bd67 [Pei-Lun Lee] DirectParquetOutputCommitter f75e261 [Pei-Lun Lee] DirectParquetOutputCommitter --- .../DirectParquetOutputCommitter.scala | 73 +++++++++++++++++++ .../sql/parquet/ParquetTableOperations.scala | 21 ++++++ .../spark/sql/parquet/ParquetIOSuite.scala | 22 ++++++ 3 files changed, 116 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala new file mode 100644 index 0000000000000..f5ce2718bec4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter + +import parquet.Log +import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} + +private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + val LOG = Log.getLog(classOf[ParquetOutputCommitter]) + + override def getWorkPath(): Path = outputPath + override def abortTask(taskContext: TaskAttemptContext): Unit = {} + override def commitTask(taskContext: TaskAttemptContext): Unit = {} + override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true + override def setupJob(jobContext: JobContext): Unit = {} + override def setupTask(taskContext: TaskAttemptContext): Unit = {} + + override def commitJob(jobContext: JobContext) { + val configuration = ContextUtil.getConfiguration(jobContext) + val fileSystem = outputPath.getFileSystem(configuration) + + if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { + try { + val outputStatus = fileSystem.getFileStatus(outputPath) + val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) + try { + ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) + } catch { + case e: Exception => { + LOG.warn("could not write summary file for " + outputPath, e) + val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fileSystem.exists(metadataPath)) { + fileSystem.delete(metadataPath, true) + } + } + } + } catch { + case e: Exception => LOG.warn("could not write summary file for " + outputPath, e) + } + } + + if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) { + try { + val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fileSystem.create(successPath).close() + } catch { + case e: Exception => LOG.warn("could not write success file for " + outputPath, e) + } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index a938b77578686..aded126ea0615 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -381,6 +381,7 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) extends parquet.hadoop.ParquetOutputFormat[Row] { // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} + var committer: OutputCommitter = null // override to choose output filename so not overwrite existing ones override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { @@ -403,6 +404,26 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] } + + // override to create output committer from configuration + override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + if (committer == null) { + val output = getOutputPath(context) + val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", + classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) + val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] + } + committer + } + + // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 + private def getOutputPath(context: TaskAttemptContext): Path = { + context.getConfiguration().get("mapred.output.dir") match { + case null => null + case name => new Path(name) + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 97c0f439acf13..b504842053690 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -381,6 +381,28 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } } + + test("SPARK-6352 DirectParquetOutputCommitter") { + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } + finally { + configuration.set("spark.sql.parquet.output.committer.class", + "parquet.hadoop.ParquetOutputCommitter") + } + } } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { From 7f3b3b7eb7d14767124a28ec0062c4d60d6c16fc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 28 Apr 2015 07:48:34 -0400 Subject: [PATCH 153/191] [SPARK-7168] [BUILD] Update plugin versions in Maven build and centralize versions Update Maven build plugin versions and centralize plugin version management Author: Sean Owen Closes #5720 from srowen/SPARK-7168 and squashes the following commits: 98a8947 [Sean Owen] Make install, deploy plugin versions explicit 4ecf3b2 [Sean Owen] Update Maven build plugin versions and centralize plugin version management --- assembly/pom.xml | 1 - core/pom.xml | 1 - network/common/pom.xml | 1 - pom.xml | 44 ++++++++++++++++++++++++++++++------------ sql/hive/pom.xml | 1 - 5 files changed, 32 insertions(+), 16 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 20593e710dedb..2b4d0a990bf22 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -194,7 +194,6 @@ org.apache.maven.plugins maven-assembly-plugin - 2.4 dist diff --git a/core/pom.xml b/core/pom.xml index 5e89d548cd47f..459ef66712c36 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -478,7 +478,6 @@ org.codehaus.mojo exec-maven-plugin - 1.3.2 sparkr-pkg diff --git a/network/common/pom.xml b/network/common/pom.xml index 22c738bde6d42..0c3147761cfc5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -95,7 +95,6 @@ org.apache.maven.plugins maven-jar-plugin - 2.2 test-jar-on-test-compile diff --git a/pom.xml b/pom.xml index 9fbce1d639d8b..928f5d0f5efad 100644 --- a/pom.xml +++ b/pom.xml @@ -1082,7 +1082,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.3.1 + 1.4 enforce-versions @@ -1105,7 +1105,7 @@ org.codehaus.mojo build-helper-maven-plugin - 1.8 + 1.9.1 net.alchim31.maven @@ -1176,7 +1176,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.1 + 3.3 ${java.version} ${java.version} @@ -1189,7 +1189,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.18 + 2.18.1 @@ -1260,17 +1260,17 @@ org.apache.maven.plugins maven-jar-plugin - 2.4 + 2.6 org.apache.maven.plugins maven-antrun-plugin - 1.7 + 1.8 org.apache.maven.plugins maven-source-plugin - 2.2.1 + 2.4 true @@ -1287,7 +1287,7 @@ org.apache.maven.plugins maven-clean-plugin - 2.5 + 2.6.1 @@ -1305,7 +1305,27 @@ org.apache.maven.plugins maven-javadoc-plugin - 2.10.1 + 2.10.3 + + + org.codehaus.mojo + exec-maven-plugin + 1.4.0 + + + org.apache.maven.plugins + maven-assembly-plugin + 2.5.3 + + + org.apache.maven.plugins + maven-install-plugin + 2.5.2 + + + org.apache.maven.plugins + maven-deploy-plugin + 2.8.2 @@ -1315,7 +1335,7 @@ org.apache.maven.plugins maven-dependency-plugin - 2.9 + 2.10 test-compile @@ -1334,7 +1354,7 @@ org.codehaus.gmavenplus gmavenplus-plugin - 1.2 + 1.5 process-test-classes @@ -1359,7 +1379,7 @@ org.apache.maven.plugins maven-shade-plugin - 2.2 + 2.3 false diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 21dce8d8a565a..e322340094e6f 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -183,7 +183,6 @@ org.apache.maven.plugins maven-dependency-plugin - 2.4 copy-dependencies From 75905c57cd57bc5b650ac5f486580ef8a229b260 Mon Sep 17 00:00:00 2001 From: Jim Carroll Date: Tue, 28 Apr 2015 07:51:02 -0400 Subject: [PATCH 154/191] [SPARK-7100] [MLLIB] Fix persisted RDD leak in GradientBoostTrees This fixes a leak of a persisted RDD where GradientBoostTrees can call persist but never unpersists. Jira: https://issues.apache.org/jira/browse/SPARK-7100 Discussion: http://apache-spark-developers-list.1001551.n3.nabble.com/GradientBoostTrees-leaks-a-persisted-RDD-td11750.html Author: Jim Carroll Closes #5669 from jimfcarroll/gb-unpersist-fix and squashes the following commits: 45f4b03 [Jim Carroll] [SPARK-7100][MLLib] Fix persisted RDD leak in GradientBoostTrees --- .../apache/spark/mllib/tree/GradientBoostedTrees.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 0e31c7ed58df8..deac390130128 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -177,9 +177,10 @@ object GradientBoostedTrees extends Logging { treeStrategy.assertValid() // Cache input - if (input.getStorageLevel == StorageLevel.NONE) { + val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { input.persist(StorageLevel.MEMORY_AND_DISK) - } + true + } else false timer.stop("init") @@ -265,6 +266,9 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + + if (persistedInput) input.unpersist() + if (validate) { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, From 268c419f1586110b90e68f98cd000a782d18828c Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Tue, 28 Apr 2015 07:55:21 -0400 Subject: [PATCH 155/191] [SPARK-6435] spark-shell --jars option does not add all jars to classpath Modified to accept double-quotated args properly in spark-shell.cmd. Author: Masayoshi TSUZUKI Closes #5227 from tsudukim/feature/SPARK-6435-2 and squashes the following commits: ac55787 [Masayoshi TSUZUKI] removed unnecessary argument. 60789a7 [Masayoshi TSUZUKI] Merge branch 'master' of https://github.com/apache/spark into feature/SPARK-6435-2 1fee420 [Masayoshi TSUZUKI] fixed test code for escaping '='. 0d4dc41 [Masayoshi TSUZUKI] - escaped comman and semicolon in CommandBuilderUtils.java - added random string to the temporary filename - double-quotation followed by `cmd /c` did not worked properly - no need to escape `=` by `^` - if double-quoted string ended with `\` like classpath, the last `\` is parsed as the escape charactor and the closing `"` didn't work properly 2a332e5 [Masayoshi TSUZUKI] Merge branch 'master' into feature/SPARK-6435-2 04f4291 [Masayoshi TSUZUKI] [SPARK-6435] spark-shell --jars option does not add all jars to classpath --- bin/spark-class2.cmd | 5 ++++- .../org/apache/spark/launcher/CommandBuilderUtils.java | 9 ++++----- .../src/main/java/org/apache/spark/launcher/Main.java | 6 +----- .../apache/spark/launcher/CommandBuilderUtilsSuite.java | 5 ++++- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 3d068dd3a2739..db09fa27e51a6 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -61,7 +61,10 @@ if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. -for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %*"') do ( +set LAUNCHER_OUTPUT=%temp%\spark-class-launcher-output-%RANDOM%.txt +"%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% +for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do ( set SPARK_CMD=%%i ) +del %LAUNCHER_OUTPUT% %SPARK_CMD% diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 8028e42ffb483..261402856ac5e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -244,7 +244,7 @@ static String quoteForBatchScript(String arg) { boolean needsQuotes = false; for (int i = 0; i < arg.length(); i++) { int c = arg.codePointAt(i); - if (Character.isWhitespace(c) || c == '"' || c == '=') { + if (Character.isWhitespace(c) || c == '"' || c == '=' || c == ',' || c == ';') { needsQuotes = true; break; } @@ -261,15 +261,14 @@ static String quoteForBatchScript(String arg) { quoted.append('"'); break; - case '=': - quoted.append('^'); - break; - default: break; } quoted.appendCodePoint(cp); } + if (arg.codePointAt(arg.length() - 1) == '\\') { + quoted.append("\\"); + } quoted.append("\""); return quoted.toString(); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 206acfb514d86..929b29a49ed70 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -101,12 +101,9 @@ public static void main(String[] argsArray) throws Exception { * The method quotes all arguments so that spaces are handled as expected. Quotes within arguments * are "double quoted" (which is batch for escaping a quote). This page has more details about * quoting and other batch script fun stuff: http://ss64.com/nt/syntax-esc.html - * - * The command is executed using "cmd /c" and formatted in single line, since that's the - * easiest way to consume this from a batch script (see spark-class2.cmd). */ private static String prepareWindowsCommand(List cmd, Map childEnv) { - StringBuilder cmdline = new StringBuilder("cmd /c \""); + StringBuilder cmdline = new StringBuilder(); for (Map.Entry e : childEnv.entrySet()) { cmdline.append(String.format("set %s=%s", e.getKey(), e.getValue())); cmdline.append(" && "); @@ -115,7 +112,6 @@ private static String prepareWindowsCommand(List cmd, Map Date: Tue, 28 Apr 2015 09:46:08 -0700 Subject: [PATCH 156/191] [SPARK-5253] [ML] LinearRegression with L1/L2 (ElasticNet) using OWLQN Author: DB Tsai Author: DB Tsai Closes #4259 from dbtsai/lir and squashes the following commits: a81c201 [DB Tsai] add import org.apache.spark.util.Utils back 9fc48ed [DB Tsai] rebase 2178b63 [DB Tsai] add comments 9988ca8 [DB Tsai] addressed feedback and fixed a bug. TODO: documentation and build another synthetic dataset which can catch the bug fixed in this commit. fcbaefe [DB Tsai] Refactoring 4eb078d [DB Tsai] first commit --- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +- .../spark/ml/param/shared/sharedParams.scala | 34 ++ .../ml/regression/LinearRegression.scala | 304 ++++++++++++++++-- .../apache/spark/mllib/linalg/Vectors.scala | 8 +- .../spark/mllib/optimization/Gradient.scala | 6 +- .../spark/mllib/optimization/LBFGS.scala | 15 +- .../mllib/util/LinearDataGenerator.scala | 43 ++- .../ml/regression/LinearRegressionSuite.scala | 158 +++++++-- 8 files changed, 508 insertions(+), 64 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index e88c48741e99f..3f7e8f5a0b22c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -46,7 +46,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("outputCol", "output column name"), ParamDesc[Int]("checkpointInterval", "checkpoint interval"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), - ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()"))) + ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")), + ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"), + ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms")) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a860b8834cff9..7d2c76d6c62c8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -276,4 +276,38 @@ trait HasSeed extends Params { /** @group getParam */ final def getSeed: Long = getOrDefault(seed) } + +/** + * :: DeveloperApi :: + * Trait for shared param elasticNetParam. + */ +@DeveloperApi +trait HasElasticNetParam extends Params { + + /** + * Param for the ElasticNet mixing parameter. + * @group param + */ + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter") + + /** @group getParam */ + final def getElasticNetParam: Double = getOrDefault(elasticNetParam) +} + +/** + * :: DeveloperApi :: + * Trait for shared param tol. + */ +@DeveloperApi +trait HasTol extends Params { + + /** + * Param for the convergence tolerance for iterative algorithms. + * @group param + */ + final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms") + + /** @group getParam */ + final def getTol: Double = getOrDefault(tol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 26ca7459c4fdf..f92c6816eb54c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -17,21 +17,29 @@ package org.apache.spark.ml.regression +import scala.collection.mutable + +import breeze.linalg.{norm => brzNorm, DenseVector => BDV} +import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction} + import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Params, ParamMap} -import org.apache.spark.ml.param.shared._ -import org.apache.spark.mllib.linalg.{BLAS, Vector} -import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel - +import org.apache.spark.util.StatCounter /** * Params for linear regression. */ private[regression] trait LinearRegressionParams extends RegressorParams - with HasRegParam with HasMaxIter - + with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol /** * :: AlphaComponent :: @@ -42,34 +50,119 @@ private[regression] trait LinearRegressionParams extends RegressorParams class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams { - setDefault(regParam -> 0.1, maxIter -> 100) - - /** @group setParam */ + /** + * Set the regularization parameter. + * Default is 0.0. + * @group setParam + */ def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Set the ElasticNet mixing parameter. + * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + * For 0 < alpha < 1, the penalty is a combination of L1 and L2. + * Default is 0.0 which is an L2 penalty. + * @group setParam + */ + def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) + setDefault(elasticNetParam -> 0.0) - /** @group setParam */ + /** + * Set the maximal number of iterations. + * Default is 100. + * @group setParam + */ def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = { - // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val oldDataset = extractLabeledPoints(dataset, paramMap) + // Extract columns from data. If dataset is persisted, do not persist instances. + val instances = extractLabeledPoints(dataset, paramMap).map { + case LabeledPoint(label: Double, features: Vector) => (label, features) + } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) { - oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + instances.persist(StorageLevel.MEMORY_AND_DISK) + } + + val (summarizer, statCounter) = instances.treeAggregate( + (new MultivariateOnlineSummarizer, new StatCounter))( { + case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), + (label: Double, features: Vector)) => + (summarizer.add(features), statCounter.merge(label)) + }, { + case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), + (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => + (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) + }) + + val numFeatures = summarizer.mean.size + val yMean = statCounter.mean + val yStd = math.sqrt(statCounter.variance) + + val featuresMean = summarizer.mean.toArray + val featuresStd = summarizer.variance.toArray.map(math.sqrt) + + // Since we implicitly do the feature scaling when we compute the cost function + // to improve the convergence, the effective regParam will be changed. + val effectiveRegParam = paramMap(regParam) / yStd + val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam + val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam + + val costFun = new LeastSquaresCostFun(instances, yStd, yMean, + featuresStd, featuresMean, effectiveL2RegParam) + + val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { + new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol)) + } else { + new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol)) + } + + val initialWeights = Vectors.zeros(numFeatures) + val states = + optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector) + + var state = states.next() + val lossHistory = mutable.ArrayBuilder.make[Double] + + while (states.hasNext) { + lossHistory += state.value + state = states.next() + } + lossHistory += state.value + + // TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector. + // The weights are trained in the scaled space; we're converting them back to + // the original space. + val weights = { + val rawWeights = state.x.toArray.clone() + var i = 0 + while (i < rawWeights.length) { + rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 } + i += 1 + } + Vectors.dense(rawWeights) } - // Train model - val lr = new LinearRegressionWithSGD() - lr.optimizer - .setRegParam(paramMap(regParam)) - .setNumIterations(paramMap(maxIter)) - val model = lr.run(oldDataset) - val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept) + // The intercept in R's GLMNET is computed using closed form after the coefficients are + // converged. See the following discussion for detail. + // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + val intercept = yMean - dot(weights, Vectors.dense(featuresMean)) if (handlePersistence) { - oldDataset.unpersist() + instances.unpersist() } - lrm + new LinearRegressionModel(this, paramMap, weights, intercept) } } @@ -88,7 +181,7 @@ class LinearRegressionModel private[ml] ( with LinearRegressionParams { override protected def predict(features: Vector): Double = { - BLAS.dot(features, weights) + intercept + dot(features, weights) + intercept } override protected def copy(): LinearRegressionModel = { @@ -97,3 +190,168 @@ class LinearRegressionModel private[ml] ( m } } + +/** + * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, + * as used in linear regression for samples in sparse or dense vector in a online fashion. + * + * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + + * * Compute gradient and loss for a Least-squared loss function, as used in linear regression. + * This is correct for the averaged least squares loss function (mean squared error) + * L = 1/2n ||A weights-y||^2 + * See also the documentation for the precise formulation. + * + * @param weights weights/coefficients corresponding to features + * + * @param updater Updater to be used to update weights after every iteration. + */ +private class LeastSquaresAggregator( + weights: Vector, + labelStd: Double, + labelMean: Double, + featuresStd: Array[Double], + featuresMean: Array[Double]) extends Serializable { + + private var totalCnt: Long = 0 + private var lossSum = 0.0 + private var diffSum = 0.0 + + private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { + val weightsArray = weights.toArray.clone() + var sum = 0.0 + var i = 0 + while (i < weightsArray.length) { + if (featuresStd(i) != 0.0) { + weightsArray(i) /= featuresStd(i) + sum += weightsArray(i) * featuresMean(i) + } else { + weightsArray(i) = 0.0 + } + i += 1 + } + (weightsArray, -sum + labelMean / labelStd, weightsArray.length) + } + private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + + private val gradientSumArray: Array[Double] = Array.ofDim[Double](dim) + + /** + * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * + * @param label The label for this data point. + * @param data The features for one data point in dense/sparse vector format to be added + * into this aggregator. + * @return This LeastSquaresAggregator object. + */ + def add(label: Double, data: Vector): this.type = { + require(dim == data.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${data.size}.") + + val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + data.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += diff * value / featuresStd(index) + } + } + lossSum += diff * diff / 2.0 + diffSum += diff + } + + totalCnt += 1 + this + } + + /** + * Merge another LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other LeastSquaresAggregator to be merged. + * @return This LeastSquaresAggregator object. + */ + def merge(other: LeastSquaresAggregator): this.type = { + require(dim == other.dim, s"Dimensions mismatch when merging with another " + + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") + + if (other.totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + diffSum += other.diffSum + + var i = 0 + val localThisGradientSumArray = this.gradientSumArray + val localOtherGradientSumArray = other.gradientSumArray + while (i < dim) { + localThisGradientSumArray(i) += localOtherGradientSumArray(i) + i += 1 + } + } + this + } + + def count: Long = totalCnt + + def loss: Double = lossSum / totalCnt + + def gradient: Vector = { + val result = Vectors.dense(gradientSumArray.clone()) + + val correction = { + val temp = effectiveWeightsArray.clone() + var i = 0 + while (i < temp.length) { + temp(i) *= featuresMean(i) + i += 1 + } + Vectors.dense(temp) + } + + axpy(-diffSum, correction, result) + scal(1.0 / totalCnt, result) + result + } +} + +/** + * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost. + * It returns the loss and gradient with L2 regularization at a particular point (weights). + * It's used in Breeze's convex optimization routines. + */ +private class LeastSquaresCostFun( + data: RDD[(Double, Vector)], + labelStd: Double, + labelMean: Double, + featuresStd: Array[Double], + featuresMean: Array[Double], + effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { + + override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + val w = Vectors.fromBreeze(weights) + + val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, + labelMean, featuresStd, featuresMean))( + seqOp = (c, v) => (c, v) match { + case (aggregator, (label, features)) => aggregator.add(label, features) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + // regVal is the sum of weight squares for L2 regularization + val norm = brzNorm(weights, 2.0) + val regVal = 0.5 * effectiveL2regParam * norm * norm + + val loss = leastSquaresAggregator.loss + regVal + val gradient = leastSquaresAggregator.gradient + axpy(effectiveL2regParam, w, gradient) + + (loss, gradient.toBreeze.asInstanceOf[BDV[Double]]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 166c00cff634d..af0cfe22ca10d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -85,7 +85,7 @@ sealed trait Vector extends Serializable { /** * Converts the instance to a breeze vector. */ - private[mllib] def toBreeze: BV[Double] + private[spark] def toBreeze: BV[Double] /** * Gets the value of the ith element. @@ -284,7 +284,7 @@ object Vectors { /** * Creates a vector instance from a breeze vector. */ - private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { + private[spark] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { @@ -483,7 +483,7 @@ class DenseVector(val values: Array[Double]) extends Vector { override def toArray: Array[Double] = values - private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) + private[spark] override def toBreeze: BV[Double] = new BDV[Double](values) override def apply(i: Int): Double = values(i) @@ -543,7 +543,7 @@ class SparseVector( new SparseVector(size, indices.clone(), values.clone()) } - private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + private[spark] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 8bfa0d2b64995..240baeb5a158b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -37,7 +37,11 @@ abstract class Gradient extends Serializable { * * @return (gradient: Vector, loss: Double) */ - def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) + def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } /** * Compute the gradient and loss given the features of a single data point, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index ef6eccd90711a..efedc112d380e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.optimization +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseVector => BDV} @@ -164,7 +165,7 @@ object LBFGS extends Logging { regParam: Double, initialWeights: Vector): (Vector, Array[Double]) = { - val lossHistory = new ArrayBuffer[Double](maxNumIterations) + val lossHistory = mutable.ArrayBuilder.make[Double] val numExamples = data.count() @@ -181,17 +182,19 @@ object LBFGS extends Logging { * and regVal is the regularization value computed in the previous iteration as well. */ var state = states.next() - while(states.hasNext) { - lossHistory.append(state.value) + while (states.hasNext) { + lossHistory += state.value state = states.next() } - lossHistory.append(state.value) + lossHistory += state.value val weights = Vectors.fromBreeze(state.x) + val lossHistoryArray = lossHistory.result() + logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format( - lossHistory.takeRight(10).mkString(", "))) + lossHistoryArray.takeRight(10).mkString(", "))) - (weights, lossHistory.toArray) + (weights, lossHistoryArray) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index c9d33787b0bb5..d7bb943e84f53 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -56,6 +56,10 @@ object LinearDataGenerator { } /** + * For compatibility, the generated data without specifying the mean and variance + * will have zero mean and variance of (1.0/3.0) since the original output range is + * [-1, 1] with uniform distribution, and the variance of uniform distribution + * is (b - a)^2^ / 12 which will be (1.0/3.0) * * @param intercept Data intercept * @param weights Weights to be applied. @@ -70,10 +74,47 @@ object LinearDataGenerator { nPoints: Int, seed: Int, eps: Double = 0.1): Seq[LabeledPoint] = { + generateLinearInput(intercept, weights, + Array.fill[Double](weights.size)(0.0), + Array.fill[Double](weights.size)(1.0 / 3.0), + nPoints, seed, eps)} + + /** + * + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param xMean the mean of the generated features. Lots of time, if the features are not properly + * standardized, the algorithm with poor implementation will have difficulty + * to converge. + * @param xVariance the variance of the generated features. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @param eps Epsilon scaling factor. + * @return Seq of input. + */ + def generateLinearInput( + intercept: Double, + weights: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + eps: Double): Seq[LabeledPoint] = { val rnd = new Random(seed) val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0)) + Array.fill[Double](weights.length)(rnd.nextDouble)) + + x.map(vector => { + // This doesn't work if `vector` is a sparse vector. + val vectorArray = vector.toArray + var i = 0 + while (i < vectorArray.size) { + vectorArray(i) = (vectorArray(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + }) + val y = x.map { xi => blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index bbb44c3e2dfc2..80323ef5201a6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -19,47 +19,149 @@ package org.apache.spark.ml.regression import org.scalatest.FunSuite -import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext, DataFrame} class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ + /** + * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML + * is the same as the one trained by R's glmnet package. The following instruction + * describes how to reproduce the data in R. + * + * import org.apache.spark.mllib.util.LinearDataGenerator + * val data = + * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2) + * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path") + */ override def beforeAll(): Unit = { super.beforeAll() sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( - sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2)) + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) } - test("linear regression: default params") { - val lr = new LinearRegression - assert(lr.getLabelCol == "label") - val model = lr.fit(dataset) - model.transform(dataset) - .select("label", "prediction") - .collect() - // Check defaults - assert(model.getFeaturesCol == "features") - assert(model.getPredictionCol == "prediction") + test("linear regression with intercept without regularization") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + /** + * Using the following R code to load the data and train the model using glmnet package. + * + * library("glmnet") + * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + * label <- as.numeric(data$V1) + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.300528 + * as.numeric.data.V2. 4.701024 + * as.numeric.data.V3. 7.198257 + */ + val interceptR = 6.298698 + val weightsR = Array(4.700706, 7.199082) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with L1 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.311546 + * as.numeric.data.V2. 2.123522 + * as.numeric.data.V3. 4.605651 + */ + val interceptR = 6.243000 + val weightsR = Array(4.024821, 6.679841) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } - test("linear regression with setters") { - // Set params, train, and check as many as we can. - val lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(1.0) - val model = lr.fit(dataset) - assert(model.fittingParamMap.get(lr.maxIter).get === 10) - assert(model.fittingParamMap.get(lr.regParam).get === 1.0) - - // Call fit() with new params, and check as many as we can. - val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred") - assert(model2.fittingParamMap.get(lr.maxIter).get === 5) - assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) - assert(model2.getPredictionCol == "thePred") + test("linear regression with intercept with L2 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.328062 + * as.numeric.data.V2. 3.222034 + * as.numeric.data.V3. 4.926260 + */ + val interceptR = 5.269376 + val weightsR = Array(3.736216, 5.712356) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with ElasticNet regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.324108 + * as.numeric.data.V2. 3.168435 + * as.numeric.data.V3. 5.200403 + */ + val interceptR = 5.696056 + val weightsR = Array(3.670489, 6.001122) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } } From b14cd2364932e504695bcc49486ffb4518fdf33d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 09:59:36 -0700 Subject: [PATCH 157/191] [SPARK-7140] [MLLIB] only scan the first 16 entries in Vector.hashCode The Python SerDe calls `Object.hashCode`, which is very expensive for Vectors. It is not necessary to scan the whole vector, especially for large ones. In this PR, we only scan the first 16 nonzeros. srowen Author: Xiangrui Meng Closes #5697 from mengxr/SPARK-7140 and squashes the following commits: 2abc86d [Xiangrui Meng] typo 8fb7d74 [Xiangrui Meng] update impl 1ebad60 [Xiangrui Meng] only scan the first 16 nonzeros in Vector.hashCode --- .../apache/spark/mllib/linalg/Vectors.scala | 88 ++++++++++++++----- 1 file changed, 67 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index af0cfe22ca10d..34833e90d4af0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -52,7 +52,7 @@ sealed trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { - case v2: Vector => { + case v2: Vector => if (this.size != v2.size) return false (this, v2) match { case (s1: SparseVector, s2: SparseVector) => @@ -63,20 +63,28 @@ sealed trait Vector extends Serializable { Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) } - } case _ => false } } + /** + * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros + * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. + */ override def hashCode(): Int = { - var result: Int = size + 31 - this.foreachActive { case (index, value) => - // ignore explict 0 for comparison between sparse and dense - if (value != 0) { - result = 31 * result + index - // refer to {@link java.util.Arrays.equals} for hash algorithm - val bits = java.lang.Double.doubleToLongBits(value) - result = 31 * result + (bits ^ (bits >>> 32)).toInt + // This is a reference implementation. It calls return in foreachActive, which is slow. + // Subclasses should override it with optimized implementation. + var result: Int = 31 + size + this.foreachActive { (index, value) => + if (index < 16) { + // ignore explicit 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + index + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + return result } } result @@ -317,7 +325,7 @@ object Vectors { case SparseVector(n, ids, vs) => vs case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } - val size = values.size + val size = values.length if (p == 1) { var sum = 0.0 @@ -371,8 +379,8 @@ object Vectors { val v1Indices = v1.indices val v2Values = v2.values val v2Indices = v2.indices - val nnzv1 = v1Indices.size - val nnzv2 = v2Indices.size + val nnzv1 = v1Indices.length + val nnzv2 = v2Indices.length var kv1 = 0 var kv2 = 0 @@ -401,7 +409,7 @@ object Vectors { case (DenseVector(vv1), DenseVector(vv2)) => var kv = 0 - val sz = vv1.size + val sz = vv1.length while (kv < sz) { val score = vv1(kv) - vv2(kv) squaredDistance += score * score @@ -422,7 +430,7 @@ object Vectors { var kv2 = 0 val indices = v1.indices var squaredDistance = 0.0 - val nnzv1 = indices.size + val nnzv1 = indices.length val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 @@ -451,8 +459,8 @@ object Vectors { v1Values: Array[Double], v2Indices: IndexedSeq[Int], v2Values: Array[Double]): Boolean = { - val v1Size = v1Values.size - val v2Size = v2Values.size + val v1Size = v1Values.length + val v2Size = v2Values.length var k1 = 0 var k2 = 0 var allEqual = true @@ -493,7 +501,7 @@ class DenseVector(val values: Array[Double]) extends Vector { private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localValues = values while (i < localValuesSize) { @@ -501,6 +509,22 @@ class DenseVector(val values: Array[Double]) extends Vector { i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + var i = 0 + val end = math.min(values.length, 16) + while (i < end) { + val v = values(i) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(values(i)) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + i += 1 + } + result + } } object DenseVector { @@ -522,8 +546,8 @@ class SparseVector( val values: Array[Double]) extends Vector { require(indices.length == values.length, "Sparse vectors require that the dimension of the" + - s" indices match the dimension of the values. You provided ${indices.size} indices and " + - s" ${values.size} values.") + s" indices match the dimension of the values. You provided ${indices.length} indices and " + + s" ${values.length} values.") override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" @@ -547,7 +571,7 @@ class SparseVector( private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localIndices = indices val localValues = values @@ -556,6 +580,28 @@ class SparseVector( i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + val end = values.length + var continue = true + var k = 0 + while ((k < end) & continue) { + val i = indices(k) + if (i < 16) { + val v = values(k) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(v) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + continue = false + } + k += 1 + } + result + } } object SparseVector { From 52ccf1d3739694826915cdf01642bab02958eb78 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 28 Apr 2015 10:24:00 -0700 Subject: [PATCH 158/191] [Core][test][minor] replace try finally block with tryWithSafeFinally Author: Zhang, Liye Closes #5739 from liyezhang556520/trySafeFinally and squashes the following commits: 55683e5 [Zhang, Liye] replace try finally block with tryWithSafeFinally --- .../apache/spark/deploy/history/FsHistoryProviderSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index fcae603c7d18e..9e367a0d9af0d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -224,9 +224,9 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers EventLoggingListener.initEventLog(new FileOutputStream(file)) } val writer = new OutputStreamWriter(bstream, "UTF-8") - try { + Utils.tryWithSafeFinally { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) - } finally { + } { writer.close() } } From 8aab94d8984e9d12194dbda47b2e7d9dbc036889 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Tue, 28 Apr 2015 12:08:18 -0700 Subject: [PATCH 159/191] [SPARK-4286] Add an external shuffle service that can be run as a daemon. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This allows Mesos deployments to use the shuffle service (and implicitly dynamic allocation). It does so by adding a new "main" class and two corresponding scripts in `sbin`: - `sbin/start-shuffle-service.sh` - `sbin/stop-shuffle-service.sh` Specific options can be passed in `SPARK_SHUFFLE_OPTS`. This is picking up work from #3861 /cc tnachen Author: Iulian Dragos Closes #4990 from dragos/feature/external-shuffle-service and squashes the following commits: 6c2b148 [Iulian Dragos] Import order and wrong name fixup. 07804ad [Iulian Dragos] Moved ExternalShuffleService to the `deploy` package + other minor tweaks. 4dc1f91 [Iulian Dragos] Reviewer’s comments: 8145429 [Iulian Dragos] Add an external shuffle service that can be run as a daemon. --- conf/spark-env.sh.template | 3 +- ...ice.scala => ExternalShuffleService.scala} | 59 ++++++++++++++++--- .../apache/spark/deploy/worker/Worker.scala | 13 ++-- docs/job-scheduling.md | 2 +- .../launcher/SparkClassCommandBuilder.java | 4 ++ sbin/start-shuffle-service.sh | 33 +++++++++++ sbin/stop-shuffle-service.sh | 25 ++++++++ 7 files changed, 124 insertions(+), 15 deletions(-) rename core/src/main/scala/org/apache/spark/deploy/{worker/StandaloneWorkerShuffleService.scala => ExternalShuffleService.scala} (59%) create mode 100755 sbin/start-shuffle-service.sh create mode 100755 sbin/stop-shuffle-service.sh diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 67f81d33361e1..43c4288912b18 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -3,7 +3,7 @@ # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. -# Options read when launching programs locally with +# Options read when launching programs locally with # ./bin/run-example or ./bin/spark-submit # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node @@ -39,6 +39,7 @@ # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") +# - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") # - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala similarity index 59% rename from core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala rename to core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index b9798963bab0a..cd16f992a3c0a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.spark.deploy.worker +package org.apache.spark.deploy + +import java.util.concurrent.CountDownLatch import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.network.TransportContext @@ -23,6 +25,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslRpcHandler import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.util.Utils /** * Provides a server from which Executors can read shuffle files (rather than reading directly from @@ -31,8 +34,8 @@ import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler * * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. */ -private[worker] -class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) +private[deploy] +class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) extends Logging { private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) @@ -51,16 +54,58 @@ class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: Secu /** Starts the external shuffle service if the user has configured us to. */ def startIfEnabled() { if (enabled) { - require(server == null, "Shuffle server already started") - logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - server = transportContext.createServer(port) + start() } } + /** Start the external shuffle service */ + def start() { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + server = transportContext.createServer(port) + } + def stop() { - if (enabled && server != null) { + if (server != null) { server.close() server = null } } } + +/** + * A main class for running the external shuffle service. + */ +object ExternalShuffleService extends Logging { + @volatile + private var server: ExternalShuffleService = _ + + private val barrier = new CountDownLatch(1) + + def main(args: Array[String]): Unit = { + val sparkConf = new SparkConf + Utils.loadDefaultSparkProperties(sparkConf) + val securityManager = new SecurityManager(sparkConf) + + // we override this value since this service is started from the command line + // and we assume the user really wants it to be running + sparkConf.set("spark.shuffle.service.enabled", "true") + server = new ExternalShuffleService(sparkConf, securityManager) + server.start() + + installShutdownHook() + + // keep running until the process is terminated + barrier.await() + } + + private def installShutdownHook(): Unit = { + Runtime.getRuntime.addShutdownHook(new Thread("External Shuffle Service shutdown thread") { + override def run() { + logInfo("Shutting down shuffle service.") + server.stop() + barrier.countDown() + } + }) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3ee2eb69e8a4e..8f3cc54051048 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -34,6 +34,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem @@ -61,7 +62,7 @@ private[worker] class Worker( assert (port > 0) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -85,10 +86,10 @@ private[worker] class Worker( private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders - private val CLEANUP_INTERVAL_MILLIS = + private val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - private val APP_DATA_RETENTION_SECS = + private val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) private val testing: Boolean = sys.props.contains("spark.testing") @@ -112,7 +113,7 @@ private[worker] class Worker( } else { new File(sys.env.get("SPARK_HOME").getOrElse(".")) } - + var workDir: File = null val finishedExecutors = new HashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] @@ -122,7 +123,7 @@ private[worker] class Worker( val finishedApps = new HashSet[String] // The shuffle service is not actually started unless configured. - private val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + private val shuffleService = new ExternalShuffleService(conf, securityMgr) private val publicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") @@ -134,7 +135,7 @@ private[worker] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) - + private var registrationRetryTimer: Option[Cancellable] = None var coresUsed = 0 diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 963e88a3e1d8f..8d9c2ba2041b2 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -32,7 +32,7 @@ Resource allocation can be configured as follows, based on the cluster type: * **Standalone mode:** By default, applications submitted to the standalone mode cluster will run in FIFO (first-in-first-out) order, and each application will try to use all available nodes. You can limit the number of nodes an application uses by setting the `spark.cores.max` configuration property in it, - or change the default for applications that don't set this setting through `spark.deploy.defaultCores`. + or change the default for applications that don't set this setting through `spark.deploy.defaultCores`. Finally, in addition to controlling cores, each application's `spark.executor.memory` setting controls its memory use. * **Mesos:** To use static partitioning on Mesos, set the `spark.mesos.coarse` configuration property to `true`, diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index e601a0a19f368..d80abf2a8676e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -69,6 +69,10 @@ public List buildCommand(Map env) throws IOException { } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_SHUFFLE_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; } else if (className.startsWith("org.apache.spark.tools.")) { String sparkHome = getSparkHome(); File toolsDir = new File(join(File.separator, sparkHome, "tools", "target", diff --git a/sbin/start-shuffle-service.sh b/sbin/start-shuffle-service.sh new file mode 100755 index 0000000000000..4fddcf7f95d40 --- /dev/null +++ b/sbin/start-shuffle-service.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Starts the external shuffle server on the machine this script is executed on. +# +# Usage: start-shuffle-server.sh +# +# Use the SPARK_SHUFFLE_OPTS environment variable to set shuffle server configuration. +# + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.ExternalShuffleService 1 diff --git a/sbin/stop-shuffle-service.sh b/sbin/stop-shuffle-service.sh new file mode 100755 index 0000000000000..4cb6891ae27fa --- /dev/null +++ b/sbin/stop-shuffle-service.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stops the external shuffle service on the machine this script is executed on. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.ExternalShuffleService 1 From 2d222fb39dd978e5a33cde6ceb59307cbdf7b171 Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Tue, 28 Apr 2015 12:18:55 -0700 Subject: [PATCH 160/191] [SPARK-5932] [CORE] Use consistent naming for size properties I've added an interface to JavaUtils to do byte conversion and added hooks within Utils.scala to handle conversion within Spark code (like for time strings). I've added matching tests for size conversion, and then updated all deprecated configs and documentation as per SPARK-5933. Author: Ilya Ganelin Closes #5574 from ilganeli/SPARK-5932 and squashes the following commits: 11f6999 [Ilya Ganelin] Nit fixes 49a8720 [Ilya Ganelin] Whitespace fix 2ab886b [Ilya Ganelin] Scala style fc85733 [Ilya Ganelin] Got rid of floating point math 852a407 [Ilya Ganelin] [SPARK-5932] Added much improved overflow handling. Can now handle sizes up to Long.MAX_VALUE Petabytes instead of being capped at Long.MAX_VALUE Bytes 9ee779c [Ilya Ganelin] Simplified fraction matches 22413b1 [Ilya Ganelin] Made MAX private 3dfae96 [Ilya Ganelin] Fixed some nits. Added automatic conversion of old paramter for kryoserializer.mb to new values. e428049 [Ilya Ganelin] resolving merge conflict 8b43748 [Ilya Ganelin] Fixed error in pattern matching for doubles 84a2581 [Ilya Ganelin] Added smoother handling of fractional values for size parameters. This now throws an exception and added a warning for old spark.kryoserializer.buffer d3d09b6 [Ilya Ganelin] [SPARK-5932] Fixing error in KryoSerializer fe286b4 [Ilya Ganelin] Resolved merge conflict c7803cd [Ilya Ganelin] Empty lines 54b78b4 [Ilya Ganelin] Simplified byteUnit class 69e2f20 [Ilya Ganelin] Updates to code f32bc01 [Ilya Ganelin] [SPARK-5932] Fixed error in API in SparkConf.scala where Kb conversion wasn't being done properly (was Mb). Added test cases for both timeUnit and ByteUnit conversion f15f209 [Ilya Ganelin] Fixed conversion of kryo buffer size 0f4443e [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-5932 35a7fa7 [Ilya Ganelin] Minor formatting 928469e [Ilya Ganelin] [SPARK-5932] Converted some longs to ints 5d29f90 [Ilya Ganelin] [SPARK-5932] Finished documentation updates 7a6c847 [Ilya Ganelin] [SPARK-5932] Updated spark.shuffle.file.buffer afc9a38 [Ilya Ganelin] [SPARK-5932] Updated spark.broadcast.blockSize and spark.storage.memoryMapThreshold ae7e9f6 [Ilya Ganelin] [SPARK-5932] Updated spark.io.compression.snappy.block.size 2d15681 [Ilya Ganelin] [SPARK-5932] Updated spark.executor.logs.rolling.size.maxBytes 1fbd435 [Ilya Ganelin] [SPARK-5932] Updated spark.broadcast.blockSize eba4de6 [Ilya Ganelin] [SPARK-5932] Updated spark.shuffle.file.buffer.kb b809a78 [Ilya Ganelin] [SPARK-5932] Updated spark.kryoserializer.buffer.max 0cdff35 [Ilya Ganelin] [SPARK-5932] Updated to use bibibytes in method names. Updated spark.kryoserializer.buffer.mb and spark.reducer.maxMbInFlight 475370a [Ilya Ganelin] [SPARK-5932] Simplified ByteUnit code, switched to using longs. Updated docs to clarify that we use kibi, mebi etc instead of kilo, mega 851d691 [Ilya Ganelin] [SPARK-5932] Updated memoryStringToMb to use new interfaces a9f4fcf [Ilya Ganelin] [SPARK-5932] Added unit tests for unit conversion 747393a [Ilya Ganelin] [SPARK-5932] Added unit tests for ByteString conversion 09ea450 [Ilya Ganelin] [SPARK-5932] Added byte string conversion to Jav utils 5390fd9 [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-5932 db9a963 [Ilya Ganelin] Closing second spark context 1dc0444 [Ilya Ganelin] Added ref equality check 8c884fa [Ilya Ganelin] Made getOrCreate synchronized cb0c6b7 [Ilya Ganelin] Doc updates and code cleanup 270cfe3 [Ilya Ganelin] [SPARK-6703] Documentation fixes 15e8dea [Ilya Ganelin] Updated comments and added MiMa Exclude 0e1567c [Ilya Ganelin] Got rid of unecessary option for AtomicReference dfec4da [Ilya Ganelin] Changed activeContext to AtomicReference 733ec9f [Ilya Ganelin] Fixed some bugs in test code 8be2f83 [Ilya Ganelin] Replaced match with if e92caf7 [Ilya Ganelin] [SPARK-6703] Added test to ensure that getOrCreate both allows creation, retrieval, and a second context if desired a99032f [Ilya Ganelin] Spacing fix d7a06b8 [Ilya Ganelin] Updated SparkConf class to add getOrCreate method. Started test suite implementation --- .../scala/org/apache/spark/SparkConf.scala | 90 ++++++++++++++- .../spark/broadcast/TorrentBroadcast.scala | 3 +- .../apache/spark/io/CompressionCodec.scala | 8 +- .../spark/serializer/KryoSerializer.scala | 17 +-- .../shuffle/FileShuffleBlockManager.scala | 3 +- .../hash/BlockStoreShuffleFetcher.scala | 3 +- .../org/apache/spark/storage/DiskStore.scala | 3 +- .../scala/org/apache/spark/util/Utils.scala | 53 ++++++--- .../collection/ExternalAppendOnlyMap.scala | 6 +- .../util/collection/ExternalSorter.scala | 4 +- .../util/logging/RollingFileAppender.scala | 2 +- .../org/apache/spark/DistributedSuite.scala | 2 +- .../org/apache/spark/SparkConfSuite.scala | 19 ++++ .../KryoSerializerResizableOutputSuite.scala | 8 +- .../serializer/KryoSerializerSuite.scala | 2 +- .../BlockManagerReplicationSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 6 +- .../org/apache/spark/util/UtilsSuite.scala | 100 +++++++++++++++- docs/configuration.md | 60 ++++++---- docs/tuning.md | 2 +- .../spark/examples/mllib/MovieLensALS.scala | 2 +- .../apache/spark/network/util/ByteUnit.java | 67 +++++++++++ .../apache/spark/network/util/JavaUtils.java | 107 ++++++++++++++++-- 23 files changed, 488 insertions(+), 81 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c1996e08756a6..a8fc90ad2050e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -211,7 +211,74 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.timeStringAsMs(get(key, defaultValue)) } + /** + * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then bytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsBytes(key: String): Long = { + Utils.byteStringAsBytes(get(key)) + } + + /** + * Get a size parameter as bytes, falling back to a default if not set. If no + * suffix is provided then bytes are assumed. + */ + def getSizeAsBytes(key: String, defaultValue: String): Long = { + Utils.byteStringAsBytes(get(key, defaultValue)) + } + + /** + * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Kibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsKb(key: String): Long = { + Utils.byteStringAsKb(get(key)) + } + + /** + * Get a size parameter as Kibibytes, falling back to a default if not set. If no + * suffix is provided then Kibibytes are assumed. + */ + def getSizeAsKb(key: String, defaultValue: String): Long = { + Utils.byteStringAsKb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Mebibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsMb(key: String): Long = { + Utils.byteStringAsMb(get(key)) + } + + /** + * Get a size parameter as Mebibytes, falling back to a default if not set. If no + * suffix is provided then Mebibytes are assumed. + */ + def getSizeAsMb(key: String, defaultValue: String): Long = { + Utils.byteStringAsMb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Gibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsGb(key: String): Long = { + Utils.byteStringAsGb(get(key)) + } + /** + * Get a size parameter as Gibibytes, falling back to a default if not set. If no + * suffix is provided then Gibibytes are assumed. + */ + def getSizeAsGb(key: String, defaultValue: String): Long = { + Utils.byteStringAsGb(get(key, defaultValue)) + } + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) @@ -407,7 +474,13 @@ private[spark] object SparkConf extends Logging { "The spark.cache.class property is no longer being used! Specify storage levels using " + "the RDD.persist() method instead."), DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", - "Please use spark.{driver,executor}.userClassPathFirst instead.")) + "Please use spark.{driver,executor}.userClassPathFirst instead."), + DeprecatedConfig("spark.kryoserializer.buffer.mb", "1.4", + "Please use spark.kryoserializer.buffer instead. The default value for " + + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + + "are no longer accepted. To specify the equivalent now, one may use '64k'.") + ) + Map(configs.map { cfg => (cfg.key -> cfg) }:_*) } @@ -432,6 +505,21 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", // Translate old value to a duration, with 10s wait time per try. translation = s => s"${s.toLong * 10}s")), + "spark.reducer.maxSizeInFlight" -> Seq( + AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), + "spark.kryoserializer.buffer" -> + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${s.toDouble * 1000}k")), + "spark.kryoserializer.buffer.max" -> Seq( + AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), + "spark.shuffle.file.buffer" -> Seq( + AlternateConfig("spark.shuffle.file.buffer.kb", "1.4")), + "spark.executor.logs.rolling.maxSize" -> Seq( + AlternateConfig("spark.executor.logs.rolling.size.maxBytes", "1.4")), + "spark.io.compression.snappy.blockSize" -> Seq( + AlternateConfig("spark.io.compression.snappy.block.size", "1.4")), + "spark.io.compression.lz4.blockSize" -> Seq( + AlternateConfig("spark.io.compression.lz4.block.size", "1.4")), "spark.rpc.numRetries" -> Seq( AlternateConfig("spark.akka.num.retries", "1.4")), "spark.rpc.retry.wait" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 23b02e60338fb..a0c9b5e63c744 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -74,7 +74,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } else { None } - blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + // Note: use getSizeAsKb (not bytes) to maintain compatiblity if no units are provided + blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 } setConf(SparkEnv.get.conf) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0709b6d689e86..0756cdb2ed8e6 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -97,7 +97,7 @@ private[spark] object CompressionCodec { /** * :: DeveloperApi :: * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.lz4.block.size`. + * Block size can be configured by `spark.io.compression.lz4.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -107,7 +107,7 @@ private[spark] object CompressionCodec { class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.lz4.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.lz4.blockSize", "32k").toInt new LZ4BlockOutputStream(s, blockSize) } @@ -137,7 +137,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { /** * :: DeveloperApi :: * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.snappy.block.size`. + * Block size can be configured by `spark.io.compression.snappy.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -153,7 +153,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { } override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.snappy.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt new SnappyOutputStream(s, blockSize) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 579fb6624e692..754832b8a4ca7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -49,16 +49,17 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSizeMb = conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) - if (bufferSizeMb >= 2048) { - throw new IllegalArgumentException("spark.kryoserializer.buffer.mb must be less than " + - s"2048 mb, got: + $bufferSizeMb mb.") + private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") + + if (bufferSizeKb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + + s"2048 mb, got: + $bufferSizeKb mb.") } - private val bufferSize = (bufferSizeMb * 1024 * 1024).toInt + private val bufferSize = (bufferSizeKb * 1024).toInt - val maxBufferSizeMb = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) + val maxBufferSizeMb = conf.getSizeAsMb("spark.kryoserializer.buffer.max", "64m").toInt if (maxBufferSizeMb >= 2048) { - throw new IllegalArgumentException("spark.kryoserializer.buffer.max.mb must be less than " + + throw new IllegalArgumentException("spark.kryoserializer.buffer.max must be less than " + s"2048 mb, got: + $maxBufferSizeMb mb.") } private val maxBufferSize = maxBufferSizeMb * 1024 * 1024 @@ -173,7 +174,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + - "increase spark.kryoserializer.buffer.max.mb value.") + "increase spark.kryoserializer.buffer.max value.") } ByteBuffer.wrap(output.toBytes) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 538e150ead05a..e9b4e2b955dc8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -78,7 +78,8 @@ class FileShuffleBlockManager(conf: SparkConf) private val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 7a2c5ae32d98b..80374adc44296 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -79,7 +79,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blockManager, blocksByAddress, serializer, - SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 4b232ae7d3180..1f45956282166 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -31,8 +31,7 @@ import org.apache.spark.util.Utils private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) extends BlockStore(blockManager) with Logging { - val minMemoryMapBytes = blockManager.conf.getLong( - "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L) + val minMemoryMapBytes = blockManager.conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") override def getSize(blockId: BlockId): Long = { diskManager.getFile(blockId.name).length diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 342bc9a06db47..4c028c06a5138 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1020,21 +1020,48 @@ private[spark] object Utils extends Logging { } /** - * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + def byteStringAsBytes(str: String): Long = { + JavaUtils.byteStringAsBytes(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + def byteStringAsKb(str: String): Long = { + JavaUtils.byteStringAsKb(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + def byteStringAsMb(str: String): Long = { + JavaUtils.byteStringAsMb(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m, 500g) to gibibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + def byteStringAsGb(str: String): Long = { + JavaUtils.byteStringAsGb(str) + } + + /** + * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of mebibytes. */ def memoryStringToMb(str: String): Int = { - val lower = str.toLowerCase - if (lower.endsWith("k")) { - (lower.substring(0, lower.length-1).toLong / 1024).toInt - } else if (lower.endsWith("m")) { - lower.substring(0, lower.length-1).toInt - } else if (lower.endsWith("g")) { - lower.substring(0, lower.length-1).toInt * 1024 - } else if (lower.endsWith("t")) { - lower.substring(0, lower.length-1).toInt * 1024 * 1024 - } else {// no suffix, so it's just a number in bytes - (lower.toLong / 1024 / 1024).toInt - } + // Convert to bytes, rather than directly to MB, because when no units are specified the unit + // is assumed to be bytes + (JavaUtils.byteStringAsBytes(str) / 1024 / 1024).toInt } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 30dd7f22e494f..f912049563906 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -89,8 +89,10 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L - - private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val fileBufferSize = + sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 79a695fb62086..ef3cac622505e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -108,7 +108,9 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) // Size of object batches when reading/writing from serializers. diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index e579421676343..7138b4b8e4533 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -138,7 +138,7 @@ private[spark] object RollingFileAppender { val STRATEGY_DEFAULT = "" val INTERVAL_PROPERTY = "spark.executor.logs.rolling.time.interval" val INTERVAL_DEFAULT = "daily" - val SIZE_PROPERTY = "spark.executor.logs.rolling.size.maxBytes" + val SIZE_PROPERTY = "spark.executor.logs.rolling.maxSize" val SIZE_DEFAULT = (1024 * 1024).toString val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles" val DEFAULT_BUFFER_SIZE = 8192 diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 97ea3578aa8ba..96a9c207ad022 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -77,7 +77,7 @@ class DistributedSuite extends FunSuite with Matchers with LocalSparkContext { } test("groupByKey where map output sizes exceed maxMbInFlight") { - val conf = new SparkConf().set("spark.reducer.maxMbInFlight", "1") + val conf = new SparkConf().set("spark.reducer.maxSizeInFlight", "1m") sc = new SparkContext(clusterUrl, "test", conf) // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output // file should be about 2.5 MB diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 272e6af0514e4..68d08e32f9aa4 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -24,11 +24,30 @@ import scala.language.postfixOps import scala.util.{Try, Random} import org.scalatest.FunSuite +import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} import org.apache.spark.util.{RpcUtils, ResetSystemProperties} import com.esotericsoftware.kryo.Kryo class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { + test("Test byteString conversion") { + val conf = new SparkConf() + // Simply exercise the API, we don't need a complete conversion test since that's handled in + // UtilsSuite.scala + assert(conf.getSizeAsBytes("fake","1k") === ByteUnit.KiB.toBytes(1)) + assert(conf.getSizeAsKb("fake","1k") === ByteUnit.KiB.toKiB(1)) + assert(conf.getSizeAsMb("fake","1k") === ByteUnit.KiB.toMiB(1)) + assert(conf.getSizeAsGb("fake","1k") === ByteUnit.KiB.toGiB(1)) + } + + test("Test timeString conversion") { + val conf = new SparkConf() + // Simply exercise the API, we don't need a complete conversion test since that's handled in + // UtilsSuite.scala + assert(conf.getTimeAsMs("fake","1ms") === TimeUnit.MILLISECONDS.toMillis(1)) + assert(conf.getTimeAsSeconds("fake","1000ms") === TimeUnit.MILLISECONDS.toSeconds(1000)) + } + test("loading from system properties") { System.setProperty("spark.test.testProperty", "2") val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index 967c9e9899c9d..da98d09184735 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -33,8 +33,8 @@ class KryoSerializerResizableOutputSuite extends FunSuite { test("kryo without resizable output buffer should fail on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.kryoserializer.buffer.max", "1m") val sc = new SparkContext("local", "test", conf) intercept[SparkException](sc.parallelize(x).collect()) LocalSparkContext.stop(sc) @@ -43,8 +43,8 @@ class KryoSerializerResizableOutputSuite extends FunSuite { test("kryo with resizable output buffer should succeed on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "2") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.kryoserializer.buffer.max", "2m") val sc = new SparkContext("local", "test", conf) assert(sc.parallelize(x).collect() === x) LocalSparkContext.stop(sc) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index b070a54aa989b..1b13559e77cb8 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -269,7 +269,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("serialization buffer overflow reporting") { import org.apache.spark.SparkException - val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max.mb" + val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max" val largeObject = (1 to 1000000).toArray diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index ffa5162a31841..f647200402ecb 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -50,7 +50,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 7d82a7c66ad1a..6957bc72e9903 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -55,7 +55,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val shuffleManager = new HashShuffleManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. @@ -814,14 +814,14 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // be nice to refactor classes involved in disk storage in a way that // allows for easier testing. val blockManager = mock(classOf[BlockManager]) - when(blockManager.conf).thenReturn(conf.clone.set(confKey, 0.toString)) + when(blockManager.conf).thenReturn(conf.clone.set(confKey, "0")) val diskBlockManager = new DiskBlockManager(blockManager, conf) val diskStoreMapped = new DiskStore(blockManager, diskBlockManager) diskStoreMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val mapped = diskStoreMapped.getBytes(blockId).get - when(blockManager.conf).thenReturn(conf.clone.set(confKey, (1000 * 1000).toString)) + when(blockManager.conf).thenReturn(conf.clone.set(confKey, "1m")) val diskStoreNotMapped = new DiskStore(blockManager, diskBlockManager) diskStoreNotMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val notMapped = diskStoreNotMapped.getBytes(blockId).get diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 1ba99803f5a0e..62a3cbcdf69ea 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -23,7 +23,6 @@ import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols import java.util.concurrent.TimeUnit import java.util.Locale -import java.util.PriorityQueue import scala.collection.mutable.ListBuffer import scala.util.Random @@ -35,6 +34,7 @@ import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.network.util.ByteUnit import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { @@ -65,6 +65,10 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.timeStringAsMs("1d") === TimeUnit.DAYS.toMillis(1)) // Test invalid strings + intercept[NumberFormatException] { + Utils.timeStringAsMs("600l") + } + intercept[NumberFormatException] { Utils.timeStringAsMs("This breaks 600s") } @@ -82,6 +86,100 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { } } + test("Test byteString conversion") { + // Test zero + assert(Utils.byteStringAsBytes("0") === 0) + + assert(Utils.byteStringAsGb("1") === 1) + assert(Utils.byteStringAsGb("1g") === 1) + assert(Utils.byteStringAsGb("1023m") === 0) + assert(Utils.byteStringAsGb("1024m") === 1) + assert(Utils.byteStringAsGb("1048575k") === 0) + assert(Utils.byteStringAsGb("1048576k") === 1) + assert(Utils.byteStringAsGb("1k") === 0) + assert(Utils.byteStringAsGb("1t") === ByteUnit.TiB.toGiB(1)) + assert(Utils.byteStringAsGb("1p") === ByteUnit.PiB.toGiB(1)) + + assert(Utils.byteStringAsMb("1") === 1) + assert(Utils.byteStringAsMb("1m") === 1) + assert(Utils.byteStringAsMb("1048575b") === 0) + assert(Utils.byteStringAsMb("1048576b") === 1) + assert(Utils.byteStringAsMb("1023k") === 0) + assert(Utils.byteStringAsMb("1024k") === 1) + assert(Utils.byteStringAsMb("3645k") === 3) + assert(Utils.byteStringAsMb("1024gb") === 1048576) + assert(Utils.byteStringAsMb("1g") === ByteUnit.GiB.toMiB(1)) + assert(Utils.byteStringAsMb("1t") === ByteUnit.TiB.toMiB(1)) + assert(Utils.byteStringAsMb("1p") === ByteUnit.PiB.toMiB(1)) + + assert(Utils.byteStringAsKb("1") === 1) + assert(Utils.byteStringAsKb("1k") === 1) + assert(Utils.byteStringAsKb("1m") === ByteUnit.MiB.toKiB(1)) + assert(Utils.byteStringAsKb("1g") === ByteUnit.GiB.toKiB(1)) + assert(Utils.byteStringAsKb("1t") === ByteUnit.TiB.toKiB(1)) + assert(Utils.byteStringAsKb("1p") === ByteUnit.PiB.toKiB(1)) + + assert(Utils.byteStringAsBytes("1") === 1) + assert(Utils.byteStringAsBytes("1k") === ByteUnit.KiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1m") === ByteUnit.MiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1g") === ByteUnit.GiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1t") === ByteUnit.TiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1p") === ByteUnit.PiB.toBytes(1)) + + // Overflow handling, 1073741824p exceeds Long.MAX_VALUE if converted straight to Bytes + // This demonstrates that we can have e.g 1024^3 PB without overflowing. + assert(Utils.byteStringAsGb("1073741824p") === ByteUnit.PiB.toGiB(1073741824)) + assert(Utils.byteStringAsMb("1073741824p") === ByteUnit.PiB.toMiB(1073741824)) + + // Run this to confirm it doesn't throw an exception + assert(Utils.byteStringAsBytes("9223372036854775807") === 9223372036854775807L) + assert(ByteUnit.PiB.toPiB(9223372036854775807L) === 9223372036854775807L) + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MAX when converted to bytes + Utils.byteStringAsBytes("9223372036854775808") + } + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MAX when converted to TB + ByteUnit.PiB.toTiB(9223372036854775807L) + } + + // Test fractional string + intercept[NumberFormatException] { + Utils.byteStringAsMb("0.064") + } + + // Test fractional string + intercept[NumberFormatException] { + Utils.byteStringAsMb("0.064m") + } + + // Test invalid strings + intercept[NumberFormatException] { + Utils.byteStringAsBytes("500ub") + } + + // Test invalid strings + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This breaks 600b") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This breaks 600") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("600gb This breaks") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This 123mb breaks") + } + } + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") diff --git a/docs/configuration.md b/docs/configuration.md index d587b91124cb8..72105feba4919 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -48,6 +48,17 @@ The following format is accepted: 5d (days) 1y (years) + +Properties that specify a byte size should be configured with a unit of size. +The following format is accepted: + + 1b (bytes) + 1k or 1kb (kibibytes = 1024 bytes) + 1m or 1mb (mebibytes = 1024 kibibytes) + 1g or 1gb (gibibytes = 1024 mebibytes) + 1t or 1tb (tebibytes = 1024 gibibytes) + 1p or 1pb (pebibytes = 1024 tebibytes) + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different @@ -272,12 +283,11 @@ Apart from these, the following properties are also available, and may be useful - + @@ -366,10 +376,10 @@ Apart from these, the following properties are also available, and may be useful
spark.executor.logs.rolling.size.maxBytesspark.executor.logs.rolling.maxSize (none) Set the max size of the file by which the executor logs will be rolled over. - Rolling is disabled by default. Value is set in terms of bytes. - See spark.executor.logs.rolling.maxRetainedFiles + Rolling is disabled by default. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs.
- - + + @@ -403,10 +413,10 @@ Apart from these, the following properties are also available, and may be useful - - + + @@ -582,18 +592,18 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + @@ -641,19 +651,19 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + @@ -698,9 +708,9 @@ Apart from these, the following properties are also available, and may be useful - + @@ -816,9 +826,9 @@ Apart from these, the following properties are also available, and may be useful - + diff --git a/docs/tuning.md b/docs/tuning.md index cbd227868b248..1cb223e74f382 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -60,7 +60,7 @@ val sc = new SparkContext(conf) The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes more advanced registration options, such as adding custom serialization code. -If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` +If your objects are large, you may also need to increase the `spark.kryoserializer.buffer` config property. The default is 2, but this value needs to be large enough to hold the *largest* object you will serialize. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 0bc36ea65e1ab..99588b0984ab2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -100,7 +100,7 @@ object MovieLensALS { val conf = new SparkConf().setAppName(s"MovieLensALS with $params") if (params.kryo) { conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating])) - .set("spark.kryoserializer.buffer.mb", "8") + .set("spark.kryoserializer.buffer", "8m") } val sc = new SparkContext(conf) diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java new file mode 100644 index 0000000000000..36d655017fb0d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.util; + +public enum ByteUnit { + BYTE (1), + KiB (1024L), + MiB ((long) Math.pow(1024L, 2L)), + GiB ((long) Math.pow(1024L, 3L)), + TiB ((long) Math.pow(1024L, 4L)), + PiB ((long) Math.pow(1024L, 5L)); + + private ByteUnit(long multiplier) { + this.multiplier = multiplier; + } + + // Interpret the provided number (d) with suffix (u) as this unit type. + // E.g. KiB.interpret(1, MiB) interprets 1MiB as its KiB representation = 1024k + public long convertFrom(long d, ByteUnit u) { + return u.convertTo(d, this); + } + + // Convert the provided number (d) interpreted as this unit type to unit type (u). + public long convertTo(long d, ByteUnit u) { + if (multiplier > u.multiplier) { + long ratio = multiplier / u.multiplier; + if (Long.MAX_VALUE / ratio < d) { + throw new IllegalArgumentException("Conversion of " + d + " exceeds Long.MAX_VALUE in " + + name() + ". Try a larger unit (e.g. MiB instead of KiB)"); + } + return d * ratio; + } else { + // Perform operations in this order to avoid potential overflow + // when computing d * multiplier + return d / (u.multiplier / multiplier); + } + } + + public double toBytes(long d) { + if (d < 0) { + throw new IllegalArgumentException("Negative size value. Size must be positive: " + d); + } + return d * multiplier; + } + + public long toKiB(long d) { return convertTo(d, KiB); } + public long toMiB(long d) { return convertTo(d, MiB); } + public long toGiB(long d) { return convertTo(d, GiB); } + public long toTiB(long d) { return convertTo(d, TiB); } + public long toPiB(long d) { return convertTo(d, PiB); } + + private final long multiplier; +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index b6fbace509a0e..6b514aaa1290d 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -126,7 +126,7 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static ImmutableMap timeSuffixes = + private static final ImmutableMap timeSuffixes = ImmutableMap.builder() .put("us", TimeUnit.MICROSECONDS) .put("ms", TimeUnit.MILLISECONDS) @@ -137,6 +137,21 @@ private static boolean isSymlink(File file) throws IOException { .put("d", TimeUnit.DAYS) .build(); + private static final ImmutableMap byteSuffixes = + ImmutableMap.builder() + .put("b", ByteUnit.BYTE) + .put("k", ByteUnit.KiB) + .put("kb", ByteUnit.KiB) + .put("m", ByteUnit.MiB) + .put("mb", ByteUnit.MiB) + .put("g", ByteUnit.GiB) + .put("gb", ByteUnit.GiB) + .put("t", ByteUnit.TiB) + .put("tb", ByteUnit.TiB) + .put("p", ByteUnit.PiB) + .put("pb", ByteUnit.PiB) + .build(); + /** * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for * internal use. If no suffix is provided a direct conversion is attempted. @@ -145,16 +160,14 @@ private static long parseTimeString(String str, TimeUnit unit) { String lower = str.toLowerCase().trim(); try { - String suffix; - long val; Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); - if (m.matches()) { - val = Long.parseLong(m.group(1)); - suffix = m.group(2); - } else { + if (!m.matches()) { throw new NumberFormatException("Failed to parse time string: " + str); } + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + // Check for invalid suffixes if (suffix != null && !timeSuffixes.containsKey(suffix)) { throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); @@ -164,7 +177,7 @@ private static long parseTimeString(String str, TimeUnit unit) { return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); } catch (NumberFormatException e) { String timeError = "Time must be specified as seconds (s), " + - "milliseconds (ms), microseconds (us), minutes (m or min) hour (h), or day (d). " + + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + "E.g. 50s, 100ms, or 250us."; throw new NumberFormatException(timeError + "\n" + e.getMessage()); @@ -186,5 +199,83 @@ public static long timeStringAsMs(String str) { public static long timeStringAsSec(String str) { return parseTimeString(str, TimeUnit.SECONDS); } + + /** + * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for + * internal use. If no suffix is provided a direct conversion of the provided default is + * attempted. + */ + private static long parseByteString(String str, ByteUnit unit) { + String lower = str.toLowerCase().trim(); + + try { + Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); + Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); + + if (m.matches()) { + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + + // Check for invalid suffixes + if (suffix != null && !byteSuffixes.containsKey(suffix)) { + throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); + } + + // If suffix is valid use that, otherwise none was provided and use the default passed + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + } else if (fractionMatcher.matches()) { + throw new NumberFormatException("Fractional values are not supported. Input was: " + + fractionMatcher.group(1)); + } else { + throw new NumberFormatException("Failed to parse byte string: " + str); + } + + } catch (NumberFormatException e) { + String timeError = "Size must be specified as bytes (b), " + + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + + "E.g. 50b, 100k, or 250m."; + throw new NumberFormatException(timeError + "\n" + e.getMessage()); + } + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + public static long byteStringAsBytes(String str) { + return parseByteString(str, ByteUnit.BYTE); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + public static long byteStringAsKb(String str) { + return parseByteString(str, ByteUnit.KiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + public static long byteStringAsMb(String str) { + return parseByteString(str, ByteUnit.MiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to gibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + public static long byteStringAsGb(String str) { + return parseByteString(str, ByteUnit.GiB); + } } From 80098109d908b738b43d397e024756ff617d0af4 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 28 Apr 2015 12:33:48 -0700 Subject: [PATCH 161/191] [SPARK-6314] [CORE] handle JsonParseException for history server This is handled in the same way with [SPARK-6197](https://issues.apache.org/jira/browse/SPARK-6197). The result of this PR is that exception showed in history server log will be replaced by a warning, and the application that with un-complete history log file will be listed on history server webUI Author: Zhang, Liye Closes #5736 from liyezhang556520/SPARK-6314 and squashes the following commits: b8d2d88 [Zhang, Liye] handle JsonParseException for history server --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a94ebf6e53750..fb2cbbcccc54b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -333,8 +333,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } try { val appListener = new ApplicationEventListener + val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) - bus.replay(logInput, logPath.toString) + bus.replay(logInput, logPath.toString, !appCompleted) new FsApplicationHistoryInfo( logPath.getName(), appListener.appId.getOrElse(logPath.getName()), @@ -343,7 +344,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis appListener.endTime.getOrElse(-1L), getModificationTime(eventLog).get, appListener.sparkUser.getOrElse(NOT_STARTED), - isApplicationCompleted(eventLog)) + appCompleted) } finally { logInput.close() } From 53befacced828bbac53c6e3a4976ec3f036bae9e Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 28 Apr 2015 13:31:08 -0700 Subject: [PATCH 162/191] [SPARK-5338] [MESOS] Add cluster mode support for Mesos This patch adds the support for cluster mode to run on Mesos. It introduces a new Mesos framework dedicated to launch new apps/drivers, and can be called with the spark-submit script and specifying --master flag to the cluster mode REST interface instead of Mesos master. Example: ./bin/spark-submit --deploy-mode cluster --class org.apache.spark.examples.SparkPi --master mesos://10.0.0.206:8077 --executor-memory 1G --total-executor-cores 100 examples/target/spark-examples_2.10-1.3.0-SNAPSHOT.jar 30 Part of this patch is also to abstract the StandaloneRestServer so it can have different implementations of the REST endpoints. Features of the cluster mode in this PR: - Supports supervise mode where scheduler will keep trying to reschedule exited job. - Adds a new UI for the cluster mode scheduler to see all the running jobs, finished jobs, and supervise jobs waiting to be retried - Supports state persistence to ZK, so when the cluster scheduler fails over it can pick up all the queued and running jobs Author: Timothy Chen Author: Luc Bourlier Closes #5144 from tnachen/mesos_cluster_mode and squashes the following commits: 069e946 [Timothy Chen] Fix rebase. e24b512 [Timothy Chen] Persist submitted driver. 390c491 [Timothy Chen] Fix zk conf key for mesos zk engine. e324ac1 [Timothy Chen] Fix merge. fd5259d [Timothy Chen] Address review comments. 1553230 [Timothy Chen] Address review comments. c6c6b73 [Timothy Chen] Pass spark properties to mesos cluster tasks. f7d8046 [Timothy Chen] Change app name to spark cluster. 17f93a2 [Timothy Chen] Fix head of line blocking in scheduling drivers. 6ff8e5c [Timothy Chen] Address comments and add logging. df355cd [Timothy Chen] Add metrics to mesos cluster scheduler. 20f7284 [Timothy Chen] Address review comments 7252612 [Timothy Chen] Fix tests. a46ad66 [Timothy Chen] Allow zk cli param override. 920fc4b [Timothy Chen] Fix scala style issues. 862b5b5 [Timothy Chen] Support asking driver status when it's retrying. 7f214c2 [Timothy Chen] Fix RetryState visibility e0f33f7 [Timothy Chen] Add supervise support and persist retries. 371ce65 [Timothy Chen] Handle cluster mode recovery and state persistence. 3d4dfa1 [Luc Bourlier] Adds support to kill submissions febfaba [Timothy Chen] Bound the finished drivers in memory 543a98d [Timothy Chen] Schedule multiple jobs 6887e5e [Timothy Chen] Support looking at SPARK_EXECUTOR_URI env variable in schedulers 8ec76bc [Timothy Chen] Fix Mesos dispatcher UI. d57d77d [Timothy Chen] Add documentation 825afa0 [Luc Bourlier] Supports more spark-submit parameters b8e7181 [Luc Bourlier] Adds a shutdown latch to keep the deamon running 0fa7780 [Luc Bourlier] Launch task through the mesos scheduler 5b7a12b [Timothy Chen] WIP: Making a cluster mode a mesos framework. 4b2f5ef [Timothy Chen] Specify user jar in command to be replaced with local. e775001 [Timothy Chen] Support fetching remote uris in driver runner. 7179495 [Timothy Chen] Change Driver page output and add logging 880bc27 [Timothy Chen] Add Mesos Cluster UI to display driver results 9986731 [Timothy Chen] Kill drivers when shutdown 67cbc18 [Timothy Chen] Rename StandaloneRestClient to RestClient and add sbin scripts e3facdd [Timothy Chen] Add Mesos Cluster dispatcher --- .../spark/deploy/FaultToleranceTest.scala | 2 +- .../{master => }/SparkCuratorUtil.scala | 10 +- .../org/apache/spark/deploy/SparkSubmit.scala | 48 +- .../spark/deploy/SparkSubmitArguments.scala | 11 +- .../apache/spark/deploy/master/Master.scala | 2 +- .../master/ZooKeeperLeaderElectionAgent.scala | 1 + .../master/ZooKeeperPersistenceEngine.scala | 1 + .../deploy/mesos/MesosClusterDispatcher.scala | 116 ++++ .../MesosClusterDispatcherArguments.scala | 101 +++ .../deploy/mesos/MesosDriverDescription.scala | 65 ++ .../deploy/mesos/ui/MesosClusterPage.scala | 114 ++++ .../deploy/mesos/ui/MesosClusterUI.scala | 48 ++ ...lient.scala => RestSubmissionClient.scala} | 35 +- .../deploy/rest/RestSubmissionServer.scala | 318 +++++++++ .../deploy/rest/StandaloneRestServer.scala | 344 ++-------- .../rest/SubmitRestProtocolRequest.scala | 2 +- .../rest/SubmitRestProtocolResponse.scala | 6 +- .../deploy/rest/mesos/MesosRestServer.scala | 158 +++++ .../mesos/CoarseMesosSchedulerBackend.scala | 82 +-- .../mesos/MesosClusterPersistenceEngine.scala | 134 ++++ .../cluster/mesos/MesosClusterScheduler.scala | 608 ++++++++++++++++++ .../mesos/MesosClusterSchedulerSource.scala | 40 ++ .../cluster/mesos/MesosSchedulerBackend.scala | 85 +-- .../cluster/mesos/MesosSchedulerUtils.scala | 95 +++ .../spark/deploy/SparkSubmitSuite.scala | 2 +- .../rest/StandaloneRestSubmitSuite.scala | 46 +- .../mesos/MesosClusterSchedulerSuite.scala | 76 +++ docs/running-on-mesos.md | 23 +- sbin/start-mesos-dispatcher.sh | 40 ++ sbin/stop-mesos-dispatcher.sh | 27 + 30 files changed, 2147 insertions(+), 493 deletions(-) rename core/src/main/scala/org/apache/spark/deploy/{master => }/SparkCuratorUtil.scala (89%) create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala rename core/src/main/scala/org/apache/spark/deploy/rest/{StandaloneRestClient.scala => RestSubmissionClient.scala} (93%) create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala create mode 100755 sbin/start-mesos-dispatcher.sh create mode 100755 sbin/stop-mesos-dispatcher.sh diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index a7c89276a045e..c048b78910f38 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -32,7 +32,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.apache.spark.{Logging, SparkConf, SparkContext} -import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil} +import org.apache.spark.deploy.master.RecoveryState import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala similarity index 89% rename from core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala rename to core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala index 5b22481ea8c5f..b8d3993540220 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.deploy.master +package org.apache.spark.deploy import scala.collection.JavaConversions._ @@ -25,15 +25,17 @@ import org.apache.zookeeper.KeeperException import org.apache.spark.{Logging, SparkConf} -private[deploy] object SparkCuratorUtil extends Logging { +private[spark] object SparkCuratorUtil extends Logging { private val ZK_CONNECTION_TIMEOUT_MILLIS = 15000 private val ZK_SESSION_TIMEOUT_MILLIS = 60000 private val RETRY_WAIT_MILLIS = 5000 private val MAX_RECONNECT_ATTEMPTS = 3 - def newClient(conf: SparkConf): CuratorFramework = { - val ZK_URL = conf.get("spark.deploy.zookeeper.url") + def newClient( + conf: SparkConf, + zkUrlConf: String = "spark.deploy.zookeeper.url"): CuratorFramework = { + val ZK_URL = conf.get(zkUrlConf) val zk = CuratorFrameworkFactory.newClient(ZK_URL, ZK_SESSION_TIMEOUT_MILLIS, ZK_CONNECTION_TIMEOUT_MILLIS, new ExponentialBackoffRetry(RETRY_WAIT_MILLIS, MAX_RECONNECT_ATTEMPTS)) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 296a0764b8baf..f4f572e1e256e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -36,11 +36,11 @@ import org.apache.ivy.core.retrieve.RetrieveOptions import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver} - import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} + /** * Whether to submit, kill, or request the status of an application. * The latter two operations are currently supported only for standalone cluster mode. @@ -114,18 +114,20 @@ object SparkSubmit { } } - /** Kill an existing submission using the REST protocol. Standalone cluster mode only. */ + /** + * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. + */ private def kill(args: SparkSubmitArguments): Unit = { - new StandaloneRestClient() + new RestSubmissionClient() .killSubmission(args.master, args.submissionToKill) } /** * Request the status of an existing submission using the REST protocol. - * Standalone cluster mode only. + * Standalone and Mesos cluster mode only. */ private def requestStatus(args: SparkSubmitArguments): Unit = { - new StandaloneRestClient() + new RestSubmissionClient() .requestSubmissionStatus(args.master, args.submissionToRequestStatusFor) } @@ -252,6 +254,7 @@ object SparkSubmit { } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code @@ -294,8 +297,9 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) => - printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") + case (MESOS, CLUSTER) if args.isPython => + printErrorAndExit("Cluster deploy mode is currently not supported for python " + + "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") @@ -377,15 +381,6 @@ object SparkSubmit { OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.driver.extraLibraryPath"), - // Standalone cluster only - // Do not set CL arguments here because there are multiple possibilities for the main class - OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), - OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), - OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, sysProp = "spark.driver.memory"), - OptionAssigner(args.driverCores, STANDALONE, CLUSTER, sysProp = "spark.driver.cores"), - OptionAssigner(args.supervise.toString, STANDALONE, CLUSTER, - sysProp = "spark.driver.supervise"), - // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), @@ -413,7 +408,15 @@ object SparkSubmit { OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.files") + sysProp = "spark.files"), + OptionAssigner(args.jars, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars"), + OptionAssigner(args.driverMemory, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.memory"), + OptionAssigner(args.driverCores, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.cores"), + OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.supervise"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy") ) // In client mode, launch the application main class directly @@ -452,7 +455,7 @@ object SparkSubmit { // All Spark parameters are expected to be passed to the client through system properties. if (args.isStandaloneCluster) { if (args.useRest) { - childMainClass = "org.apache.spark.deploy.rest.StandaloneRestClient" + childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" childArgs += (args.primaryResource, args.mainClass) } else { // In legacy standalone cluster mode, use Client as a wrapper around the user class @@ -496,6 +499,15 @@ object SparkSubmit { } } + if (isMesosCluster) { + assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API") + childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" + childArgs += (args.primaryResource, args.mainClass) + if (args.childArgs != null) { + childArgs ++= args.childArgs + } + } + // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { sysProps.getOrElseUpdate(k, v) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c896842943f2b..c621b8fc86f94 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -241,8 +241,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def validateKillArguments(): Unit = { - if (!master.startsWith("spark://")) { - SparkSubmit.printErrorAndExit("Killing submissions is only supported in standalone mode!") + if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { + SparkSubmit.printErrorAndExit( + "Killing submissions is only supported in standalone or Mesos mode!") } if (submissionToKill == null) { SparkSubmit.printErrorAndExit("Please specify a submission to kill.") @@ -250,9 +251,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def validateStatusRequestArguments(): Unit = { - if (!master.startsWith("spark://")) { + if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { SparkSubmit.printErrorAndExit( - "Requesting submission statuses is only supported in standalone mode!") + "Requesting submission statuses is only supported in standalone or Mesos mode!") } if (submissionToRequestStatusFor == null) { SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") @@ -485,6 +486,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | | Spark standalone with cluster deploy mode only: | --driver-cores NUM Cores for driver (Default: 1). + | + | Spark standalone or Mesos with cluster deploy mode only: | --supervise If given, restarts the driver on failure. | --kill SUBMISSION_ID If given, kills the driver specified. | --status SUBMISSION_ID If given, requests the status of the driver specified. diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index ff2eed6dee70a..1c21c179562ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -130,7 +130,7 @@ private[master] class Master( private val restServer = if (restServerEnabled) { val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(host, port, self, masterUrl, conf)) + Some(new StandaloneRestServer(host, port, conf, self, masterUrl)) } else { None } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 4823fd7cac0cb..52758d6a7c4be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -23,6 +23,7 @@ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} +import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index a285783f72000..80db6d474b5c1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -26,6 +26,7 @@ import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala new file mode 100644 index 0000000000000..5d4e5b899dfdc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.concurrent.CountDownLatch + +import org.apache.spark.deploy.mesos.ui.MesosClusterUI +import org.apache.spark.deploy.rest.mesos.MesosRestServer +import org.apache.spark.scheduler.cluster.mesos._ +import org.apache.spark.util.SignalLogger +import org.apache.spark.{Logging, SecurityManager, SparkConf} + +/* + * A dispatcher that is responsible for managing and launching drivers, and is intended to be + * used for Mesos cluster mode. The dispatcher is a long-running process started by the user in + * the cluster independently of Spark applications. + * It contains a [[MesosRestServer]] that listens for requests to submit drivers and a + * [[MesosClusterScheduler]] that processes these requests by negotiating with the Mesos master + * for resources. + * + * A typical new driver lifecycle is the following: + * - Driver submitted via spark-submit talking to the [[MesosRestServer]] + * - [[MesosRestServer]] queues the driver request to [[MesosClusterScheduler]] + * - [[MesosClusterScheduler]] gets resource offers and launches the drivers that are in queue + * + * This dispatcher supports both Mesos fine-grain or coarse-grain mode as the mode is configurable + * per driver launched. + * This class is needed since Mesos doesn't manage frameworks, so the dispatcher acts as + * a daemon to launch drivers as Mesos frameworks upon request. The dispatcher is also started and + * stopped by sbin/start-mesos-dispatcher and sbin/stop-mesos-dispatcher respectively. + */ +private[mesos] class MesosClusterDispatcher( + args: MesosClusterDispatcherArguments, + conf: SparkConf) + extends Logging { + + private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) + private val recoveryMode = conf.get("spark.mesos.deploy.recoveryMode", "NONE").toUpperCase() + logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) + + private val engineFactory = recoveryMode match { + case "NONE" => new BlackHoleMesosClusterPersistenceEngineFactory + case "ZOOKEEPER" => new ZookeeperMesosClusterPersistenceEngineFactory(conf) + case _ => throw new IllegalArgumentException("Unsupported recovery mode: " + recoveryMode) + } + + private val scheduler = new MesosClusterScheduler(engineFactory, conf) + + private val server = new MesosRestServer(args.host, args.port, conf, scheduler) + private val webUi = new MesosClusterUI( + new SecurityManager(conf), + args.webUiPort, + conf, + publicAddress, + scheduler) + + private val shutdownLatch = new CountDownLatch(1) + + def start(): Unit = { + webUi.bind() + scheduler.frameworkUrl = webUi.activeWebUiUrl + scheduler.start() + server.start() + } + + def awaitShutdown(): Unit = { + shutdownLatch.await() + } + + def stop(): Unit = { + webUi.stop() + server.stop() + scheduler.stop() + shutdownLatch.countDown() + } +} + +private[mesos] object MesosClusterDispatcher extends Logging { + def main(args: Array[String]) { + SignalLogger.register(log) + val conf = new SparkConf + val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) + conf.setMaster(dispatcherArgs.masterUrl) + conf.setAppName(dispatcherArgs.name) + dispatcherArgs.zookeeperUrl.foreach { z => + conf.set("spark.mesos.deploy.recoveryMode", "ZOOKEEPER") + conf.set("spark.mesos.deploy.zookeeper.url", z) + } + val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) + dispatcher.start() + val shutdownHook = new Thread() { + override def run() { + logInfo("Shutdown hook is shutting down dispatcher") + dispatcher.stop() + dispatcher.awaitShutdown() + } + } + Runtime.getRuntime.addShutdownHook(shutdownHook) + dispatcher.awaitShutdown() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala new file mode 100644 index 0000000000000..894cb78d8591a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import org.apache.spark.SparkConf +import org.apache.spark.util.{IntParam, Utils} + + +private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { + var host = Utils.localHostName() + var port = 7077 + var name = "Spark Cluster" + var webUiPort = 8081 + var masterUrl: String = _ + var zookeeperUrl: Option[String] = None + var propertiesFile: String = _ + + parse(args.toList) + + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + + private def parse(args: List[String]): Unit = args match { + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value + parse(tail) + + case ("--port" | "-p") :: IntParam(value) :: tail => + port = value + parse(tail) + + case ("--webui-port" | "-p") :: IntParam(value) :: tail => + webUiPort = value + parse(tail) + + case ("--zk" | "-z") :: value :: tail => + zookeeperUrl = Some(value) + parse(tail) + + case ("--master" | "-m") :: value :: tail => + if (!value.startsWith("mesos://")) { + System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + System.exit(1) + } + masterUrl = value.stripPrefix("mesos://") + parse(tail) + + case ("--name") :: value :: tail => + name = value + parse(tail) + + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => + printUsageAndExit(0) + + case Nil => { + if (masterUrl == null) { + System.err.println("--master is required") + printUsageAndExit(1) + } + } + + case _ => + printUsageAndExit(1) + } + + private def printUsageAndExit(exitCode: Int): Unit = { + System.err.println( + "Usage: MesosClusterDispatcher [options]\n" + + "\n" + + "Options:\n" + + " -h HOST, --host HOST Hostname to listen on\n" + + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + + " --webui-port WEBUI_PORT WebUI Port to listen on (default: 8081)\n" + + " --name NAME Framework name to show in Mesos UI\n" + + " -m --master MASTER URI for connecting to Mesos master\n" + + " -z --zk ZOOKEEPER Comma delimited URLs for connecting to \n" + + " Zookeeper for persistence\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") + System.exit(exitCode) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala new file mode 100644 index 0000000000000..1948226800afe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.Date + +import org.apache.spark.deploy.Command +import org.apache.spark.scheduler.cluster.mesos.MesosClusterRetryState + +/** + * Describes a Spark driver that is submitted from the + * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]], to be launched by + * [[org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler]]. + * @param jarUrl URL to the application jar + * @param mem Amount of memory for the driver + * @param cores Number of cores for the driver + * @param supervise Supervise the driver for long running app + * @param command The command to launch the driver. + * @param schedulerProperties Extra properties to pass the Mesos scheduler + */ +private[spark] class MesosDriverDescription( + val name: String, + val jarUrl: String, + val mem: Int, + val cores: Double, + val supervise: Boolean, + val command: Command, + val schedulerProperties: Map[String, String], + val submissionId: String, + val submissionDate: Date, + val retryState: Option[MesosClusterRetryState] = None) + extends Serializable { + + def copy( + name: String = name, + jarUrl: String = jarUrl, + mem: Int = mem, + cores: Double = cores, + supervise: Boolean = supervise, + command: Command = command, + schedulerProperties: Map[String, String] = schedulerProperties, + submissionId: String = submissionId, + submissionDate: Date = submissionDate, + retryState: Option[MesosClusterRetryState] = retryState): MesosDriverDescription = { + new MesosDriverDescription(name, jarUrl, mem, cores, supervise, command, schedulerProperties, + submissionId, submissionDate, retryState) + } + + override def toString: String = s"MesosDriverDescription (${command.mainClass})" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala new file mode 100644 index 0000000000000..7b2005e0f1237 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.mesos.Protos.TaskStatus +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos.MesosClusterSubmissionState +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage("") { + def render(request: HttpServletRequest): Seq[Node] = { + val state = parent.scheduler.getSchedulerState() + val queuedHeaders = Seq("Driver ID", "Submit Date", "Main Class", "Driver Resources") + val driverHeaders = queuedHeaders ++ + Seq("Start Date", "Mesos Slave ID", "State") + val retryHeaders = Seq("Driver ID", "Submit Date", "Description") ++ + Seq("Last Failed Status", "Next Retry Time", "Attempt Count") + val queuedTable = UIUtils.listingTable(queuedHeaders, queuedRow, state.queuedDrivers) + val launchedTable = UIUtils.listingTable(driverHeaders, driverRow, state.launchedDrivers) + val finishedTable = UIUtils.listingTable(driverHeaders, driverRow, state.finishedDrivers) + val retryTable = UIUtils.listingTable(retryHeaders, retryRow, state.pendingRetryDrivers) + val content = +

Mesos Framework ID: {state.frameworkId}

+
+
+

Queued Drivers:

+ {queuedTable} +

Launched Drivers:

+ {launchedTable} +

Finished Drivers:

+ {finishedTable} +

Supervise drivers waiting for retry:

+ {retryTable} +
+
; + UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster") + } + + private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { +
+ + + + + + } + + private def driverRow(state: MesosClusterSubmissionState): Seq[Node] = { + + + + + + + + + + } + + private def retryRow(submission: MesosDriverDescription): Seq[Node] = { + + + + + + + + + } + + private def stateString(status: Option[TaskStatus]): String = { + if (status.isEmpty) { + return "" + } + val sb = new StringBuilder + val s = status.get + sb.append(s"State: ${s.getState}") + if (status.get.hasMessage) { + sb.append(s", Message: ${s.getMessage}") + } + if (status.get.hasHealthy) { + sb.append(s", Healthy: ${s.getHealthy}") + } + if (status.get.hasSource) { + sb.append(s", Source: ${s.getSource}") + } + if (status.get.hasReason) { + sb.append(s", Reason: ${s.getReason}") + } + if (status.get.hasTimestamp) { + sb.append(s", Time: ${s.getTimestamp}") + } + sb.toString() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala new file mode 100644 index 0000000000000..4865d46dbc4ab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.{SparkUI, WebUI} + +/** + * UI that displays driver results from the [[org.apache.spark.deploy.mesos.MesosClusterDispatcher]] + */ +private[spark] class MesosClusterUI( + securityManager: SecurityManager, + port: Int, + conf: SparkConf, + dispatcherPublicAddress: String, + val scheduler: MesosClusterScheduler) + extends WebUI(securityManager, port, conf) { + + initialize() + + def activeWebUiUrl: String = "http://" + dispatcherPublicAddress + ":" + boundPort + + override def initialize() { + attachPage(new MesosClusterPage(this)) + attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + } +} + +private object MesosClusterUI { + val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala similarity index 93% rename from core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala rename to core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index b8fd406fb6f9a..307cebfb4bd09 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -30,9 +30,7 @@ import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} import org.apache.spark.util.Utils /** - * A client that submits applications to the standalone Master using a REST protocol. - * This client is intended to communicate with the [[StandaloneRestServer]] and is - * currently used for cluster mode only. + * A client that submits applications to a [[RestSubmissionServer]]. * * In protocol version v1, the REST URL takes the form http://[host:port]/v1/submissions/[action], * where [action] can be one of create, kill, or status. Each type of request is represented in @@ -53,8 +51,10 @@ import org.apache.spark.util.Utils * implementation of this client can use that information to retry using the version specified * by the server. */ -private[deploy] class StandaloneRestClient extends Logging { - import StandaloneRestClient._ +private[spark] class RestSubmissionClient extends Logging { + import RestSubmissionClient._ + + private val supportedMasterPrefixes = Seq("spark://", "mesos://") /** * Submit an application specified by the parameters in the provided request. @@ -62,7 +62,7 @@ private[deploy] class StandaloneRestClient extends Logging { * If the submission was successful, poll the status of the submission and report * it to the user. Otherwise, report the error message provided by the server. */ - private[rest] def createSubmission( + def createSubmission( master: String, request: CreateSubmissionRequest): SubmitRestProtocolResponse = { logInfo(s"Submitting a request to launch an application in $master.") @@ -107,7 +107,7 @@ private[deploy] class StandaloneRestClient extends Logging { } /** Construct a message that captures the specified parameters for submitting an application. */ - private[rest] def constructSubmitRequest( + def constructSubmitRequest( appResource: String, mainClass: String, appArgs: Array[String], @@ -219,14 +219,23 @@ private[deploy] class StandaloneRestClient extends Logging { /** Return the base URL for communicating with the server, including the protocol version. */ private def getBaseUrl(master: String): String = { - val masterUrl = master.stripPrefix("spark://").stripSuffix("/") + var masterUrl = master + supportedMasterPrefixes.foreach { prefix => + if (master.startsWith(prefix)) { + masterUrl = master.stripPrefix(prefix) + } + } + masterUrl = masterUrl.stripSuffix("/") s"http://$masterUrl/$PROTOCOL_VERSION/submissions" } /** Throw an exception if this is not standalone mode. */ private def validateMaster(master: String): Unit = { - if (!master.startsWith("spark://")) { - throw new IllegalArgumentException("This REST client is only supported in standalone mode.") + val valid = supportedMasterPrefixes.exists { prefix => master.startsWith(prefix) } + if (!valid) { + throw new IllegalArgumentException( + "This REST client only supports master URLs that start with " + + "one of the following: " + supportedMasterPrefixes.mkString(",")) } } @@ -295,7 +304,7 @@ private[deploy] class StandaloneRestClient extends Logging { } } -private[rest] object StandaloneRestClient { +private[spark] object RestSubmissionClient { private val REPORT_DRIVER_STATUS_INTERVAL = 1000 private val REPORT_DRIVER_STATUS_MAX_TRIES = 10 val PROTOCOL_VERSION = "v1" @@ -315,7 +324,7 @@ private[rest] object StandaloneRestClient { } val sparkProperties = conf.getAll.toMap val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } - val client = new StandaloneRestClient + val client = new RestSubmissionClient val submitRequest = client.constructSubmitRequest( appResource, mainClass, appArgs, sparkProperties, environmentVariables) client.createSubmission(master, submitRequest) @@ -323,7 +332,7 @@ private[rest] object StandaloneRestClient { def main(args: Array[String]): Unit = { if (args.size < 2) { - sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]") + sys.error("Usage: RestSubmissionClient [app resource] [main class] [app args*]") sys.exit(1) } val appResource = args(0) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala new file mode 100644 index 0000000000000..2e78d03e5c0cc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.net.InetSocketAddress +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} + +import scala.io.Source +import com.fasterxml.jackson.core.JsonProcessingException +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} +import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.Utils + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * + * This server responds with different HTTP codes depending on the situation: + * 200 OK - Request was processed successfully + * 400 BAD REQUEST - Request was malformed, not successfully validated, or of unexpected type + * 468 UNKNOWN PROTOCOL VERSION - Request specified a protocol this server does not understand + * 500 INTERNAL SERVER ERROR - Server throws an exception internally while processing the request + * + * The server always includes a JSON representation of the relevant [[SubmitRestProtocolResponse]] + * in the HTTP body. If an error occurs, however, the server will include an [[ErrorResponse]] + * instead of the one expected by the client. If the construction of this error response itself + * fails, the response will consist of an empty body with a response code that indicates internal + * server error. + */ +private[spark] abstract class RestSubmissionServer( + val host: String, + val requestedPort: Int, + val masterConf: SparkConf) extends Logging { + protected val submitRequestServlet: SubmitRequestServlet + protected val killRequestServlet: KillRequestServlet + protected val statusRequestServlet: StatusRequestServlet + + private var _server: Option[Server] = None + + // A mapping from URL prefixes to servlets that serve them. Exposed for testing. + protected val baseContext = s"/${RestSubmissionServer.PROTOCOL_VERSION}/submissions" + protected lazy val contextToServlet = Map[String, RestServlet]( + s"$baseContext/create/*" -> submitRequestServlet, + s"$baseContext/kill/*" -> killRequestServlet, + s"$baseContext/status/*" -> statusRequestServlet, + "/*" -> new ErrorServlet // default handler + ) + + /** Start the server and return the bound port. */ + def start(): Int = { + val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) + _server = Some(server) + logInfo(s"Started REST server for submitting applications on port $boundPort") + boundPort + } + + /** + * Map the servlets to their corresponding contexts and attach them to a server. + * Return a 2-tuple of the started server and the bound port. + */ + private def doStart(startPort: Int): (Server, Int) = { + val server = new Server(new InetSocketAddress(host, startPort)) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val mainHandler = new ServletContextHandler + mainHandler.setContextPath("/") + contextToServlet.foreach { case (prefix, servlet) => + mainHandler.addServlet(new ServletHolder(servlet), prefix) + } + server.setHandler(mainHandler) + server.start() + val boundPort = server.getConnectors()(0).getLocalPort + (server, boundPort) + } + + def stop(): Unit = { + _server.foreach(_.stop()) + } +} + +private[rest] object RestSubmissionServer { + val PROTOCOL_VERSION = RestSubmissionClient.PROTOCOL_VERSION + val SC_UNKNOWN_PROTOCOL_VERSION = 468 +} + +/** + * An abstract servlet for handling requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class RestServlet extends HttpServlet with Logging { + + /** + * Serialize the given response message to JSON and send it through the response servlet. + * This validates the response before sending it to ensure it is properly constructed. + */ + protected def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val message = validateResponse(responseMessage, responseServlet) + responseServlet.setContentType("application/json") + responseServlet.setCharacterEncoding("utf-8") + responseServlet.getWriter.write(message.toJson) + } + + /** + * Return any fields in the client request message that the server does not know about. + * + * The mechanism for this is to reconstruct the JSON on the server side and compare the + * diff between this JSON and the one generated on the client side. Any fields that are + * only in the client JSON are treated as unexpected. + */ + protected def findUnknownFields( + requestJson: String, + requestMessage: SubmitRestProtocolMessage): Array[String] = { + val clientSideJson = parse(requestJson) + val serverSideJson = parse(requestMessage.toJson) + val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) + unknown match { + case j: JObject => j.obj.map { case (k, _) => k }.toArray + case _ => Array.empty[String] // No difference + } + } + + /** Return a human readable String representation of the exception. */ + protected def formatException(e: Throwable): String = { + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"$e\n$stackTraceString" + } + + /** Construct an error message to signal the fact that an exception has been thrown. */ + protected def handleError(message: String): ErrorResponse = { + val e = new ErrorResponse + e.serverSparkVersion = sparkVersion + e.message = message + e + } + + /** + * Parse a submission ID from the relative path, assuming it is the first part of the path. + * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. + * The returned submission ID cannot be empty. If the path is unexpected, return None. + */ + protected def parseSubmissionId(path: String): Option[String] = { + if (path == null || path.isEmpty) { + None + } else { + path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) + } + } + + /** + * Validate the response to ensure that it is correctly constructed. + * + * If it is, simply return the message as is. Otherwise, return an error response instead + * to propagate the exception back to the client and set the appropriate error code. + */ + private def validateResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + try { + responseMessage.validate() + responseMessage + } catch { + case e: Exception => + responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + handleError("Internal server error: " + formatException(e)) + } + } +} + +/** + * A servlet for handling kill requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class KillRequestServlet extends RestServlet { + + /** + * If a submission ID is specified in the URL, have the Master kill the corresponding + * driver and return an appropriate response to the client. Otherwise, return error. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleKill).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in kill request.") + } + sendResponse(responseMessage, response) + } + + protected def handleKill(submissionId: String): KillSubmissionResponse +} + +/** + * A servlet for handling status requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class StatusRequestServlet extends RestServlet { + + /** + * If a submission ID is specified in the URL, request the status of the corresponding + * driver from the Master and include it in the response. Otherwise, return error. + */ + protected override def doGet( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleStatus).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in status request.") + } + sendResponse(responseMessage, response) + } + + protected def handleStatus(submissionId: String): SubmissionStatusResponse +} + +/** + * A servlet for handling submit requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class SubmitRequestServlet extends RestServlet { + + /** + * Submit an application to the Master with parameters specified in the request. + * + * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. + * If the request is successfully processed, return an appropriate response to the + * client indicating so. Otherwise, return error instead. + */ + protected override def doPost( + requestServlet: HttpServletRequest, + responseServlet: HttpServletResponse): Unit = { + val responseMessage = + try { + val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + // The response should have already been validated on the client. + // In case this is not true, validate it ourselves to avoid potential NPEs. + requestMessage.validate() + handleSubmit(requestMessageJson, requestMessage, responseServlet) + } catch { + // The client failed to provide a valid JSON, so this is not our fault + case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Malformed request: " + formatException(e)) + } + sendResponse(responseMessage, responseServlet) + } + + protected def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse +} + +/** + * A default servlet that handles error cases that are not captured by other servlets. + */ +private class ErrorServlet extends RestServlet { + private val serverVersion = RestSubmissionServer.PROTOCOL_VERSION + + /** Service a faulty request by returning an appropriate error message to the client. */ + protected override def service( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val path = request.getPathInfo + val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList + var versionMismatch = false + var msg = + parts match { + case Nil => + // http://host:port/ + "Missing protocol version." + case `serverVersion` :: Nil => + // http://host:port/correct-version + "Missing the /submissions prefix." + case `serverVersion` :: "submissions" :: tail => + // http://host:port/correct-version/submissions/* + "Missing an action: please specify one of /create, /kill, or /status." + case unknownVersion :: tail => + // http://host:port/unknown-version/* + versionMismatch = true + s"Unknown protocol version '$unknownVersion'." + case _ => + // never reached + s"Malformed path $path." + } + msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." + val error = handleError(msg) + // If there is a version mismatch, include the highest protocol version that + // this server supports in case the client wants to retry with our version + if (versionMismatch) { + error.highestProtocolVersion = serverVersion + response.setStatus(RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) + } else { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + } + sendResponse(error, response) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 2d6b8d4204795..502b9bb701ccf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -18,26 +18,16 @@ package org.apache.spark.deploy.rest import java.io.File -import java.net.InetSocketAddress -import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} - -import scala.io.Source +import javax.servlet.http.HttpServletResponse import akka.actor.ActorRef -import com.fasterxml.jackson.core.JsonProcessingException -import org.eclipse.jetty.server.Server -import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} -import org.eclipse.jetty.util.thread.QueuedThreadPool -import org.json4s._ -import org.json4s.jackson.JsonMethods._ - -import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} -import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} import org.apache.spark.deploy.ClientArguments._ +import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** - * A server that responds to requests submitted by the [[StandaloneRestClient]]. + * A server that responds to requests submitted by the [[RestSubmissionClient]]. * This is intended to be embedded in the standalone Master and used in cluster mode only. * * This server responds with different HTTP codes depending on the situation: @@ -54,173 +44,31 @@ import org.apache.spark.deploy.ClientArguments._ * * @param host the address this server should bind to * @param requestedPort the port this server will attempt to bind to + * @param masterConf the conf used by the Master * @param masterActor reference to the Master actor to which requests can be sent * @param masterUrl the URL of the Master new drivers will attempt to connect to - * @param masterConf the conf used by the Master */ private[deploy] class StandaloneRestServer( host: String, requestedPort: Int, + masterConf: SparkConf, masterActor: ActorRef, - masterUrl: String, - masterConf: SparkConf) - extends Logging { - - import StandaloneRestServer._ - - private var _server: Option[Server] = None - - // A mapping from URL prefixes to servlets that serve them. Exposed for testing. - protected val baseContext = s"/$PROTOCOL_VERSION/submissions" - protected val contextToServlet = Map[String, StandaloneRestServlet]( - s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, masterUrl, masterConf), - s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf), - s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, masterConf), - "/*" -> new ErrorServlet // default handler - ) - - /** Start the server and return the bound port. */ - def start(): Int = { - val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) - _server = Some(server) - logInfo(s"Started REST server for submitting applications on port $boundPort") - boundPort - } - - /** - * Map the servlets to their corresponding contexts and attach them to a server. - * Return a 2-tuple of the started server and the bound port. - */ - private def doStart(startPort: Int): (Server, Int) = { - val server = new Server(new InetSocketAddress(host, startPort)) - val threadPool = new QueuedThreadPool - threadPool.setDaemon(true) - server.setThreadPool(threadPool) - val mainHandler = new ServletContextHandler - mainHandler.setContextPath("/") - contextToServlet.foreach { case (prefix, servlet) => - mainHandler.addServlet(new ServletHolder(servlet), prefix) - } - server.setHandler(mainHandler) - server.start() - val boundPort = server.getConnectors()(0).getLocalPort - (server, boundPort) - } - - def stop(): Unit = { - _server.foreach(_.stop()) - } -} - -private[rest] object StandaloneRestServer { - val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION - val SC_UNKNOWN_PROTOCOL_VERSION = 468 -} - -/** - * An abstract servlet for handling requests passed to the [[StandaloneRestServer]]. - */ -private[rest] abstract class StandaloneRestServlet extends HttpServlet with Logging { - - /** - * Serialize the given response message to JSON and send it through the response servlet. - * This validates the response before sending it to ensure it is properly constructed. - */ - protected def sendResponse( - responseMessage: SubmitRestProtocolResponse, - responseServlet: HttpServletResponse): Unit = { - val message = validateResponse(responseMessage, responseServlet) - responseServlet.setContentType("application/json") - responseServlet.setCharacterEncoding("utf-8") - responseServlet.getWriter.write(message.toJson) - } - - /** - * Return any fields in the client request message that the server does not know about. - * - * The mechanism for this is to reconstruct the JSON on the server side and compare the - * diff between this JSON and the one generated on the client side. Any fields that are - * only in the client JSON are treated as unexpected. - */ - protected def findUnknownFields( - requestJson: String, - requestMessage: SubmitRestProtocolMessage): Array[String] = { - val clientSideJson = parse(requestJson) - val serverSideJson = parse(requestMessage.toJson) - val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) - unknown match { - case j: JObject => j.obj.map { case (k, _) => k }.toArray - case _ => Array.empty[String] // No difference - } - } - - /** Return a human readable String representation of the exception. */ - protected def formatException(e: Throwable): String = { - val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") - s"$e\n$stackTraceString" - } - - /** Construct an error message to signal the fact that an exception has been thrown. */ - protected def handleError(message: String): ErrorResponse = { - val e = new ErrorResponse - e.serverSparkVersion = sparkVersion - e.message = message - e - } - - /** - * Parse a submission ID from the relative path, assuming it is the first part of the path. - * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. - * The returned submission ID cannot be empty. If the path is unexpected, return None. - */ - protected def parseSubmissionId(path: String): Option[String] = { - if (path == null || path.isEmpty) { - None - } else { - path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) - } - } - - /** - * Validate the response to ensure that it is correctly constructed. - * - * If it is, simply return the message as is. Otherwise, return an error response instead - * to propagate the exception back to the client and set the appropriate error code. - */ - private def validateResponse( - responseMessage: SubmitRestProtocolResponse, - responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { - try { - responseMessage.validate() - responseMessage - } catch { - case e: Exception => - responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) - handleError("Internal server error: " + formatException(e)) - } - } + masterUrl: String) + extends RestSubmissionServer(host, requestedPort, masterConf) { + + protected override val submitRequestServlet = + new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) + protected override val killRequestServlet = + new StandaloneKillRequestServlet(masterActor, masterConf) + protected override val statusRequestServlet = + new StandaloneStatusRequestServlet(masterActor, masterConf) } /** * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. */ -private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * If a submission ID is specified in the URL, have the Master kill the corresponding - * driver and return an appropriate response to the client. Otherwise, return error. - */ - protected override def doPost( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val submissionId = parseSubmissionId(request.getPathInfo) - val responseMessage = submissionId.map(handleKill).getOrElse { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Submission ID is missing in kill request.") - } - sendResponse(responseMessage, response) - } +private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { val askTimeout = RpcUtils.askTimeout(conf) @@ -238,23 +86,8 @@ private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * If a submission ID is specified in the URL, request the status of the corresponding - * driver from the Master and include it in the response. Otherwise, return error. - */ - protected override def doGet( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val submissionId = parseSubmissionId(request.getPathInfo) - val responseMessage = submissionId.map(handleStatus).getOrElse { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Submission ID is missing in status request.") - } - sendResponse(responseMessage, response) - } +private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { val askTimeout = RpcUtils.askTimeout(conf) @@ -276,71 +109,11 @@ private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) /** * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ -private[rest] class SubmitRequestServlet( +private[rest] class StandaloneSubmitRequestServlet( masterActor: ActorRef, masterUrl: String, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * Submit an application to the Master with parameters specified in the request. - * - * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. - * If the request is successfully processed, return an appropriate response to the - * client indicating so. Otherwise, return error instead. - */ - protected override def doPost( - requestServlet: HttpServletRequest, - responseServlet: HttpServletResponse): Unit = { - val responseMessage = - try { - val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString - val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) - // The response should have already been validated on the client. - // In case this is not true, validate it ourselves to avoid potential NPEs. - requestMessage.validate() - handleSubmit(requestMessageJson, requestMessage, responseServlet) - } catch { - // The client failed to provide a valid JSON, so this is not our fault - case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => - responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Malformed request: " + formatException(e)) - } - sendResponse(responseMessage, responseServlet) - } - - /** - * Handle the submit request and construct an appropriate response to return to the client. - * - * This assumes that the request message is already successfully validated. - * If the request message is not of the expected type, return error to the client. - */ - private def handleSubmit( - requestMessageJson: String, - requestMessage: SubmitRestProtocolMessage, - responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { - requestMessage match { - case submitRequest: CreateSubmissionRequest => - val askTimeout = RpcUtils.askTimeout(conf) - val driverDescription = buildDriverDescription(submitRequest) - val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( - DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) - val submitResponse = new CreateSubmissionResponse - submitResponse.serverSparkVersion = sparkVersion - submitResponse.message = response.message - submitResponse.success = response.success - submitResponse.submissionId = response.driverId.orNull - val unknownFields = findUnknownFields(requestMessageJson, requestMessage) - if (unknownFields.nonEmpty) { - // If there are fields that the server does not know about, warn the client - submitResponse.unknownFields = unknownFields - } - submitResponse - case unexpected => - responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError(s"Received message of unexpected type ${unexpected.messageType}.") - } - } + extends SubmitRequestServlet { /** * Build a driver description from the fields specified in the submit request. @@ -389,50 +162,37 @@ private[rest] class SubmitRequestServlet( new DriverDescription( appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) } -} -/** - * A default servlet that handles error cases that are not captured by other servlets. - */ -private class ErrorServlet extends StandaloneRestServlet { - private val serverVersion = StandaloneRestServer.PROTOCOL_VERSION - - /** Service a faulty request by returning an appropriate error message to the client. */ - protected override def service( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val path = request.getPathInfo - val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList - var versionMismatch = false - var msg = - parts match { - case Nil => - // http://host:port/ - "Missing protocol version." - case `serverVersion` :: Nil => - // http://host:port/correct-version - "Missing the /submissions prefix." - case `serverVersion` :: "submissions" :: tail => - // http://host:port/correct-version/submissions/* - "Missing an action: please specify one of /create, /kill, or /status." - case unknownVersion :: tail => - // http://host:port/unknown-version/* - versionMismatch = true - s"Unknown protocol version '$unknownVersion'." - case _ => - // never reached - s"Malformed path $path." - } - msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." - val error = handleError(msg) - // If there is a version mismatch, include the highest protocol version that - // this server supports in case the client wants to retry with our version - if (versionMismatch) { - error.highestProtocolVersion = serverVersion - response.setStatus(StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) - } else { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + /** + * Handle the submit request and construct an appropriate response to return to the client. + * + * This assumes that the request message is already successfully validated. + * If the request message is not of the expected type, return error to the client. + */ + protected override def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val askTimeout = RpcUtils.askTimeout(conf) + val driverDescription = buildDriverDescription(submitRequest) + val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val submitResponse = new CreateSubmissionResponse + submitResponse.serverSparkVersion = sparkVersion + submitResponse.message = response.message + submitResponse.success = response.success + submitResponse.submissionId = response.driverId.orNull + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + submitResponse.unknownFields = unknownFields + } + submitResponse + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") } - sendResponse(error, response) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index d80abdf15fb34..0d50a768942ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -61,7 +61,7 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { assertProperty[Boolean](key, "boolean", _.toBoolean) private def assertPropertyIsNumeric(key: String): Unit = - assertProperty[Int](key, "numeric", _.toInt) + assertProperty[Double](key, "numeric", _.toDouble) private def assertPropertyIsMemory(key: String): Unit = assertProperty[Int](key, "memory", Utils.memoryStringToMb) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala index 8fde8c142a4c1..0e226ee294cab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -35,7 +35,7 @@ private[rest] abstract class SubmitRestProtocolResponse extends SubmitRestProtoc /** * A response to a [[CreateSubmissionRequest]] in the REST application submission protocol. */ -private[rest] class CreateSubmissionResponse extends SubmitRestProtocolResponse { +private[spark] class CreateSubmissionResponse extends SubmitRestProtocolResponse { var submissionId: String = null protected override def doValidate(): Unit = { super.doValidate() @@ -46,7 +46,7 @@ private[rest] class CreateSubmissionResponse extends SubmitRestProtocolResponse /** * A response to a kill request in the REST application submission protocol. */ -private[rest] class KillSubmissionResponse extends SubmitRestProtocolResponse { +private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse { var submissionId: String = null protected override def doValidate(): Unit = { super.doValidate() @@ -58,7 +58,7 @@ private[rest] class KillSubmissionResponse extends SubmitRestProtocolResponse { /** * A response to a status request in the REST application submission protocol. */ -private[rest] class SubmissionStatusResponse extends SubmitRestProtocolResponse { +private[spark] class SubmissionStatusResponse extends SubmitRestProtocolResponse { var submissionId: String = null var driverState: String = null var workerId: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala new file mode 100644 index 0000000000000..fd17a980c9319 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest.mesos + +import java.io.File +import java.text.SimpleDateFormat +import java.util.Date +import java.util.concurrent.atomic.AtomicLong +import javax.servlet.http.HttpServletResponse + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.rest._ +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler +import org.apache.spark.util.Utils +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} + + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * All requests are forwarded to + * [[org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler]]. + * This is intended to be used in Mesos cluster mode only. + * For more details about the REST submission please refer to [[RestSubmissionServer]] javadocs. + */ +private[spark] class MesosRestServer( + host: String, + requestedPort: Int, + masterConf: SparkConf, + scheduler: MesosClusterScheduler) + extends RestSubmissionServer(host, requestedPort, masterConf) { + + protected override val submitRequestServlet = + new MesosSubmitRequestServlet(scheduler, masterConf) + protected override val killRequestServlet = + new MesosKillRequestServlet(scheduler, masterConf) + protected override val statusRequestServlet = + new MesosStatusRequestServlet(scheduler, masterConf) +} + +private[deploy] class MesosSubmitRequestServlet( + scheduler: MesosClusterScheduler, + conf: SparkConf) + extends SubmitRequestServlet { + + private val DEFAULT_SUPERVISE = false + private val DEFAULT_MEMORY = 512 // mb + private val DEFAULT_CORES = 1.0 + + private val nextDriverNumber = new AtomicLong(0) + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + private def newDriverId(submitDate: Date): String = { + "driver-%s-%04d".format( + createDateFormat.format(submitDate), nextDriverNumber.incrementAndGet()) + } + + /** + * Build a driver description from the fields specified in the submit request. + * + * This involves constructing a command that launches a mesos framework for the job. + * This does not currently consider fields used by python applications since python + * is not supported in mesos cluster mode yet. + */ + private def buildDriverDescription(request: CreateSubmissionRequest): MesosDriverDescription = { + // Required fields, including the main class because python is not yet supported + val appResource = Option(request.appResource).getOrElse { + throw new SubmitRestMissingFieldException("Application jar is missing.") + } + val mainClass = Option(request.mainClass).getOrElse { + throw new SubmitRestMissingFieldException("Main class is missing.") + } + + // Optional fields + val sparkProperties = request.sparkProperties + val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions") + val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") + val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") + val superviseDriver = sparkProperties.get("spark.driver.supervise") + val driverMemory = sparkProperties.get("spark.driver.memory") + val driverCores = sparkProperties.get("spark.driver.cores") + val appArgs = request.appArgs + val environmentVariables = request.environmentVariables + val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) + + // Construct driver description + val conf = new SparkConf(false).setAll(sparkProperties) + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + mainClass, appArgs, environmentVariables, extraClassPath, extraLibraryPath, javaOpts) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toDouble).getOrElse(DEFAULT_CORES) + val submitDate = new Date() + val submissionId = newDriverId(submitDate) + + new MesosDriverDescription( + name, appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, + command, request.sparkProperties, submissionId, submitDate) + } + + protected override def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val driverDescription = buildDriverDescription(submitRequest) + val s = scheduler.submitDriver(driverDescription) + s.serverSparkVersion = sparkVersion + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + s.unknownFields = unknownFields + } + s + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") + } + } +} + +private[deploy] class MesosKillRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) + extends KillRequestServlet { + protected override def handleKill(submissionId: String): KillSubmissionResponse = { + val k = scheduler.killDriver(submissionId) + k.serverSparkVersion = sparkVersion + k + } +} + +private[deploy] class MesosStatusRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) + extends StatusRequestServlet { + protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { + val d = scheduler.getDriverStatus(submissionId) + d.serverSparkVersion = sparkVersion + d + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 82f652dae0378..3412301e64fd7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,20 +18,17 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{List => JList} -import java.util.Collections +import java.util.{Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} - -import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -49,17 +46,10 @@ private[spark] class CoarseMesosSchedulerBackend( master: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler - with Logging { + with MesosSchedulerUtils { val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null - // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt @@ -87,26 +77,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def start() { super.start() - - synchronized { - new Thread("CoarseMesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() + startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) } def createCommand(offer: Offer, numCores: Int): CommandInfo = { @@ -150,8 +122,10 @@ private[spark] class CoarseMesosSchedulerBackend( conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) - val uri = conf.get("spark.executor.uri", null) - if (uri == null) { + val uri = conf.getOption("spark.executor.uri") + .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + + if (uri.isEmpty) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" @@ -164,7 +138,7 @@ private[spark] class CoarseMesosSchedulerBackend( } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head + val basename = uri.get.split('/').last.split('.').head command.setValue( s"cd $basename*; $prefixEnv " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + @@ -173,7 +147,7 @@ private[spark] class CoarseMesosSchedulerBackend( s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } command.build() } @@ -183,18 +157,7 @@ private[spark] class CoarseMesosSchedulerBackend( override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } - } + markRegistered() } override def disconnected(d: SchedulerDriver) {} @@ -245,14 +208,6 @@ private[spark] class CoarseMesosSchedulerBackend( } } - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - private def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - 0 - } - /** Build a Mesos resource protobuf object */ private def createResource(resourceName: String, quantity: Double): Protos.Resource = { Resource.newBuilder() @@ -284,7 +239,8 @@ private[spark] class CoarseMesosSchedulerBackend( "is Spark installed on it?") } } - driver.reviveOffers() // In case we'd rejected everything before but have now lost a node + // In case we'd rejected everything before but have now lost a node + mesosDriver.reviveOffers() } } } @@ -296,8 +252,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def stop() { super.stop() - if (driver != null) { - driver.stop() + if (mesosDriver != null) { + mesosDriver.stop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala new file mode 100644 index 0000000000000..3efc536f1456c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import scala.collection.JavaConversions._ + +import org.apache.curator.framework.CuratorFramework +import org.apache.zookeeper.CreateMode +import org.apache.zookeeper.KeeperException.NoNodeException + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.util.Utils + +/** + * Persistence engine factory that is responsible for creating new persistence engines + * to store Mesos cluster mode state. + */ +private[spark] abstract class MesosClusterPersistenceEngineFactory(conf: SparkConf) { + def createEngine(path: String): MesosClusterPersistenceEngine +} + +/** + * Mesos cluster persistence engine is responsible for persisting Mesos cluster mode + * specific state, so that on failover all the state can be recovered and the scheduler + * can resume managing the drivers. + */ +private[spark] trait MesosClusterPersistenceEngine { + def persist(name: String, obj: Object): Unit + def expunge(name: String): Unit + def fetch[T](name: String): Option[T] + def fetchAll[T](): Iterable[T] +} + +/** + * Zookeeper backed persistence engine factory. + * All Zk engines created from this factory shares the same Zookeeper client, so + * all of them reuses the same connection pool. + */ +private[spark] class ZookeeperMesosClusterPersistenceEngineFactory(conf: SparkConf) + extends MesosClusterPersistenceEngineFactory(conf) { + + lazy val zk = SparkCuratorUtil.newClient(conf, "spark.mesos.deploy.zookeeper.url") + + def createEngine(path: String): MesosClusterPersistenceEngine = { + new ZookeeperMesosClusterPersistenceEngine(path, zk, conf) + } +} + +/** + * Black hole persistence engine factory that creates black hole + * persistence engines, which stores nothing. + */ +private[spark] class BlackHoleMesosClusterPersistenceEngineFactory + extends MesosClusterPersistenceEngineFactory(null) { + def createEngine(path: String): MesosClusterPersistenceEngine = { + new BlackHoleMesosClusterPersistenceEngine + } +} + +/** + * Black hole persistence engine that stores nothing. + */ +private[spark] class BlackHoleMesosClusterPersistenceEngine extends MesosClusterPersistenceEngine { + override def persist(name: String, obj: Object): Unit = {} + override def fetch[T](name: String): Option[T] = None + override def expunge(name: String): Unit = {} + override def fetchAll[T](): Iterable[T] = Iterable.empty[T] +} + +/** + * Zookeeper based Mesos cluster persistence engine, that stores cluster mode state + * into Zookeeper. Each engine object is operating under one folder in Zookeeper, but + * reuses a shared Zookeeper client. + */ +private[spark] class ZookeeperMesosClusterPersistenceEngine( + baseDir: String, + zk: CuratorFramework, + conf: SparkConf) + extends MesosClusterPersistenceEngine with Logging { + private val WORKING_DIR = + conf.get("spark.deploy.zookeeper.dir", "/spark_mesos_dispatcher") + "/" + baseDir + + SparkCuratorUtil.mkdir(zk, WORKING_DIR) + + def path(name: String): String = { + WORKING_DIR + "/" + name + } + + override def expunge(name: String): Unit = { + zk.delete().forPath(path(name)) + } + + override def persist(name: String, obj: Object): Unit = { + val serialized = Utils.serialize(obj) + val zkPath = path(name) + zk.create().withMode(CreateMode.PERSISTENT).forPath(zkPath, serialized) + } + + override def fetch[T](name: String): Option[T] = { + val zkPath = path(name) + + try { + val fileData = zk.getData().forPath(zkPath) + Some(Utils.deserialize[T](fileData)) + } catch { + case e: NoNodeException => None + case e: Exception => { + logWarning("Exception while reading persisted file, deleting", e) + zk.delete().forPath(zkPath) + None + } + } + } + + override def fetchAll[T](): Iterable[T] = { + zk.getChildren.forPath(WORKING_DIR).map(fetch[T]).flatten + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala new file mode 100644 index 0000000000000..0396e62be5309 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -0,0 +1,608 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.io.File +import java.util.concurrent.locks.ReentrantLock +import java.util.{Collections, Date, List => JList} + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.mesos.Protos.Environment.Variable +import org.apache.mesos.Protos.TaskStatus.Reason +import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} +import org.apache.mesos.{Scheduler, SchedulerDriver} +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.util.Utils +import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} + + +/** + * Tracks the current state of a Mesos Task that runs a Spark driver. + * @param driverDescription Submitted driver description from + * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]] + * @param taskId Mesos TaskID generated for the task + * @param slaveId Slave ID that the task is assigned to + * @param mesosTaskStatus The last known task status update. + * @param startDate The date the task was launched + */ +private[spark] class MesosClusterSubmissionState( + val driverDescription: MesosDriverDescription, + val taskId: TaskID, + val slaveId: SlaveID, + var mesosTaskStatus: Option[TaskStatus], + var startDate: Date) + extends Serializable { + + def copy(): MesosClusterSubmissionState = { + new MesosClusterSubmissionState( + driverDescription, taskId, slaveId, mesosTaskStatus, startDate) + } +} + +/** + * Tracks the retry state of a driver, which includes the next time it should be scheduled + * and necessary information to do exponential backoff. + * This class is not thread-safe, and we expect the caller to handle synchronizing state. + * @param lastFailureStatus Last Task status when it failed. + * @param retries Number of times it has been retried. + * @param nextRetry Time at which it should be retried next + * @param waitTime The amount of time driver is scheduled to wait until next retry. + */ +private[spark] class MesosClusterRetryState( + val lastFailureStatus: TaskStatus, + val retries: Int, + val nextRetry: Date, + val waitTime: Int) extends Serializable { + def copy(): MesosClusterRetryState = + new MesosClusterRetryState(lastFailureStatus, retries, nextRetry, waitTime) +} + +/** + * The full state of the cluster scheduler, currently being used for displaying + * information on the UI. + * @param frameworkId Mesos Framework id for the cluster scheduler. + * @param masterUrl The Mesos master url + * @param queuedDrivers All drivers queued to be launched + * @param launchedDrivers All launched or running drivers + * @param finishedDrivers All terminated drivers + * @param pendingRetryDrivers All drivers pending to be retried + */ +private[spark] class MesosClusterSchedulerState( + val frameworkId: String, + val masterUrl: Option[String], + val queuedDrivers: Iterable[MesosDriverDescription], + val launchedDrivers: Iterable[MesosClusterSubmissionState], + val finishedDrivers: Iterable[MesosClusterSubmissionState], + val pendingRetryDrivers: Iterable[MesosDriverDescription]) + +/** + * A Mesos scheduler that is responsible for launching submitted Spark drivers in cluster mode + * as Mesos tasks in a Mesos cluster. + * All drivers are launched asynchronously by the framework, which will eventually be launched + * by one of the slaves in the cluster. The results of the driver will be stored in slave's task + * sandbox which is accessible by visiting the Mesos UI. + * This scheduler supports recovery by persisting all its state and performs task reconciliation + * on recover, which gets all the latest state for all the drivers from Mesos master. + */ +private[spark] class MesosClusterScheduler( + engineFactory: MesosClusterPersistenceEngineFactory, + conf: SparkConf) + extends Scheduler with MesosSchedulerUtils { + var frameworkUrl: String = _ + private val metricsSystem = + MetricsSystem.createMetricsSystem("mesos_cluster", conf, new SecurityManager(conf)) + private val master = conf.get("spark.master") + private val appName = conf.get("spark.app.name") + private val queuedCapacity = conf.getInt("spark.mesos.maxDrivers", 200) + private val retainedDrivers = conf.getInt("spark.mesos.retainedDrivers", 200) + private val maxRetryWaitTime = conf.getInt("spark.mesos.cluster.retry.wait.max", 60) // 1 minute + private val schedulerState = engineFactory.createEngine("scheduler") + private val stateLock = new ReentrantLock() + private val finishedDrivers = + new mutable.ArrayBuffer[MesosClusterSubmissionState](retainedDrivers) + private var frameworkId: String = null + // Holds all the launched drivers and current launch state, keyed by driver id. + private val launchedDrivers = new mutable.HashMap[String, MesosClusterSubmissionState]() + // Holds a map of driver id to expected slave id that is passed to Mesos for reconciliation. + // All drivers that are loaded after failover are added here, as we need get the latest + // state of the tasks from Mesos. + private val pendingRecover = new mutable.HashMap[String, SlaveID]() + // Stores all the submitted drivers that hasn't been launched. + private val queuedDrivers = new ArrayBuffer[MesosDriverDescription]() + // All supervised drivers that are waiting to retry after termination. + private val pendingRetryDrivers = new ArrayBuffer[MesosDriverDescription]() + private val queuedDriversState = engineFactory.createEngine("driverQueue") + private val launchedDriversState = engineFactory.createEngine("launchedDrivers") + private val pendingRetryDriversState = engineFactory.createEngine("retryList") + // Flag to mark if the scheduler is ready to be called, which is until the scheduler + // is registered with Mesos master. + @volatile protected var ready = false + private var masterInfo: Option[MasterInfo] = None + + def submitDriver(desc: MesosDriverDescription): CreateSubmissionResponse = { + val c = new CreateSubmissionResponse + if (!ready) { + c.success = false + c.message = "Scheduler is not ready to take requests" + return c + } + + stateLock.synchronized { + if (isQueueFull()) { + c.success = false + c.message = "Already reached maximum submission size" + return c + } + c.submissionId = desc.submissionId + queuedDriversState.persist(desc.submissionId, desc) + queuedDrivers += desc + c.success = true + } + c + } + + def killDriver(submissionId: String): KillSubmissionResponse = { + val k = new KillSubmissionResponse + if (!ready) { + k.success = false + k.message = "Scheduler is not ready to take requests" + return k + } + k.submissionId = submissionId + stateLock.synchronized { + // We look for the requested driver in the following places: + // 1. Check if submission is running or launched. + // 2. Check if it's still queued. + // 3. Check if it's in the retry list. + // 4. Check if it has already completed. + if (launchedDrivers.contains(submissionId)) { + val task = launchedDrivers(submissionId) + mesosDriver.killTask(task.taskId) + k.success = true + k.message = "Killing running driver" + } else if (removeFromQueuedDrivers(submissionId)) { + k.success = true + k.message = "Removed driver while it's still pending" + } else if (removeFromPendingRetryDrivers(submissionId)) { + k.success = true + k.message = "Removed driver while it's being retried" + } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + k.success = false + k.message = "Driver already terminated" + } else { + k.success = false + k.message = "Cannot find driver" + } + } + k + } + + def getDriverStatus(submissionId: String): SubmissionStatusResponse = { + val s = new SubmissionStatusResponse + if (!ready) { + s.success = false + s.message = "Scheduler is not ready to take requests" + return s + } + s.submissionId = submissionId + stateLock.synchronized { + if (queuedDrivers.exists(_.submissionId.equals(submissionId))) { + s.success = true + s.driverState = "QUEUED" + } else if (launchedDrivers.contains(submissionId)) { + s.success = true + s.driverState = "RUNNING" + launchedDrivers(submissionId).mesosTaskStatus.foreach(state => s.message = state.toString) + } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + s.success = true + s.driverState = "FINISHED" + finishedDrivers + .find(d => d.driverDescription.submissionId.equals(submissionId)).get.mesosTaskStatus + .foreach(state => s.message = state.toString) + } else if (pendingRetryDrivers.exists(_.submissionId.equals(submissionId))) { + val status = pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + .get.retryState.get.lastFailureStatus + s.success = true + s.driverState = "RETRYING" + s.message = status.toString + } else { + s.success = false + s.driverState = "NOT_FOUND" + } + } + s + } + + private def isQueueFull(): Boolean = launchedDrivers.size >= queuedCapacity + + /** + * Recover scheduler state that is persisted. + * We still need to do task reconciliation to be up to date of the latest task states + * as it might have changed while the scheduler is failing over. + */ + private def recoverState(): Unit = { + stateLock.synchronized { + launchedDriversState.fetchAll[MesosClusterSubmissionState]().foreach { state => + launchedDrivers(state.taskId.getValue) = state + pendingRecover(state.taskId.getValue) = state.slaveId + } + queuedDriversState.fetchAll[MesosDriverDescription]().foreach(d => queuedDrivers += d) + // There is potential timing issue where a queued driver might have been launched + // but the scheduler shuts down before the queued driver was able to be removed + // from the queue. We try to mitigate this issue by walking through all queued drivers + // and remove if they're already launched. + queuedDrivers + .filter(d => launchedDrivers.contains(d.submissionId)) + .foreach(d => removeFromQueuedDrivers(d.submissionId)) + pendingRetryDriversState.fetchAll[MesosDriverDescription]() + .foreach(s => pendingRetryDrivers += s) + // TODO: Consider storing finished drivers so we can show them on the UI after + // failover. For now we clear the history on each recovery. + finishedDrivers.clear() + } + } + + /** + * Starts the cluster scheduler and wait until the scheduler is registered. + * This also marks the scheduler to be ready for requests. + */ + def start(): Unit = { + // TODO: Implement leader election to make sure only one framework running in the cluster. + val fwId = schedulerState.fetch[String]("frameworkId") + val builder = FrameworkInfo.newBuilder() + .setUser(Utils.getCurrentUserName()) + .setName(appName) + .setWebuiUrl(frameworkUrl) + .setCheckpoint(true) + .setFailoverTimeout(Integer.MAX_VALUE) // Setting to max so tasks keep running on crash + fwId.foreach { id => + builder.setId(FrameworkID.newBuilder().setValue(id).build()) + frameworkId = id + } + recoverState() + metricsSystem.registerSource(new MesosClusterSchedulerSource(this)) + metricsSystem.start() + startScheduler(master, MesosClusterScheduler.this, builder.build()) + ready = true + } + + def stop(): Unit = { + ready = false + metricsSystem.report() + metricsSystem.stop() + mesosDriver.stop(true) + } + + override def registered( + driver: SchedulerDriver, + newFrameworkId: FrameworkID, + masterInfo: MasterInfo): Unit = { + logInfo("Registered as framework ID " + newFrameworkId.getValue) + if (newFrameworkId.getValue != frameworkId) { + frameworkId = newFrameworkId.getValue + schedulerState.persist("frameworkId", frameworkId) + } + markRegistered() + + stateLock.synchronized { + this.masterInfo = Some(masterInfo) + if (!pendingRecover.isEmpty) { + // Start task reconciliation if we need to recover. + val statuses = pendingRecover.collect { + case (taskId, slaveId) => + val newStatus = TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId).build()) + .setSlaveId(slaveId) + .setState(MesosTaskState.TASK_STAGING) + .build() + launchedDrivers.get(taskId).map(_.mesosTaskStatus.getOrElse(newStatus)) + .getOrElse(newStatus) + } + // TODO: Page the status updates to avoid trying to reconcile + // a large amount of tasks at once. + driver.reconcileTasks(statuses) + } + } + } + + private def buildDriverCommand(desc: MesosDriverDescription): CommandInfo = { + val appJar = CommandInfo.URI.newBuilder() + .setValue(desc.jarUrl.stripPrefix("file:").stripPrefix("local:")).build() + val builder = CommandInfo.newBuilder().addUris(appJar) + val entries = + (conf.getOption("spark.executor.extraLibraryPath").toList ++ + desc.command.libraryPathEntries) + val prefixEnv = if (!entries.isEmpty) { + Utils.libraryPathEnvPrefix(entries) + } else { + "" + } + val envBuilder = Environment.newBuilder() + desc.command.environment.foreach { case (k, v) => + envBuilder.addVariables(Variable.newBuilder().setName(k).setValue(v).build()) + } + // Pass all spark properties to executor. + val executorOpts = desc.schedulerProperties.map { case (k, v) => s"-D$k=$v" }.mkString(" ") + envBuilder.addVariables( + Variable.newBuilder().setName("SPARK_EXECUTOR_OPTS").setValue(executorOpts)) + val cmdOptions = generateCmdOption(desc) + val executorUri = desc.schedulerProperties.get("spark.executor.uri") + .orElse(desc.command.environment.get("SPARK_EXECUTOR_URI")) + val appArguments = desc.command.arguments.mkString(" ") + val cmd = if (executorUri.isDefined) { + builder.addUris(CommandInfo.URI.newBuilder().setValue(executorUri.get).build()) + val folderBasename = executorUri.get.split('/').last.split('.').head + val cmdExecutable = s"cd $folderBasename*; $prefixEnv bin/spark-submit" + val cmdJar = s"../${desc.jarUrl.split("/").last}" + s"$cmdExecutable ${cmdOptions.mkString(" ")} $cmdJar $appArguments" + } else { + val executorSparkHome = desc.schedulerProperties.get("spark.mesos.executor.home") + .orElse(conf.getOption("spark.home")) + .orElse(Option(System.getenv("SPARK_HOME"))) + .getOrElse { + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } + val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath + val cmdJar = desc.jarUrl.split("/").last + s"$cmdExecutable ${cmdOptions.mkString(" ")} $cmdJar $appArguments" + } + builder.setValue(cmd) + builder.setEnvironment(envBuilder.build()) + builder.build() + } + + private def generateCmdOption(desc: MesosDriverDescription): Seq[String] = { + var options = Seq( + "--name", desc.schedulerProperties("spark.app.name"), + "--class", desc.command.mainClass, + "--master", s"mesos://${conf.get("spark.master")}", + "--driver-cores", desc.cores.toString, + "--driver-memory", s"${desc.mem}M") + desc.schedulerProperties.get("spark.executor.memory").map { v => + options ++= Seq("--executor-memory", v) + } + desc.schedulerProperties.get("spark.cores.max").map { v => + options ++= Seq("--total-executor-cores", v) + } + options + } + + private class ResourceOffer(val offer: Offer, var cpu: Double, var mem: Double) { + override def toString(): String = { + s"Offer id: ${offer.getId.getValue}, cpu: $cpu, mem: $mem" + } + } + + /** + * This method takes all the possible candidates and attempt to schedule them with Mesos offers. + * Every time a new task is scheduled, the afterLaunchCallback is called to perform post scheduled + * logic on each task. + */ + private def scheduleTasks( + candidates: Seq[MesosDriverDescription], + afterLaunchCallback: (String) => Boolean, + currentOffers: List[ResourceOffer], + tasks: mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]): Unit = { + for (submission <- candidates) { + val driverCpu = submission.cores + val driverMem = submission.mem + logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") + val offerOption = currentOffers.find { o => + o.cpu >= driverCpu && o.mem >= driverMem + } + if (offerOption.isEmpty) { + logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + + s"cpu: $driverCpu, mem: $driverMem") + } else { + val offer = offerOption.get + offer.cpu -= driverCpu + offer.mem -= driverMem + val taskId = TaskID.newBuilder().setValue(submission.submissionId).build() + val cpuResource = Resource.newBuilder() + .setName("cpus").setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(driverCpu)).build() + val memResource = Resource.newBuilder() + .setName("mem").setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(driverMem)).build() + val commandInfo = buildDriverCommand(submission) + val appName = submission.schedulerProperties("spark.app.name") + val taskInfo = TaskInfo.newBuilder() + .setTaskId(taskId) + .setName(s"Driver for $appName") + .setSlaveId(offer.offer.getSlaveId) + .setCommand(commandInfo) + .addResources(cpuResource) + .addResources(memResource) + .build() + val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo]) + queuedTasks += taskInfo + logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + + submission.submissionId) + val newState = new MesosClusterSubmissionState(submission, taskId, offer.offer.getSlaveId, + None, new Date()) + launchedDrivers(submission.submissionId) = newState + launchedDriversState.persist(submission.submissionId, newState) + afterLaunchCallback(submission.submissionId) + } + } + } + + override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = { + val currentOffers = offers.map { o => + new ResourceOffer( + o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem")) + }.toList + logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}") + val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]() + val currentTime = new Date() + + stateLock.synchronized { + // We first schedule all the supervised drivers that are ready to retry. + // This list will be empty if none of the drivers are marked as supervise. + val driversToRetry = pendingRetryDrivers.filter { d => + d.retryState.get.nextRetry.before(currentTime) + } + scheduleTasks( + driversToRetry, + removeFromPendingRetryDrivers, + currentOffers, + tasks) + // Then we walk through the queued drivers and try to schedule them. + scheduleTasks( + queuedDrivers, + removeFromQueuedDrivers, + currentOffers, + tasks) + } + tasks.foreach { case (offerId, tasks) => + driver.launchTasks(Collections.singleton(offerId), tasks) + } + offers + .filter(o => !tasks.keySet.contains(o.getId)) + .foreach(o => driver.declineOffer(o.getId)) + } + + def getSchedulerState(): MesosClusterSchedulerState = { + def copyBuffer( + buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { + val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) + buffer.copyToBuffer(newBuffer) + newBuffer + } + stateLock.synchronized { + new MesosClusterSchedulerState( + frameworkId, + masterInfo.map(m => s"http://${m.getIp}:${m.getPort}"), + copyBuffer(queuedDrivers), + launchedDrivers.values.map(_.copy()).toList, + finishedDrivers.map(_.copy()).toList, + copyBuffer(pendingRetryDrivers)) + } + } + + override def offerRescinded(driver: SchedulerDriver, offerId: OfferID): Unit = {} + override def disconnected(driver: SchedulerDriver): Unit = {} + override def reregistered(driver: SchedulerDriver, masterInfo: MasterInfo): Unit = { + logInfo(s"Framework re-registered with master ${masterInfo.getId}") + } + override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {} + override def error(driver: SchedulerDriver, error: String): Unit = { + logError("Error received: " + error) + } + + /** + * Check if the task state is a recoverable state that we can relaunch the task. + * Task state like TASK_ERROR are not relaunchable state since it wasn't able + * to be validated by Mesos. + */ + private def shouldRelaunch(state: MesosTaskState): Boolean = { + state == MesosTaskState.TASK_FAILED || + state == MesosTaskState.TASK_KILLED || + state == MesosTaskState.TASK_LOST + } + + override def statusUpdate(driver: SchedulerDriver, status: TaskStatus): Unit = { + val taskId = status.getTaskId.getValue + stateLock.synchronized { + if (launchedDrivers.contains(taskId)) { + if (status.getReason == Reason.REASON_RECONCILIATION && + !pendingRecover.contains(taskId)) { + // Task has already received update and no longer requires reconciliation. + return + } + val state = launchedDrivers(taskId) + // Check if the driver is supervise enabled and can be relaunched. + if (state.driverDescription.supervise && shouldRelaunch(status.getState)) { + removeFromLaunchedDrivers(taskId) + val retryState: Option[MesosClusterRetryState] = state.driverDescription.retryState + val (retries, waitTimeSec) = retryState + .map { rs => (rs.retries + 1, Math.min(maxRetryWaitTime, rs.waitTime * 2)) } + .getOrElse{ (1, 1) } + val nextRetry = new Date(new Date().getTime + waitTimeSec * 1000L) + + val newDriverDescription = state.driverDescription.copy( + retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec))) + pendingRetryDrivers += newDriverDescription + pendingRetryDriversState.persist(taskId, newDriverDescription) + } else if (TaskState.isFinished(TaskState.fromMesos(status.getState))) { + removeFromLaunchedDrivers(taskId) + if (finishedDrivers.size >= retainedDrivers) { + val toRemove = math.max(retainedDrivers / 10, 1) + finishedDrivers.trimStart(toRemove) + } + finishedDrivers += state + } + state.mesosTaskStatus = Option(status) + } else { + logError(s"Unable to find driver $taskId in status update") + } + } + } + + override def frameworkMessage( + driver: SchedulerDriver, + executorId: ExecutorID, + slaveId: SlaveID, + message: Array[Byte]): Unit = {} + + override def executorLost( + driver: SchedulerDriver, + executorId: ExecutorID, + slaveId: SlaveID, + status: Int): Unit = {} + + private def removeFromQueuedDrivers(id: String): Boolean = { + val index = queuedDrivers.indexWhere(_.submissionId.equals(id)) + if (index != -1) { + queuedDrivers.remove(index) + queuedDriversState.expunge(id) + true + } else { + false + } + } + + private def removeFromLaunchedDrivers(id: String): Boolean = { + if (launchedDrivers.remove(id).isDefined) { + launchedDriversState.expunge(id) + true + } else { + false + } + } + + private def removeFromPendingRetryDrivers(id: String): Boolean = { + val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(id)) + if (index != -1) { + pendingRetryDrivers.remove(index) + pendingRetryDriversState.expunge(id) + true + } else { + false + } + } + + def getQueuedDriversSize: Int = queuedDrivers.size + def getLaunchedDriversSize: Int = launchedDrivers.size + def getPendingRetryDriversSize: Int = pendingRetryDrivers.size +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala new file mode 100644 index 0000000000000..1fe94974c8e36 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.metrics.source.Source + +private[mesos] class MesosClusterSchedulerSource(scheduler: MesosClusterScheduler) + extends Source { + override def sourceName: String = "mesos_cluster" + override def metricRegistry: MetricRegistry = new MetricRegistry() + + metricRegistry.register(MetricRegistry.name("waitingDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getQueuedDriversSize + }) + + metricRegistry.register(MetricRegistry.name("launchedDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getLaunchedDriversSize + }) + + metricRegistry.register(MetricRegistry.name("retryDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getPendingRetryDriversSize + }) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index d9d62b0e287ed..8346a2407489f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -18,23 +18,19 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{ArrayList => JArrayList, List => JList} -import java.util.Collections +import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, - ExecutorInfo => MesosExecutorInfo, _} - +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.executor.MesosExecutorBackend -import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} -import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils +import org.apache.spark.{SparkContext, SparkException, TaskState} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -47,14 +43,7 @@ private[spark] class MesosSchedulerBackend( master: String) extends SchedulerBackend with MScheduler - with Logging { - - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null + with MesosSchedulerUtils { // Which slave IDs we have executors on val slaveIdsWithExecutors = new HashSet[String] @@ -73,26 +62,9 @@ private[spark] class MesosSchedulerBackend( @volatile var appId: String = _ override def start() { - synchronized { - classLoader = Thread.currentThread.getContextClassLoader - - new Thread("MesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() + classLoader = Thread.currentThread.getContextClassLoader + startScheduler(master, MesosSchedulerBackend.this, fwInfo) } def createExecutorInfo(execId: String): MesosExecutorInfo = { @@ -125,17 +97,19 @@ private[spark] class MesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val uri = sc.conf.get("spark.executor.uri", null) + val uri = sc.conf.getOption("spark.executor.uri") + .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + val executorBackendName = classOf[MesosExecutorBackend].getName - if (uri == null) { + if (uri.isEmpty) { val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath command.setValue(s"$prefixEnv $executorPath $executorBackendName") } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head + val basename = uri.get.split('/').last.split('.').head command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } val cpus = Resource.newBuilder() .setName("cpus") @@ -181,18 +155,7 @@ private[spark] class MesosSchedulerBackend( inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } + markRegistered() } } @@ -287,14 +250,6 @@ private[spark] class MesosSchedulerBackend( } } - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - 0 - } - /** Turn a Spark TaskDescription into a Mesos task */ def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = { val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() @@ -339,13 +294,13 @@ private[spark] class MesosSchedulerBackend( } override def stop() { - if (driver != null) { - driver.stop() + if (mesosDriver != null) { + mesosDriver.stop() } } override def reviveOffers() { - driver.reviveOffers() + mesosDriver.reviveOffers() } override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} @@ -380,7 +335,7 @@ private[spark] class MesosSchedulerBackend( } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { - driver.killTask( + mesosDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() ) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala new file mode 100644 index 0000000000000..d11228f3d016a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util.List +import java.util.concurrent.CountDownLatch + +import scala.collection.JavaConversions._ + +import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status} +import org.apache.mesos.{MesosSchedulerDriver, Scheduler} +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * Shared trait for implementing a Mesos Scheduler. This holds common state and helper + * methods and Mesos scheduler will use. + */ +private[mesos] trait MesosSchedulerUtils extends Logging { + // Lock used to wait for scheduler to be registered + private final val registerLatch = new CountDownLatch(1) + + // Driver for talking to Mesos + protected var mesosDriver: MesosSchedulerDriver = null + + /** + * Starts the MesosSchedulerDriver with the provided information. This method returns + * only after the scheduler has registered with Mesos. + * @param masterUrl Mesos master connection URL + * @param scheduler Scheduler object + * @param fwInfo FrameworkInfo to pass to the Mesos master + */ + def startScheduler(masterUrl: String, scheduler: Scheduler, fwInfo: FrameworkInfo): Unit = { + synchronized { + if (mesosDriver != null) { + registerLatch.await() + return + } + + new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") { + setDaemon(true) + + override def run() { + mesosDriver = new MesosSchedulerDriver(scheduler, fwInfo, masterUrl) + try { + val ret = mesosDriver.run() + logInfo("driver.run() returned with code " + ret) + if (ret.equals(Status.DRIVER_ABORTED)) { + System.exit(1) + } + } catch { + case e: Exception => { + logError("driver.run() failed", e) + System.exit(1) + } + } + } + }.start() + + registerLatch.await() + } + } + + /** + * Signal that the scheduler has registered with Mesos. + */ + protected def markRegistered(): Unit = { + registerLatch.countDown() + } + + /** + * Get the amount of resources for the specified type from the resource list + */ + protected def getResource(res: List[Resource], name: String): Double = { + for (r <- res if r.getName == name) { + return r.getScalar.getValue + } + 0.0 + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 4561e5b8e9663..c4e6f06146b0a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -231,7 +231,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") - mainClass should be ("org.apache.spark.deploy.rest.StandaloneRestClient") + mainClass should be ("org.apache.spark.deploy.rest.RestSubmissionClient") } else { childArgsStr should startWith ("--supervise --memory 4g --cores 5") childArgsStr should include regex "launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2" diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 8e09976636386..0a318a27ac212 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -39,9 +39,9 @@ import org.apache.spark.deploy.master.DriverState._ * Tests for the REST application submission protocol used in standalone cluster mode. */ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { - private val client = new StandaloneRestClient + private val client = new RestSubmissionClient private var actorSystem: Option[ActorSystem] = None - private var server: Option[StandaloneRestServer] = None + private var server: Option[RestSubmissionServer] = None override def afterEach() { actorSystem.foreach(_.shutdown()) @@ -89,7 +89,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { conf.set("spark.app.name", "dreamer") val appArgs = Array("one", "two", "six") // main method calls this - val response = StandaloneRestClient.run("app-resource", "main-class", appArgs, conf) + val response = RestSubmissionClient.run("app-resource", "main-class", appArgs, conf) val submitResponse = getSubmitResponse(response) assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) assert(submitResponse.serverSparkVersion === SPARK_VERSION) @@ -208,7 +208,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("good request paths") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val json = constructSubmitRequest(masterUrl).toJson val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill" @@ -238,7 +238,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("good request paths, bad requests") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill" val statusRequestPath = s"$httpUrl/$v/submissions/status" @@ -276,7 +276,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("bad request paths") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val (response1, code1) = sendHttpRequestWithResponse(httpUrl, "GET") val (response2, code2) = sendHttpRequestWithResponse(s"$httpUrl/", "GET") val (response3, code3) = sendHttpRequestWithResponse(s"$httpUrl/$v", "GET") @@ -292,7 +292,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { assert(code5 === HttpServletResponse.SC_BAD_REQUEST) assert(code6 === HttpServletResponse.SC_BAD_REQUEST) assert(code7 === HttpServletResponse.SC_BAD_REQUEST) - assert(code8 === StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) + assert(code8 === RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) // all responses should be error responses val errorResponse1 = getErrorResponse(response1) val errorResponse2 = getErrorResponse(response2) @@ -310,13 +310,13 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { assert(errorResponse5.highestProtocolVersion === null) assert(errorResponse6.highestProtocolVersion === null) assert(errorResponse7.highestProtocolVersion === null) - assert(errorResponse8.highestProtocolVersion === StandaloneRestServer.PROTOCOL_VERSION) + assert(errorResponse8.highestProtocolVersion === RestSubmissionServer.PROTOCOL_VERSION) } test("server returns unknown fields") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val oldJson = constructSubmitRequest(masterUrl).toJson val oldFields = parse(oldJson).asInstanceOf[JObject].obj @@ -340,7 +340,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("client handles faulty server") { val masterUrl = startFaultyServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill/anything" val statusRequestPath = s"$httpUrl/$v/submissions/status/anything" @@ -400,9 +400,9 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) val _server = if (faulty) { - new FaultyStandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + new FaultyStandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") } else { - new StandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") } val port = _server.start() // set these to clean them up after every test @@ -563,20 +563,18 @@ private class SmarterMaster extends Actor { private class FaultyStandaloneRestServer( host: String, requestedPort: Int, + masterConf: SparkConf, masterActor: ActorRef, - masterUrl: String, - masterConf: SparkConf) - extends StandaloneRestServer(host, requestedPort, masterActor, masterUrl, masterConf) { + masterUrl: String) + extends RestSubmissionServer(host, requestedPort, masterConf) { - protected override val contextToServlet = Map[String, StandaloneRestServlet]( - s"$baseContext/create/*" -> new MalformedSubmitServlet, - s"$baseContext/kill/*" -> new InvalidKillServlet, - s"$baseContext/status/*" -> new ExplodingStatusServlet, - "/*" -> new ErrorServlet - ) + protected override val submitRequestServlet = new MalformedSubmitServlet + protected override val killRequestServlet = new InvalidKillServlet + protected override val statusRequestServlet = new ExplodingStatusServlet /** A faulty servlet that produces malformed responses. */ - class MalformedSubmitServlet extends SubmitRequestServlet(masterActor, masterUrl, masterConf) { + class MalformedSubmitServlet + extends StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) { protected override def sendResponse( responseMessage: SubmitRestProtocolResponse, responseServlet: HttpServletResponse): Unit = { @@ -586,7 +584,7 @@ private class FaultyStandaloneRestServer( } /** A faulty servlet that produces invalid responses. */ - class InvalidKillServlet extends KillRequestServlet(masterActor, masterConf) { + class InvalidKillServlet extends StandaloneKillRequestServlet(masterActor, masterConf) { protected override def handleKill(submissionId: String): KillSubmissionResponse = { val k = super.handleKill(submissionId) k.submissionId = null @@ -595,7 +593,7 @@ private class FaultyStandaloneRestServer( } /** A faulty status servlet that explodes. */ - class ExplodingStatusServlet extends StatusRequestServlet(masterActor, masterConf) { + class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterActor, masterConf) { private def explode: Int = 1 / 0 protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { val s = super.handleStatus(submissionId) diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala new file mode 100644 index 0000000000000..f28e29e9b8d8e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.mesos + +import java.util.Date + +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos._ +import org.apache.spark.{LocalSparkContext, SparkConf} + + +class MesosClusterSchedulerSuite extends FunSuite with LocalSparkContext with MockitoSugar { + + private val command = new Command("mainClass", Seq("arg"), null, null, null, null) + + test("can queue drivers") { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + val scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + scheduler.start() + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val response2 = + scheduler.submitDriver(new MesosDriverDescription( + "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) + assert(response2.success) + val state = scheduler.getSchedulerState() + val queuedDrivers = state.queuedDrivers.toList + assert(queuedDrivers(0).submissionId == response.submissionId) + assert(queuedDrivers(1).submissionId == response2.submissionId) + } + + test("can kill queued drivers") { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + val scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + scheduler.start() + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val killResponse = scheduler.killDriver(response.submissionId) + assert(killResponse.success) + val state = scheduler.getSchedulerState() + assert(state.queuedDrivers.isEmpty) + } +} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 594bf78b67713..8f53d8201a089 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -78,6 +78,9 @@ To verify that the Mesos cluster is ready for Spark, navigate to the Mesos maste To use Mesos from Spark, you need a Spark binary package available in a place accessible by Mesos, and a Spark driver program configured to connect to Mesos. +Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure +`spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location. + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary @@ -107,7 +110,11 @@ the `make-distribution.sh` script included in a Spark source tarball/checkout. The Master URLs for Mesos are in the form `mesos://host:5050` for a single-master Mesos cluster, or `mesos://zk://host:2181` for a multi-master Mesos cluster using ZooKeeper. -The driver also needs some configuration in `spark-env.sh` to interact properly with Mesos: +## Client Mode + +In client mode, a Spark Mesos framework is launched directly on the client machine and waits for the driver output. + +The driver needs some configuration in `spark-env.sh` to interact properly with Mesos: 1. In `spark-env.sh` set some environment variables: * `export MESOS_NATIVE_JAVA_LIBRARY=`. This path is typically @@ -129,8 +136,7 @@ val sc = new SparkContext(conf) {% endhighlight %} (You can also use [`spark-submit`](submitting-applications.html) and configure `spark.executor.uri` -in the [conf/spark-defaults.conf](configuration.html#loading-default-configurations) file. Note -that `spark-submit` currently only supports deploying the Spark driver in `client` mode for Mesos.) +in the [conf/spark-defaults.conf](configuration.html#loading-default-configurations) file.) When running a shell, the `spark.executor.uri` parameter is inherited from `SPARK_EXECUTOR_URI`, so it does not need to be redundantly passed in as a system property. @@ -139,6 +145,17 @@ it does not need to be redundantly passed in as a system property. ./bin/spark-shell --master mesos://host:5050 {% endhighlight %} +## Cluster mode + +Spark on Mesos also supports cluster mode, where the driver is launched in the cluster and the client +can find the results of the driver from the Mesos Web UI. + +To use cluster mode, you must start the MesosClusterDispatcher in your cluster via the `sbin/start-mesos-dispatcher.sh` script, +passing in the Mesos master url (e.g: mesos://host:5050). + +From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master url +to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the +Spark cluster Web UI. # Mesos Run Modes diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh new file mode 100755 index 0000000000000..ef1fc573d5c65 --- /dev/null +++ b/sbin/start-mesos-dispatcher.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Starts the Mesos Cluster Dispatcher on the machine this script is executed on. +# The Mesos Cluster Dispatcher is responsible for launching the Mesos framework and +# Rest server to handle driver requests for Mesos cluster mode. +# Only one cluster dispatcher is needed per Mesos cluster. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +if [ "$SPARK_MESOS_DISPATCHER_PORT" = "" ]; then + SPARK_MESOS_DISPATCHER_PORT=7077 +fi + +if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then + SPARK_MESOS_DISPATCHER_HOST=`hostname` +fi + + +"$sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" diff --git a/sbin/stop-mesos-dispatcher.sh b/sbin/stop-mesos-dispatcher.sh new file mode 100755 index 0000000000000..cb65d95b5e524 --- /dev/null +++ b/sbin/stop-mesos-dispatcher.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Stop the Mesos Cluster dispatcher on the machine this script is executed on. + +sbin=`dirname "$0"` +sbin=`cd "$sbin"; pwd` + +. "$sbin/spark-config.sh" + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 + From 28b1af7420e0b5e7e2dfc09eafc45fe2ffcde5ec Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 28 Apr 2015 13:49:29 -0700 Subject: [PATCH 163/191] [MINOR] [CORE] Warn users who try to cache RDDs with dynamic allocation on. Author: Marcelo Vanzin Closes #5751 from vanzin/cached-rdd-warning and squashes the following commits: 554cc07 [Marcelo Vanzin] Change message. 9efb9da [Marcelo Vanzin] [minor] [core] Warn users who try to cache RDDs with dynamic allocation on. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 65b903a55d5bd..d0cf2a8dd01cd 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1396,6 +1396,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register an RDD to be persisted in memory and/or disk storage */ private[spark] def persistRDD(rdd: RDD[_]) { + _executorAllocationManager.foreach { _ => + logWarning( + s"Dynamic allocation currently does not support cached RDDs. Cached data for RDD " + + s"${rdd.id} will be lost when executors are removed.") + } persistentRdds(rdd.id) = rdd } From f0a1f90f53b447c61f405cdb2c553f3dc067bf8d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 14:07:26 -0700 Subject: [PATCH 164/191] [SPARK-7201] [MLLIB] move Identifiable to ml.util It shouldn't live directly under `spark.ml`. Author: Xiangrui Meng Closes #5749 from mengxr/SPARK-7201 and squashes the following commits: 53847f9 [Xiangrui Meng] move Identifiable to ml.util --- mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala | 1 + mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 2 +- .../scala/org/apache/spark/ml/{ => util}/Identifiable.scala | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/{ => util}/Identifiable.scala (97%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala index d2ca2e6871e6b..8b4b5fd8af986 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ddc5907e7facd..014e124e440a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,7 +24,7 @@ import scala.annotation.varargs import scala.collection.mutable import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.Identifiable +import org.apache.spark.ml.util.Identifiable /** * :: AlphaComponent :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala similarity index 97% rename from mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala rename to mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala index a1d49095c24ac..8a56748ab0a02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml +package org.apache.spark.ml.util import java.util.UUID From 555213ebbf2be2ee523be8665bd5b9a47ae4bec8 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 14:21:25 -0700 Subject: [PATCH 165/191] Closes #4807 Closes #5055 Closes #3583 From d36e67350c516a96d58abd50a0d5d22b3b22f291 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 17:41:09 -0700 Subject: [PATCH 166/191] [SPARK-6965] [MLLIB] StringIndexer handles numeric input. Cast numeric types to String for indexing. Boolean type is not handled in this PR. jkbradley Author: Xiangrui Meng Closes #5753 from mengxr/SPARK-6965 and squashes the following commits: 2e34f3c [Xiangrui Meng] add actual type in the error message ad938bf [Xiangrui Meng] StringIndexer handles numeric input. --- .../spark/ml/feature/StringIndexer.scala | 17 ++++++++++++----- .../spark/ml/feature/StringIndexerSuite.scala | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 23956c512c8a6..9db3b29e10d69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{NumericType, StringType, StructType} import org.apache.spark.util.collection.OpenHashMap /** @@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = extractParamMap(paramMap) - SchemaUtils.checkColumnType(schema, map(inputCol), StringType) + val inputColName = map(inputCol) + val inputDataType = schema(inputColName).dataType + require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], + s"The input column $inputColName must be either string type or numeric type, " + + s"but got $inputDataType.") val inputFields = schema.fields val outputColName = map(outputCol) require(inputFields.forall(_.name != outputColName), @@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** * :: AlphaComponent :: * A label indexer that maps a string column of labels to an ML column of label indices. + * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. */ @@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { val map = extractParamMap(paramMap) - val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() + val counts = dataset.select(col(map(inputCol)).cast(StringType)) + .map(_.getString(0)) + .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray val model = new StringIndexerModel(this, map, labels) Params.inheritValues(map, this, model) @@ -119,7 +125,8 @@ class StringIndexerModel private[ml] ( val outputColName = map(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) + dataset.select(col("*"), + indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata)) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 00b5d094d82f1..b6939e5870410 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -49,4 +49,23 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) } + + test("StringIndexer with a numeric input column") { + val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("100", "300", "200")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // 100 -> 0, 200 -> 2, 300 -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } } From 5c8f4bd5fae539ab5fb992573d5357ed34e2f4d0 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 28 Apr 2015 19:31:57 -0700 Subject: [PATCH 167/191] [SPARK-7138] [STREAMING] Add method to BlockGenerator to add multiple records to BlockGenerator with single callback This is to ensure that receivers that receive data in small batches (like Kinesis) and want to add them but want the callback function to be called only once. This is for internal use only for improvement to Kinesis Receiver that we are planning to do. Author: Tathagata Das Closes #5695 from tdas/SPARK-7138 and squashes the following commits: a35cf7d [Tathagata Das] Fixed style. a7a4cb9 [Tathagata Das] Added extra method to BlockGenerator. --- .../spark/streaming/receiver/BlockGenerator.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index f4963a78e1d18..4bebcc5aa7ca0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -126,6 +126,20 @@ private[streaming] class BlockGenerator( listener.onAddData(data, metadata) } + /** + * Push multiple data items into the buffer. After buffering the data, the + * `BlockGeneratorListener.onAddData` callback will be called. All received data items + * will be periodically pushed into BlockManager. Note that all the data items is guaranteed + * to be present in a single block. + */ + def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = synchronized { + dataIterator.foreach { data => + waitToPush() + currentBuffer += data + } + listener.onAddData(dataIterator, metadata) + } + /** Change the buffer to which single records are added to. */ private def updateCurrentBuffer(time: Long): Unit = synchronized { try { From a8aeadb7d4a2dc308a75a50fdd8065f9a32ef336 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 28 Apr 2015 21:15:47 -0700 Subject: [PATCH 168/191] [SPARK-7208] [ML] [PYTHON] Added Matrix, SparseMatrix to __all__ list in linalg.py Added Matrix, SparseMatrix to __all__ list in linalg.py CC: mengxr Author: Joseph K. Bradley Closes #5759 from jkbradley/SPARK-7208 and squashes the following commits: deb51a2 [Joseph K. Bradley] Added Matrix, SparseMatrix to __all__ list in linalg.py --- python/pyspark/mllib/linalg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index cc9a4cf8ba170..a57c0b3ae0d00 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -39,7 +39,8 @@ IntegerType, ByteType -__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices'] +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', + 'Matrix', 'DenseMatrix', 'SparseMatrix', 'Matrices'] if sys.version_info[:2] == (2, 7): From 5ef006fc4d010905e02cb905c9115b95ba55282b Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 21:49:53 -0700 Subject: [PATCH 169/191] [SPARK-6756] [MLLIB] add toSparse, toDense, numActives, numNonzeros, and compressed to Vector Add `compressed` to `Vector` with some other methods: `numActives`, `numNonzeros`, `toSparse`, and `toDense`. jkbradley Author: Xiangrui Meng Closes #5756 from mengxr/SPARK-6756 and squashes the following commits: 8d4ecbd [Xiangrui Meng] address comment and add mima excludes da54179 [Xiangrui Meng] add toSparse, toDense, numActives, numNonzeros, and compressed to Vector --- .../apache/spark/mllib/linalg/Vectors.scala | 93 +++++++++++++++++++ .../spark/mllib/linalg/VectorsSuite.scala | 44 +++++++++ project/MimaExcludes.scala | 12 +++ 3 files changed, 149 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 34833e90d4af0..188d1e542b5b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -116,6 +116,40 @@ sealed trait Vector extends Serializable { * with type `Double`. */ private[spark] def foreachActive(f: (Int, Double) => Unit) + + /** + * Number of active entries. An "active entry" is an element which is explicitly stored, + * regardless of its value. Note that inactive entries have value 0. + */ + def numActives: Int + + /** + * Number of nonzero elements. This scans all active values and count nonzeros. + */ + def numNonzeros: Int + + /** + * Converts this vector to a sparse vector with all explicit zeros removed. + */ + def toSparse: SparseVector + + /** + * Converts this vector to a dense vector. + */ + def toDense: DenseVector = new DenseVector(this.toArray) + + /** + * Returns a vector in either dense or sparse format, whichever uses less storage. + */ + def compressed: Vector = { + val nnz = numNonzeros + // A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes. + if (1.5 * (nnz + 1.0) < size) { + toSparse + } else { + toDense + } + } } /** @@ -525,6 +559,34 @@ class DenseVector(val values: Array[Double]) extends Vector { } result } + + override def numActives: Int = size + + override def numNonzeros: Int = { + // same as values.count(_ != 0.0) but faster + var nnz = 0 + values.foreach { v => + if (v != 0.0) { + nnz += 1 + } + } + nnz + } + + override def toSparse: SparseVector = { + val nnz = numNonzeros + val ii = new Array[Int](nnz) + val vv = new Array[Double](nnz) + var k = 0 + foreachActive { (i, v) => + if (v != 0) { + ii(k) = i + vv(k) = v + k += 1 + } + } + new SparseVector(size, ii, vv) + } } object DenseVector { @@ -602,6 +664,37 @@ class SparseVector( } result } + + override def numActives: Int = values.length + + override def numNonzeros: Int = { + var nnz = 0 + values.foreach { v => + if (v != 0.0) { + nnz += 1 + } + } + nnz + } + + override def toSparse: SparseVector = { + val nnz = numNonzeros + if (nnz == numActives) { + this + } else { + val ii = new Array[Int](nnz) + val vv = new Array[Double](nnz) + var k = 0 + foreachActive { (i, v) => + if (v != 0.0) { + ii(k) = i + vv(k) = v + k += 1 + } + } + new SparseVector(size, ii, vv) + } + } } object SparseVector { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 2839c4c289b2d..24755e9ff46fc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -270,4 +270,48 @@ class VectorsSuite extends FunSuite { assert(Vectors.norm(sv, 3.7) ~== math.pow(sv.toArray.foldLeft(0.0)((a, v) => a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8) } + + test("Vector numActive and numNonzeros") { + val dv = Vectors.dense(0.0, 2.0, 3.0, 0.0) + assert(dv.numActives === 4) + assert(dv.numNonzeros === 2) + + val sv = Vectors.sparse(4, Array(0, 1, 2), Array(0.0, 2.0, 3.0)) + assert(sv.numActives === 3) + assert(sv.numNonzeros === 2) + } + + test("Vector toSparse and toDense") { + val dv0 = Vectors.dense(0.0, 2.0, 3.0, 0.0) + assert(dv0.toDense === dv0) + val dv0s = dv0.toSparse + assert(dv0s.numActives === 2) + assert(dv0s === dv0) + + val sv0 = Vectors.sparse(4, Array(0, 1, 2), Array(0.0, 2.0, 3.0)) + assert(sv0.toDense === sv0) + val sv0s = sv0.toSparse + assert(sv0s.numActives === 2) + assert(sv0s === sv0) + } + + test("Vector.compressed") { + val dv0 = Vectors.dense(1.0, 2.0, 3.0, 0.0) + val dv0c = dv0.compressed.asInstanceOf[DenseVector] + assert(dv0c === dv0) + + val dv1 = Vectors.dense(0.0, 2.0, 0.0, 0.0) + val dv1c = dv1.compressed.asInstanceOf[SparseVector] + assert(dv1 === dv1c) + assert(dv1c.numActives === 1) + + val sv0 = Vectors.sparse(4, Array(1, 2), Array(2.0, 0.0)) + val sv0c = sv0.compressed.asInstanceOf[SparseVector] + assert(sv0 === sv0c) + assert(sv0c.numActives === 1) + + val sv1 = Vectors.sparse(4, Array(0, 1, 2), Array(1.0, 2.0, 3.0)) + val sv1c = sv1.compressed.asInstanceOf[DenseVector] + assert(sv1 === sv1c) + } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 967961c2bf5c3..3beafa158eb97 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -76,6 +76,18 @@ object MimaExcludes { // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.mllib.clustering.LDA$EMOptimizer") + ) ++ Seq( + // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.compressed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toDense"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toSparse"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numActives") ) case v if v.startsWith("1.3") => From 271c4c621d91d3f610ae89e5d2e5dab1a2009ca6 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 28 Apr 2015 22:48:04 -0700 Subject: [PATCH 170/191] [SPARK-7215] made coalesce and repartition a part of the query plan Coalesce and repartition now show up as part of the query plan, rather than resulting in a new `DataFrame`. cc rxin Author: Burak Yavuz Closes #5762 from brkyvz/df-repartition and squashes the following commits: b1e76dd [Burak Yavuz] added documentation on repartitions 5807e35 [Burak Yavuz] renamed coalescepartitions fa4509f [Burak Yavuz] rename coalesce 2c349b5 [Burak Yavuz] address comments f2e6af1 [Burak Yavuz] add ticks 686c90b [Burak Yavuz] made coalesce and repartition a part of the query plan --- .../catalyst/plans/logical/basicOperators.scala | 11 +++++++++++ .../sql/catalyst/plans/logical/partitioning.scala | 8 +++++++- .../scala/org/apache/spark/sql/DataFrame.scala | 9 ++------- .../spark/sql/execution/SparkStrategies.scala | 5 +++-- .../spark/sql/execution/basicOperators.scala | 14 ++++++++++++++ .../scala/org/apache/spark/sql/hive/HiveQl.scala | 6 +++--- 6 files changed, 40 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index bbc94a7ab3398..608e272da7784 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -310,6 +310,17 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } +/** + * Return a new RDD that has exactly `numPartitions` partitions. Differs from + * [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user + * asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer + * of the output requires some specific ordering or distribution of the data. + */ +case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + /** * A relation with one row. This is used in "SELECT ..." without a from clause. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index e737418d9c3bc..63df2c1ee72ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -32,5 +32,11 @@ abstract class RedistributeData extends UnaryNode { case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) extends RedistributeData -case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) +/** + * This method repartitions data using [[Expression]]s, and receives information about the + * number of partitions during execution. Used when a specific ordering or distribution is + * expected by the consumer of the query result. Use [[Repartition]] for RDD-like + * `coalesce` and `repartition`. + */ +case class RepartitionByExpression(partitionExpressions: Seq[Expression], child: LogicalPlan) extends RedistributeData diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ca6ae482eb2ab..2affba7d42cc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -961,9 +961,7 @@ class DataFrame private[sql]( * @group rdd */ override def repartition(numPartitions: Int): DataFrame = { - sqlContext.createDataFrame( - queryExecution.toRdd.map(_.copy()).repartition(numPartitions), - schema, needsConversion = false) + Repartition(numPartitions, shuffle = true, logicalPlan) } /** @@ -974,10 +972,7 @@ class DataFrame private[sql]( * @group rdd */ override def coalesce(numPartitions: Int): DataFrame = { - sqlContext.createDataFrame( - queryExecution.toRdd.coalesce(numPartitions), - schema, - needsConversion = false) + Repartition(numPartitions, shuffle = false, logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 030ef118f75d4..3a0a6c86700a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -283,7 +283,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => execution.Distinct(partial = false, execution.Distinct(partial = true, planLater(child))) :: Nil - + case logical.Repartition(numPartitions, shuffle, child) => + execution.Repartition(numPartitions, shuffle, planLater(child)) :: Nil case logical.SortPartitions(sortExprs, child) => // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. @@ -317,7 +318,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil - case logical.Repartition(expressions, child) => + case logical.RepartitionByExpression(expressions, child) => execution.Exchange( HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d286fe81bee5f..1afdb409417ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -245,6 +245,20 @@ case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { } } +/** + * :: DeveloperApi :: + * Return a new RDD that has exactly `numPartitions` partitions. + */ +@DeveloperApi +case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan) + extends UnaryNode { + override def output: Seq[Attribute] = child.output + + override def execute(): RDD[Row] = { + child.execute().map(_.copy()).coalesce(numPartitions, shuffle) + } +} + /** * :: DeveloperApi :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0ea6d57b816c6..2dc6463abafa7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -783,13 +783,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case (None, Some(perPartitionOrdering), None, None) => Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, withHaving) case (None, None, Some(partitionExprs), None) => - Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving) + RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withHaving) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, - Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving)) + RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, Some(clusterExprs)) => Sort(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), false, - Repartition(clusterExprs.getChildren.map(nodeToExpr), withHaving)) + RepartitionByExpression(clusterExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, None) => withHaving case _ => sys.error("Unsupported set of ordering / distribution clauses.") } From f98773a90ded0e408af6bbd85fafbaffbc5b825f Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 28 Apr 2015 23:05:02 -0700 Subject: [PATCH 171/191] [SPARK-7205] Support `.ivy2/local` and `.m2/repositories/` in --packages In addition, I made a small change that will allow users to import 2 different artifacts with the same name. That change is made in `[organization]_[artifact]-[revision].[ext]`. This used to be only `[artifact].[ext]` which might have caused collisions between artifacts with the same artifactId, but different groupId's. cc pwendell Author: Burak Yavuz Closes #5755 from brkyvz/local-caches and squashes the following commits: c47c9c5 [Burak Yavuz] Small fixes to --packages --- .../org/apache/spark/deploy/SparkSubmit.scala | 34 ++++++++++++++----- .../spark/deploy/SparkSubmitUtilsSuite.scala | 27 ++++++++------- 2 files changed, 40 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index f4f572e1e256e..b8ae4af18d1d1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -734,13 +734,31 @@ private[deploy] object SparkSubmitUtils { /** * Extracts maven coordinates from a comma-delimited string * @param remoteRepos Comma-delimited string of remote repositories + * @param ivySettings The Ivy settings for this session * @return A ChainResolver used by Ivy to search for and resolve dependencies. */ - def createRepoResolvers(remoteRepos: Option[String]): ChainResolver = { + def createRepoResolvers(remoteRepos: Option[String], ivySettings: IvySettings): ChainResolver = { // We need a chain resolver if we want to check multiple repositories val cr = new ChainResolver cr.setName("list") + val localM2 = new IBiblioResolver + localM2.setM2compatible(true) + val m2Path = ".m2" + File.separator + "repository" + File.separator + localM2.setRoot(new File(System.getProperty("user.home"), m2Path).toURI.toString) + localM2.setUsepoms(true) + localM2.setName("local-m2-cache") + cr.add(localM2) + + val localIvy = new IBiblioResolver + localIvy.setRoot(new File(ivySettings.getDefaultIvyUserDir, + "local" + File.separator).toURI.toString) + val ivyPattern = Seq("[organisation]", "[module]", "[revision]", "[type]s", + "[artifact](-[classifier]).[ext]").mkString(File.separator) + localIvy.setPattern(ivyPattern) + localIvy.setName("local-ivy-cache") + cr.add(localIvy) + // the biblio resolver resolves POM declared dependencies val br: IBiblioResolver = new IBiblioResolver br.setM2compatible(true) @@ -773,8 +791,7 @@ private[deploy] object SparkSubmitUtils { /** * Output a comma-delimited list of paths for the downloaded jars to be added to the classpath - * (will append to jars in SparkSubmit). The name of the jar is given - * after a '!' by Ivy. It also sometimes contains '(bundle)' after '.jar'. Remove that as well. + * (will append to jars in SparkSubmit). * @param artifacts Sequence of dependencies that were resolved and retrieved * @param cacheDirectory directory where jars are cached * @return a comma-delimited list of paths for the dependencies @@ -783,10 +800,9 @@ private[deploy] object SparkSubmitUtils { artifacts: Array[AnyRef], cacheDirectory: File): String = { artifacts.map { artifactInfo => - val artifactString = artifactInfo.toString - val jarName = artifactString.drop(artifactString.lastIndexOf("!") + 1) + val artifact = artifactInfo.asInstanceOf[Artifact].getModuleRevisionId cacheDirectory.getAbsolutePath + File.separator + - jarName.substring(0, jarName.lastIndexOf(".jar") + 4) + s"${artifact.getOrganisation}_${artifact.getName}-${artifact.getRevision}.jar" }.mkString(",") } @@ -868,6 +884,7 @@ private[deploy] object SparkSubmitUtils { if (alternateIvyCache.trim.isEmpty) { new File(ivySettings.getDefaultIvyUserDir, "jars") } else { + ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) new File(alternateIvyCache, "jars") } @@ -877,7 +894,7 @@ private[deploy] object SparkSubmitUtils { // create a pattern matcher ivySettings.addMatcher(new GlobPatternMatcher) // create the dependency resolvers - val repoResolver = createRepoResolvers(remoteRepos) + val repoResolver = createRepoResolvers(remoteRepos, ivySettings) ivySettings.addResolver(repoResolver) ivySettings.setDefaultResolver(repoResolver.getName) @@ -911,7 +928,8 @@ private[deploy] object SparkSubmitUtils { } // retrieve all resolved dependencies ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, - packagesDirectory.getAbsolutePath + File.separator + "[artifact](-[classifier]).[ext]", + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision].[ext]", retrieveOptions.setConfs(Array(ivyConfName))) System.setOut(sysOut) resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 8bcca926097a1..1b2b699cb11e6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.deploy import java.io.{PrintStream, OutputStream, File} +import org.apache.ivy.core.settings.IvySettings + import scala.collection.mutable.ArrayBuffer import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -56,24 +58,23 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { } test("create repo resolvers") { - val resolver1 = SparkSubmitUtils.createRepoResolvers(None) + val settings = new IvySettings + val res1 = SparkSubmitUtils.createRepoResolvers(None, settings) // should have central and spark-packages by default - assert(resolver1.getResolvers.size() === 2) - assert(resolver1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "central") - assert(resolver1.getResolvers.get(1).asInstanceOf[IBiblioResolver].getName === "spark-packages") + assert(res1.getResolvers.size() === 4) + assert(res1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "local-m2-cache") + assert(res1.getResolvers.get(1).asInstanceOf[IBiblioResolver].getName === "local-ivy-cache") + assert(res1.getResolvers.get(2).asInstanceOf[IBiblioResolver].getName === "central") + assert(res1.getResolvers.get(3).asInstanceOf[IBiblioResolver].getName === "spark-packages") val repos = "a/1,b/2,c/3" - val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos)) - assert(resolver2.getResolvers.size() === 5) + val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos), settings) + assert(resolver2.getResolvers.size() === 7) val expected = repos.split(",").map(r => s"$r/") resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: IBiblioResolver, i) => - if (i == 0) { - assert(resolver.getName === "central") - } else if (i == 1) { - assert(resolver.getName === "spark-packages") - } else { - assert(resolver.getName === s"repo-${i - 1}") - assert(resolver.getRoot === expected(i - 2)) + if (i > 3) { + assert(resolver.getName === s"repo-${i - 3}") + assert(resolver.getRoot === expected(i - 4)) } } } From 8dee2746b5857bb4c32b96f38840dd1b63574ab2 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Tue, 28 Apr 2015 23:38:59 -0700 Subject: [PATCH 172/191] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #3205 (close requested by 'srowen') Closes #5478 (close requested by 'andrewor14') Closes #4910 (close requested by 'srowen') Closes #5080 (close requested by 'marmbrus') Closes #537 (close requested by 'srowen') Closes #5691 (close requested by 'srowen') Closes #5469 (close requested by 'marmbrus') From fe917f5ec9be8c8424416f7b5423ddb4318e03a0 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 29 Apr 2015 00:09:24 -0700 Subject: [PATCH 173/191] [SPARK-7188] added python support for math DataFrame functions Adds support for the math functions for DataFrames in PySpark. rxin I love Davies. Author: Burak Yavuz Closes #5750 from brkyvz/python-math-udfs and squashes the following commits: 7c4f563 [Burak Yavuz] removed is_math 3c4adde [Burak Yavuz] cleanup imports d5dca3f [Burak Yavuz] moved math functions to mathfunctions 25e6534 [Burak Yavuz] addressed comments v2.0 d3f7e0f [Burak Yavuz] addressed comments and added tests 7b7d7c4 [Burak Yavuz] remove tests for removed methods 33c2c15 [Burak Yavuz] fixed python style 3ee0c05 [Burak Yavuz] added python functions --- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/mathfunctions.py | 101 +++++ python/pyspark/sql/tests.py | 29 ++ .../expressions/mathfuncs/binary.scala | 12 +- .../expressions/mathfuncs/unary.scala | 122 +----- .../ExpressionEvaluationSuite.scala | 12 - .../org/apache/spark/sql/mathfunctions.scala | 409 +++++------------- .../spark/sql/MathExpressionsSuite.scala | 12 - 8 files changed, 276 insertions(+), 423 deletions(-) create mode 100644 python/pyspark/sql/mathfunctions.py diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7b86655d9c82f..555c2fa5e7071 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -54,7 +54,7 @@ def _(col): 'upper': 'Converts a string expression to upper case.', 'lower': 'Converts a string expression to upper case.', 'sqrt': 'Computes the square root of the specified float value.', - 'abs': 'Computes the absolutle value.', + 'abs': 'Computes the absolute value.', 'max': 'Aggregate function: returns the maximum value of the expression in a group.', 'min': 'Aggregate function: returns the minimum value of the expression in a group.', diff --git a/python/pyspark/sql/mathfunctions.py b/python/pyspark/sql/mathfunctions.py new file mode 100644 index 0000000000000..7dbcab8694293 --- /dev/null +++ b/python/pyspark/sql/mathfunctions.py @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A collection of builtin math functions +""" + +from pyspark import SparkContext +from pyspark.sql.dataframe import Column + +__all__ = [] + + +def _create_unary_mathfunction(name, doc=""): + """ Create a unary mathfunction by name""" + def _(col): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.mathfunctions, name)(col._jc if isinstance(col, Column) else col) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + +def _create_binary_mathfunction(name, doc=""): + """ Create a binary mathfunction by name""" + def _(col1, col2): + sc = SparkContext._active_spark_context + # users might write ints for simplicity. This would throw an error on the JVM side. + if type(col1) is int: + col1 = col1 * 1.0 + if type(col2) is int: + col2 = col2 * 1.0 + jc = getattr(sc._jvm.mathfunctions, name)(col1._jc if isinstance(col1, Column) else col1, + col2._jc if isinstance(col2, Column) else col2) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + +# math functions are found under another object therefore, they need to be handled separately +_mathfunctions = { + 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + + '0.0 through pi.', + 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' + + '-pi/2 through pi/2.', + 'atan': 'Computes the tangent inverse of the given value.', + 'cbrt': 'Computes the cube-root of the given value.', + 'ceil': 'Computes the ceiling of the given value.', + 'cos': 'Computes the cosine of the given value.', + 'cosh': 'Computes the hyperbolic cosine of the given value.', + 'exp': 'Computes the exponential of the given value.', + 'expm1': 'Computes the exponential of the given value minus one.', + 'floor': 'Computes the floor of the given value.', + 'log': 'Computes the natural logarithm of the given value.', + 'log10': 'Computes the logarithm of the given value in Base 10.', + 'log1p': 'Computes the natural logarithm of the given value plus one.', + 'rint': 'Returns the double value that is closest in value to the argument and' + + ' is equal to a mathematical integer.', + 'signum': 'Computes the signum of the given value.', + 'sin': 'Computes the sine of the given value.', + 'sinh': 'Computes the hyperbolic sine of the given value.', + 'tan': 'Computes the tangent of the given value.', + 'tanh': 'Computes the hyperbolic tangent of the given value.', + 'toDeg': 'Converts an angle measured in radians to an approximately equivalent angle ' + + 'measured in degrees.', + 'toRad': 'Converts an angle measured in degrees to an approximately equivalent angle ' + + 'measured in radians.' +} + +# math functions that take two arguments as input +_binary_mathfunctions = { + 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + + 'polar coordinates (r, theta).', + 'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.', + 'pow': 'Returns the value of the first argument raised to the power of the second argument.' +} + +for _name, _doc in _mathfunctions.items(): + globals()[_name] = _create_unary_mathfunction(_name, _doc) +for _name, _doc in _binary_mathfunctions.items(): + globals()[_name] = _create_binary_mathfunction(_name, _doc) +del _name, _doc +__all__ += _mathfunctions.keys() +__all__ += _binary_mathfunctions.keys() +__all__.sort() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fe43c374f1cb1..2ffd18ebd7c89 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -387,6 +387,35 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_math_functions(self): + df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() + from pyspark.sql import mathfunctions as functions + import math + + def get_values(l): + return [j[0] for j in l] + + def assert_close(a, b): + c = get_values(b) + diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)] + return sum(diff) == len(a) + assert_close([math.cos(i) for i in range(10)], + df.select(functions.cos(df.a)).collect()) + assert_close([math.cos(i) for i in range(10)], + df.select(functions.cos("a")).collect()) + assert_close([math.sin(i) for i in range(10)], + df.select(functions.sin(df.a)).collect()) + assert_close([math.sin(i) for i in range(10)], + df.select(functions.sin(df['a'])).collect()) + assert_close([math.pow(i, 2 * i) for i in range(10)], + df.select(functions.pow(df.a, df.b)).collect()) + assert_close([math.pow(i, 2) for i in range(10)], + df.select(functions.pow(df.a, 2)).collect()) + assert_close([math.pow(i, 2) for i in range(10)], + df.select(functions.pow(df.a, 2.0)).collect()) + assert_close([math.hypot(i, 2 * i) for i in range(10)], + df.select(functions.hypot(df.a, df.b)).collect()) + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index 5b4d912a64f71..fcc06d3aa1036 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -65,12 +65,6 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } -case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") - -case class Hypot( - left: Expression, - right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") - case class Atan2( left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { @@ -91,3 +85,9 @@ case class Atan2( } } } + +case class Hypot( + left: Expression, + right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") + +case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index 96cb77d487529..dc68469e060cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -25,27 +25,16 @@ import org.apache.spark.sql.types._ * input format, therefore these functions extend `ExpectsInputTypes`. * @param name The short name of the function */ -abstract class MathematicalExpression(name: String) +abstract class MathematicalExpression(f: Double => Double, name: String) extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product => type EvaluatedType = Any + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"$name($child)" -} - -/** - * A unary expression specifically for math functions that take a `Double` as input and return - * a `Double`. - * @param f The math function. - * @param name The short name of the function - */ -abstract class MathematicalExpressionForDouble(f: Double => Double, name: String) - extends MathematicalExpression(name) { self: Product => - - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) override def eval(input: Row): Any = { val evalE = child.eval(input) @@ -58,111 +47,46 @@ abstract class MathematicalExpressionForDouble(f: Double => Double, name: String } } -/** - * A unary expression specifically for math functions that take an `Int` as input and return - * an `Int`. - * @param f The math function. - * @param name The short name of the function - */ -abstract class MathematicalExpressionForInt(f: Int => Int, name: String) - extends MathematicalExpression(name) { self: Product => +case class Acos(child: Expression) extends MathematicalExpression(math.acos, "ACOS") - override def dataType: DataType = IntegerType - override def expectedChildTypes: Seq[DataType] = Seq(IntegerType) +case class Asin(child: Expression) extends MathematicalExpression(math.asin, "ASIN") - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) null else f(evalE.asInstanceOf[Int]) - } -} +case class Atan(child: Expression) extends MathematicalExpression(math.atan, "ATAN") -/** - * A unary expression specifically for math functions that take a `Float` as input and return - * a `Float`. - * @param f The math function. - * @param name The short name of the function - */ -abstract class MathematicalExpressionForFloat(f: Float => Float, name: String) - extends MathematicalExpression(name) { self: Product => +case class Cbrt(child: Expression) extends MathematicalExpression(math.cbrt, "CBRT") - override def dataType: DataType = FloatType - override def expectedChildTypes: Seq[DataType] = Seq(FloatType) +case class Ceil(child: Expression) extends MathematicalExpression(math.ceil, "CEIL") - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val result = f(evalE.asInstanceOf[Float]) - if (result.isNaN) null else result - } - } -} - -/** - * A unary expression specifically for math functions that take a `Long` as input and return - * a `Long`. - * @param f The math function. - * @param name The short name of the function - */ -abstract class MathematicalExpressionForLong(f: Long => Long, name: String) - extends MathematicalExpression(name) { self: Product => - - override def dataType: DataType = LongType - override def expectedChildTypes: Seq[DataType] = Seq(LongType) - - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) null else f(evalE.asInstanceOf[Long]) - } -} - -case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin, "SIN") - -case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin, "ASIN") - -case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh, "SINH") - -case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos, "COS") +case class Cos(child: Expression) extends MathematicalExpression(math.cos, "COS") -case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos, "ACOS") +case class Cosh(child: Expression) extends MathematicalExpression(math.cosh, "COSH") -case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh, "COSH") +case class Exp(child: Expression) extends MathematicalExpression(math.exp, "EXP") -case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan, "TAN") +case class Expm1(child: Expression) extends MathematicalExpression(math.expm1, "EXPM1") -case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan, "ATAN") +case class Floor(child: Expression) extends MathematicalExpression(math.floor, "FLOOR") -case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh, "TANH") +case class Log(child: Expression) extends MathematicalExpression(math.log, "LOG") -case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil, "CEIL") +case class Log10(child: Expression) extends MathematicalExpression(math.log10, "LOG10") -case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor, "FLOOR") +case class Log1p(child: Expression) extends MathematicalExpression(math.log1p, "LOG1P") -case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint, "ROUND") +case class Rint(child: Expression) extends MathematicalExpression(math.rint, "ROUND") -case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt, "CBRT") +case class Signum(child: Expression) extends MathematicalExpression(math.signum, "SIGNUM") -case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum, "SIGNUM") +case class Sin(child: Expression) extends MathematicalExpression(math.sin, "SIN") -case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum, "ISIGNUM") +case class Sinh(child: Expression) extends MathematicalExpression(math.sinh, "SINH") -case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum, "FSIGNUM") +case class Tan(child: Expression) extends MathematicalExpression(math.tan, "TAN") -case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum, "LSIGNUM") +case class Tanh(child: Expression) extends MathematicalExpression(math.tanh, "TANH") case class ToDegrees(child: Expression) - extends MathematicalExpressionForDouble(math.toDegrees, "DEGREES") + extends MathematicalExpression(math.toDegrees, "DEGREES") case class ToRadians(child: Expression) - extends MathematicalExpressionForDouble(math.toRadians, "RADIANS") - -case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log, "LOG") - -case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10, "LOG10") - -case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p, "LOG1P") - -case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp, "EXP") - -case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1, "EXPM1") + extends MathematicalExpression(math.toRadians, "RADIANS") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 5390ce43c6639..fa71001c9336e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -1253,18 +1253,6 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { unaryMathFunctionEvaluation[Double](Signum, math.signum) } - test("isignum") { - unaryMathFunctionEvaluation[Int](ISignum, math.signum, (-5 to 5)) - } - - test("fsignum") { - unaryMathFunctionEvaluation[Float](FSignum, math.signum, (-5 to 5).map(_.toFloat)) - } - - test("lsignum") { - unaryMathFunctionEvaluation[Long](LSignum, math.signum, (5 to 5).map(_.toLong)) - } - test("log") { unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1)) unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala index 84f62bf47f955..d901542b7eaf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -29,9 +29,6 @@ import org.apache.spark.sql.functions.lit * Mathematical Functions available for [[DataFrame]]. * * @groupname double_funcs Functions that require DoubleType as an input - * @groupname int_funcs Functions that require IntegerType as an input - * @groupname float_funcs Functions that require FloatType as an input - * @groupname long_funcs Functions that require LongType as an input */ @Experimental // scalastyle:off @@ -41,522 +38,348 @@ object mathfunctions { private[this] implicit def toColumn(expr: Expression): Column = Column(expr) /** - * Computes the sine of the given value. - * - * @group double_funcs + * Computes the cosine inverse of the given value; the returned angle is in the range + * 0.0 through pi. */ - def sin(e: Column): Column = Sin(e.expr) + def acos(e: Column): Column = Acos(e.expr) /** - * Computes the sine of the given column. - * - * @group double_funcs + * Computes the cosine inverse of the given column; the returned angle is in the range + * 0.0 through pi. */ - def sin(columnName: String): Column = sin(Column(columnName)) + def acos(columnName: String): Column = acos(Column(columnName)) /** * Computes the sine inverse of the given value; the returned angle is in the range * -pi/2 through pi/2. - * - * @group double_funcs */ def asin(e: Column): Column = Asin(e.expr) /** * Computes the sine inverse of the given column; the returned angle is in the range * -pi/2 through pi/2. - * - * @group double_funcs */ def asin(columnName: String): Column = asin(Column(columnName)) /** - * Computes the hyperbolic sine of the given value. - * - * @group double_funcs - */ - def sinh(e: Column): Column = Sinh(e.expr) - - /** - * Computes the hyperbolic sine of the given column. - * - * @group double_funcs + * Computes the tangent inverse of the given value. */ - def sinh(columnName: String): Column = sinh(Column(columnName)) + def atan(e: Column): Column = Atan(e.expr) /** - * Computes the cosine of the given value. - * - * @group double_funcs + * Computes the tangent inverse of the given column. */ - def cos(e: Column): Column = Cos(e.expr) + def atan(columnName: String): Column = atan(Column(columnName)) /** - * Computes the cosine of the given column. - * - * @group double_funcs + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). */ - def cos(columnName: String): Column = cos(Column(columnName)) + def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) /** - * Computes the cosine inverse of the given value; the returned angle is in the range - * 0.0 through pi. - * - * @group double_funcs + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). */ - def acos(e: Column): Column = Acos(e.expr) + def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) /** - * Computes the cosine inverse of the given column; the returned angle is in the range - * 0.0 through pi. - * - * @group double_funcs + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). */ - def acos(columnName: String): Column = acos(Column(columnName)) + def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) /** - * Computes the hyperbolic cosine of the given value. - * - * @group double_funcs + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). */ - def cosh(e: Column): Column = Cosh(e.expr) + def atan2(leftName: String, rightName: String): Column = + atan2(Column(leftName), Column(rightName)) /** - * Computes the hyperbolic cosine of the given column. - * - * @group double_funcs - */ - def cosh(columnName: String): Column = cosh(Column(columnName)) - - /** - * Computes the tangent of the given value. - * - * @group double_funcs + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). */ - def tan(e: Column): Column = Tan(e.expr) + def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) /** - * Computes the tangent of the given column. - * - * @group double_funcs + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta).= */ - def tan(columnName: String): Column = tan(Column(columnName)) + def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) /** - * Computes the tangent inverse of the given value. - * - * @group double_funcs + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). */ - def atan(e: Column): Column = Atan(e.expr) + def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) /** - * Computes the tangent inverse of the given column. - * - * @group double_funcs + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). */ - def atan(columnName: String): Column = atan(Column(columnName)) + def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) /** - * Computes the hyperbolic tangent of the given value. - * - * @group double_funcs + * Computes the cube-root of the given value. */ - def tanh(e: Column): Column = Tanh(e.expr) + def cbrt(e: Column): Column = Cbrt(e.expr) /** - * Computes the hyperbolic tangent of the given column. - * - * @group double_funcs + * Computes the cube-root of the given column. */ - def tanh(columnName: String): Column = tanh(Column(columnName)) + def cbrt(columnName: String): Column = cbrt(Column(columnName)) /** - * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. - * - * @group double_funcs + * Computes the ceiling of the given value. */ - def toDeg(e: Column): Column = ToDegrees(e.expr) + def ceil(e: Column): Column = Ceil(e.expr) /** - * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. - * - * @group double_funcs + * Computes the ceiling of the given column. */ - def toDeg(columnName: String): Column = toDeg(Column(columnName)) + def ceil(columnName: String): Column = ceil(Column(columnName)) /** - * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. - * - * @group double_funcs + * Computes the cosine of the given value. */ - def toRad(e: Column): Column = ToRadians(e.expr) + def cos(e: Column): Column = Cos(e.expr) /** - * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. - * - * @group double_funcs + * Computes the cosine of the given column. */ - def toRad(columnName: String): Column = toRad(Column(columnName)) + def cos(columnName: String): Column = cos(Column(columnName)) /** - * Computes the ceiling of the given value. - * - * @group double_funcs + * Computes the hyperbolic cosine of the given value. */ - def ceil(e: Column): Column = Ceil(e.expr) + def cosh(e: Column): Column = Cosh(e.expr) /** - * Computes the ceiling of the given column. - * - * @group double_funcs + * Computes the hyperbolic cosine of the given column. */ - def ceil(columnName: String): Column = ceil(Column(columnName)) + def cosh(columnName: String): Column = cosh(Column(columnName)) /** - * Computes the floor of the given value. - * - * @group double_funcs + * Computes the exponential of the given value. */ - def floor(e: Column): Column = Floor(e.expr) + def exp(e: Column): Column = Exp(e.expr) /** - * Computes the floor of the given column. - * - * @group double_funcs + * Computes the exponential of the given column. */ - def floor(columnName: String): Column = floor(Column(columnName)) + def exp(columnName: String): Column = exp(Column(columnName)) /** - * Returns the double value that is closest in value to the argument and - * is equal to a mathematical integer. - * - * @group double_funcs + * Computes the exponential of the given value minus one. */ - def rint(e: Column): Column = Rint(e.expr) + def expm1(e: Column): Column = Expm1(e.expr) /** - * Returns the double value that is closest in value to the argument and - * is equal to a mathematical integer. - * - * @group double_funcs + * Computes the exponential of the given column. */ - def rint(columnName: String): Column = rint(Column(columnName)) + def expm1(columnName: String): Column = expm1(Column(columnName)) /** - * Computes the cube-root of the given value. - * - * @group double_funcs + * Computes the floor of the given value. */ - def cbrt(e: Column): Column = Cbrt(e.expr) + def floor(e: Column): Column = Floor(e.expr) /** - * Computes the cube-root of the given column. - * - * @group double_funcs + * Computes the floor of the given column. */ - def cbrt(columnName: String): Column = cbrt(Column(columnName)) + def floor(columnName: String): Column = floor(Column(columnName)) /** - * Computes the signum of the given value. - * - * @group double_funcs + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. */ - def signum(e: Column): Column = Signum(e.expr) + def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) /** - * Computes the signum of the given column. - * - * @group double_funcs + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. */ - def signum(columnName: String): Column = signum(Column(columnName)) + def hypot(l: Column, rightName: String): Column = hypot(l, Column(rightName)) /** - * Computes the signum of the given value. For IntegerType. - * - * @group int_funcs + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. */ - def isignum(e: Column): Column = ISignum(e.expr) + def hypot(leftName: String, r: Column): Column = hypot(Column(leftName), r) /** - * Computes the signum of the given column. For IntegerType. - * - * @group int_funcs + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. */ - def isignum(columnName: String): Column = isignum(Column(columnName)) + def hypot(leftName: String, rightName: String): Column = + hypot(Column(leftName), Column(rightName)) /** - * Computes the signum of the given value. For FloatType. - * - * @group float_funcs + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. */ - def fsignum(e: Column): Column = FSignum(e.expr) + def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) /** - * Computes the signum of the given column. For FloatType. - * - * @group float_funcs + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. */ - def fsignum(columnName: String): Column = fsignum(Column(columnName)) + def hypot(leftName: String, r: Double): Column = hypot(Column(leftName), r) /** - * Computes the signum of the given value. For LongType. - * - * @group long_funcs + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. */ - def lsignum(e: Column): Column = LSignum(e.expr) + def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) /** - * Computes the signum of the given column. For FloatType. - * - * @group long_funcs + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. */ - def lsignum(columnName: String): Column = lsignum(Column(columnName)) + def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) /** * Computes the natural logarithm of the given value. - * - * @group double_funcs */ def log(e: Column): Column = Log(e.expr) /** * Computes the natural logarithm of the given column. - * - * @group double_funcs */ def log(columnName: String): Column = log(Column(columnName)) /** * Computes the logarithm of the given value in Base 10. - * - * @group double_funcs */ def log10(e: Column): Column = Log10(e.expr) /** * Computes the logarithm of the given value in Base 10. - * - * @group double_funcs */ def log10(columnName: String): Column = log10(Column(columnName)) /** * Computes the natural logarithm of the given value plus one. - * - * @group double_funcs */ def log1p(e: Column): Column = Log1p(e.expr) /** * Computes the natural logarithm of the given column plus one. - * - * @group double_funcs */ def log1p(columnName: String): Column = log1p(Column(columnName)) - /** - * Computes the exponential of the given value. - * - * @group double_funcs - */ - def exp(e: Column): Column = Exp(e.expr) - - /** - * Computes the exponential of the given column. - * - * @group double_funcs - */ - def exp(columnName: String): Column = exp(Column(columnName)) - - /** - * Computes the exponential of the given value minus one. - * - * @group double_funcs - */ - def expm1(e: Column): Column = Expm1(e.expr) - - /** - * Computes the exponential of the given column. - * - * @group double_funcs - */ - def expm1(columnName: String): Column = expm1(Column(columnName)) - /** * Returns the value of the first argument raised to the power of the second argument. - * - * @group double_funcs */ def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) /** * Returns the value of the first argument raised to the power of the second argument. - * - * @group double_funcs */ def pow(l: Column, rightName: String): Column = pow(l, Column(rightName)) /** * Returns the value of the first argument raised to the power of the second argument. - * - * @group double_funcs */ def pow(leftName: String, r: Column): Column = pow(Column(leftName), r) /** * Returns the value of the first argument raised to the power of the second argument. - * - * @group double_funcs */ def pow(leftName: String, rightName: String): Column = pow(Column(leftName), Column(rightName)) /** * Returns the value of the first argument raised to the power of the second argument. - * - * @group double_funcs */ def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) /** * Returns the value of the first argument raised to the power of the second argument. - * - * @group double_funcs */ def pow(leftName: String, r: Double): Column = pow(Column(leftName), r) /** * Returns the value of the first argument raised to the power of the second argument. - * - * @group double_funcs */ def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) /** * Returns the value of the first argument raised to the power of the second argument. - * - * @group double_funcs */ def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * @group double_funcs + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. */ - def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + def rint(e: Column): Column = Rint(e.expr) /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * @group double_funcs + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. */ - def hypot(l: Column, rightName: String): Column = hypot(l, Column(rightName)) + def rint(columnName: String): Column = rint(Column(columnName)) /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * @group double_funcs + * Computes the signum of the given value. */ - def hypot(leftName: String, r: Column): Column = hypot(Column(leftName), r) + def signum(e: Column): Column = Signum(e.expr) /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * @group double_funcs + * Computes the signum of the given column. */ - def hypot(leftName: String, rightName: String): Column = - hypot(Column(leftName), Column(rightName)) + def signum(columnName: String): Column = signum(Column(columnName)) /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * @group double_funcs + * Computes the sine of the given value. */ - def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + def sin(e: Column): Column = Sin(e.expr) /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * @group double_funcs + * Computes the sine of the given column. */ - def hypot(leftName: String, r: Double): Column = hypot(Column(leftName), r) + def sin(columnName: String): Column = sin(Column(columnName)) /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * @group double_funcs + * Computes the hyperbolic sine of the given value. */ - def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + def sinh(e: Column): Column = Sinh(e.expr) /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * @group double_funcs + * Computes the hyperbolic sine of the given column. */ - def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) - + def sinh(columnName: String): Column = sinh(Column(columnName)) + /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). - * - * @group double_funcs + * Computes the tangent of the given value. */ - def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + def tan(e: Column): Column = Tan(e.expr) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). - * - * @group double_funcs + * Computes the tangent of the given column. */ - def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) - + def tan(columnName: String): Column = tan(Column(columnName)) + /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). - * - * @group double_funcs + * Computes the hyperbolic tangent of the given value. */ - def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + def tanh(e: Column): Column = Tanh(e.expr) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). - * - * @group double_funcs + * Computes the hyperbolic tangent of the given column. */ - def atan2(leftName: String, rightName: String): Column = - atan2(Column(leftName), Column(rightName)) + def tanh(columnName: String): Column = tanh(Column(columnName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). - * - * @group double_funcs + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. */ - def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + def toDeg(e: Column): Column = ToDegrees(e.expr) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). - * - * @group double_funcs + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. */ - def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + def toDeg(columnName: String): Column = toDeg(Column(columnName)) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). - * - * @group double_funcs + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. */ - def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + def toRad(e: Column): Column = ToRadians(e.expr) /** - * Returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). - * - * @group double_funcs + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. */ - def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) + def toRad(columnName: String): Column = toRad(Column(columnName)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 561553cc925cb..9e19bb7482e9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -194,18 +194,6 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction[Double](signum, math.signum) } - test("isignum") { - testOneToOneMathFunction[Int](isignum, math.signum) - } - - test("fsignum") { - testOneToOneMathFunction[Float](fsignum, math.signum) - } - - test("lsignum") { - testOneToOneMathFunction[Long](lsignum, math.signum) - } - test("pow") { testTwoToOneMathFunction(pow, pow, math.pow) } From 1fd6ed9a56ac4671f4a3d25a42823ba3bf01f60f Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 29 Apr 2015 00:35:08 -0700 Subject: [PATCH 174/191] [SPARK-7204] [SQL] Fix callSite for Dataframe and SQL operations This patch adds SQL to the set of excluded libraries when generating a callSite. This makes the callSite mechanism work properly for the data frame API. I also added a small improvement for JDBC queries where we just use the string "Spark JDBC Server Query" instead of trying to give a callsite that doesn't make any sense to the user. Before (DF): ![screen shot 2015-04-28 at 1 29 26 pm](https://cloud.githubusercontent.com/assets/320616/7380170/ef63bfb0-edae-11e4-989c-f88a5ba6bbee.png) After (DF): ![screen shot 2015-04-28 at 1 34 58 pm](https://cloud.githubusercontent.com/assets/320616/7380181/fa7f6d90-edae-11e4-9559-26f163ed63b8.png) After (JDBC): ![screen shot 2015-04-28 at 2 00 10 pm](https://cloud.githubusercontent.com/assets/320616/7380185/02f5b2a4-edaf-11e4-8e5b-99bdc3df66dd.png) Author: Patrick Wendell Closes #5757 from pwendell/dataframes and squashes the following commits: 0d931a4 [Patrick Wendell] Attempting to fix PySpark tests 85bf740 [Patrick Wendell] [SPARK-7204] Fix callsite for dataframe operations. --- .../scala/org/apache/spark/util/Utils.scala | 28 +++++++++++++------ python/pyspark/sql/dataframe.py | 3 +- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4c028c06a5138..4b5a5df5ef7b7 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1299,16 +1299,18 @@ private[spark] object Utils extends Logging { } /** Default filtering function for finding call sites using `getCallSite`. */ - private def coreExclusionFunction(className: String): Boolean = { - // A regular expression to match classes of the "core" Spark API that we want to skip when - // finding the call site of a method. + private def sparkInternalExclusionFunction(className: String): Boolean = { + // A regular expression to match classes of the internal Spark API's + // that we want to skip when finding the call site of a method. val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r + val SPARK_SQL_CLASS_REGEX = """^org\.apache\.spark\.sql.*""".r val SCALA_CORE_CLASS_PREFIX = "scala" - val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined + val isSparkClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined || + SPARK_SQL_CLASS_REGEX.findFirstIn(className).isDefined val isScalaClass = className.startsWith(SCALA_CORE_CLASS_PREFIX) // If the class is a Spark internal class or a Scala class, then exclude. - isSparkCoreClass || isScalaClass + isSparkClass || isScalaClass } /** @@ -1318,7 +1320,7 @@ private[spark] object Utils extends Logging { * * @param skipClass Function that is used to exclude non-user-code classes. */ - def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = { + def getCallSite(skipClass: String => Boolean = sparkInternalExclusionFunction): CallSite = { // Keep crawling up the stack trace until we find the first function not inside of the spark // package. We track the last (shallowest) contiguous Spark method. This might be an RDD // transformation, a SparkContext function (such as parallelize), or anything else that leads @@ -1357,9 +1359,17 @@ private[spark] object Utils extends Logging { } val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt - CallSite( - shortForm = s"$lastSparkMethod at $firstUserFile:$firstUserLine", - longForm = callStack.take(callStackDepth).mkString("\n")) + val shortForm = + if (firstUserFile == "HiveSessionImpl.java") { + // To be more user friendly, show a nicer string for queries submitted from the JDBC + // server. + "Spark JDBC Server Query" + } else { + s"$lastSparkMethod at $firstUserFile:$firstUserLine" + } + val longForm = callStack.take(callStackDepth).mkString("\n") + + CallSite(shortForm, longForm) } /** Return a string containing part of a file from byte 'start' to 'end'. */ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 4759f5fe783ad..6879fe0805c58 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -237,7 +237,8 @@ def explain(self, extended=False): :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. >>> df.explain() - PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:... + PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at applySchemaToPythonRDD at\ + NativeMethodAccessorImpl.java:... >>> df.explain(True) == Parsed Logical Plan == From f49284b5bf3a69ed91a5e3e6e0ed3be93a6ab9e4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 29 Apr 2015 01:07:26 -0700 Subject: [PATCH 175/191] [SPARK-7076][SPARK-7077][SPARK-7080][SQL] Use managed memory for aggregations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch adds managed-memory-based aggregation to Spark SQL / DataFrames. Instead of working with Java objects, this new aggregation path uses `sun.misc.Unsafe` to manipulate raw memory. This reduces the memory footprint for aggregations, resulting in fewer spills, OutOfMemoryErrors, and garbage collection pauses. As a result, this allows for higher memory utilization. It can also result in better cache locality since objects will be stored closer together in memory. This feature can be eanbled by setting `spark.sql.unsafe.enabled=true`. For now, this feature is only supported when codegen is enabled and only supports aggregations for which the grouping columns are primitive numeric types or strings and aggregated values are numeric. ### Managing memory with sun.misc.Unsafe This patch supports both on- and off-heap managed memory. - In on-heap mode, memory addresses are identified by the combination of a base Object and an offset within that object. - In off-heap mode, memory is addressed directly with 64-bit long addresses. To support both modes, functions that manipulate memory accept both `baseObject` and `baseOffset` fields. In off-heap mode, we simply pass `null` as `baseObject`. We allocate memory in large chunks, so memory fragmentation and allocation speed are not significant bottlenecks. By default, we use on-heap mode. To enable off-heap mode, set `spark.unsafe.offHeap=true`. To track allocated memory, this patch extends `SparkEnv` with an `ExecutorMemoryManager` and supplies each `TaskContext` with a `TaskMemoryManager`. These classes work together to track allocations and detect memory leaks. ### Compact tuple format This patch introduces `UnsafeRow`, a compact row layout. In this format, each tuple has three parts: a null bit set, fixed length values, and variable-length values: ![image](https://cloud.githubusercontent.com/assets/50748/7328538/2fdb65ce-ea8b-11e4-9743-6c0f02bb7d1f.png) - Rows are always 8-byte word aligned (so their sizes will always be a multiple of 8 bytes) - The bit set is used for null tracking: - Position _i_ is set if and only if field _i_ is null - The bit set is aligned to an 8-byte word boundary. - Every field appears as an 8-byte word in the fixed-length values part: - If a field is null, we zero out the values. - If a field is variable-length, the word stores a relative offset (w.r.t. the base of the tuple) that points to the beginning of the field's data in the variable-length part. - Each variable-length data type can have its own encoding: - For strings, the first word stores the length of the string and is followed by UTF-8 encoded bytes. If necessary, the end of the string is padded with empty bytes in order to ensure word-alignment. For example, a tuple that consists 3 fields of type (int, string, string), with value (null, “data”, “bricks”) would look like this: ![image](https://cloud.githubusercontent.com/assets/50748/7328526/1e21959c-ea8b-11e4-9a28-a4350fe4a7b5.png) This format allows us to compare tuples for equality by directly comparing their raw bytes. This also enables fast hashing of tuples. ### Hash map for performing aggregations This patch introduces `UnsafeFixedWidthAggregationMap`, a hash map for performing aggregations where the aggregation result columns are fixed-with. This map's keys and values are `Row` objects. `UnsafeFixedWidthAggregationMap` is implemented on top of `BytesToBytesMap`, an append-only map which supports byte-array keys and values. `BytesToBytesMap` stores pointers to key and value tuples. For each record with a new key, we copy the key and create the aggregation value buffer for that key and put them in a buffer. The hash table then simply stores pointers to the key and value. For each record with an existing key, we simply run the aggregation function to update the values in place. This map is implemented using open hashing with triangular sequence probing. Each entry stores two words in a long array: the first word stores the address of the key and the second word stores the relative offset from the key tuple to the value tuple, as well as the key's 32-bit hashcode. By storing the full hashcode, we reduce the number of equality checks that need to be performed to handle position collisions ()since the chance of hashcode collision is much lower than position collision). `UnsafeFixedWidthAggregationMap` allows regular Spark SQL `Row` objects to be used when probing the map. Internally, it encodes these rows into `UnsafeRow` format using `UnsafeRowConverter`. This conversion has a small overhead that can be eliminated in the future once we use UnsafeRows in other operators. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/5725) Author: Josh Rosen Closes #5725 from JoshRosen/unsafe and squashes the following commits: eeee512 [Josh Rosen] Add converters for Null, Boolean, Byte, and Short columns. 81f34f8 [Josh Rosen] Follow 'place children last' convention for GeneratedAggregate 1bc36cc [Josh Rosen] Refactor UnsafeRowConverter to avoid unnecessary boxing. 017b2dc [Josh Rosen] Remove BytesToBytesMap.finalize() 50e9671 [Josh Rosen] Throw memory leak warning even in case of error; add warning about code duplication 70a39e4 [Josh Rosen] Split MemoryManager into ExecutorMemoryManager and TaskMemoryManager: 6e4b192 [Josh Rosen] Remove an unused method from ByteArrayMethods. de5e001 [Josh Rosen] Fix debug vs. trace in logging message. a19e066 [Josh Rosen] Rename unsafe Java test suites to match Scala test naming convention. 78a5b84 [Josh Rosen] Add logging to MemoryManager ce3c565 [Josh Rosen] More comments, formatting, and code cleanup. 529e571 [Josh Rosen] Measure timeSpentResizing in nanoseconds instead of milliseconds. 3ca84b2 [Josh Rosen] Only zero the used portion of groupingKeyConversionScratchSpace 162caf7 [Josh Rosen] Fix test compilation b45f070 [Josh Rosen] Don't redundantly store the offset from key to value, since we can compute this from the key size. a8e4a3f [Josh Rosen] Introduce MemoryManager interface; add to SparkEnv. 0925847 [Josh Rosen] Disable MiMa checks for new unsafe module cde4132 [Josh Rosen] Add missing pom.xml 9c19fc0 [Josh Rosen] Add configuration options for heap vs. offheap 6ffdaa1 [Josh Rosen] Null handling improvements in UnsafeRow. 31eaabc [Josh Rosen] Lots of TODO and doc cleanup. a95291e [Josh Rosen] Cleanups to string handling code afe8dca [Josh Rosen] Some Javadoc cleanup f3dcbfe [Josh Rosen] More mod replacement 854201a [Josh Rosen] Import and comment cleanup 06e929d [Josh Rosen] More warning cleanup ef6b3d3 [Josh Rosen] Fix a bunch of FindBugs and IntelliJ inspections 29a7575 [Josh Rosen] Remove debug logging 49aed30 [Josh Rosen] More long -> int conversion. b26f1d3 [Josh Rosen] Fix bug in murmur hash implementation. 765243d [Josh Rosen] Enable optional performance metrics for hash map. 23a440a [Josh Rosen] Bump up default hash map size 628f936 [Josh Rosen] Use ints intead of longs for indexing. 92d5a06 [Josh Rosen] Address a number of minor code review comments. 1f4b716 [Josh Rosen] Merge Unsafe code into the regular GeneratedAggregate, guarded by a configuration flag; integrate planner support and re-enable all tests. d85eeff [Josh Rosen] Add basic sanity test for UnsafeFixedWidthAggregationMap bade966 [Josh Rosen] Comment update (bumping to refresh GitHub cache...) b3eaccd [Josh Rosen] Extract aggregation map into its own class. d2bb986 [Josh Rosen] Update to implement new Row methods added upstream 58ac393 [Josh Rosen] Use UNSAFE allocator in GeneratedAggregate (TODO: make this configurable) 7df6008 [Josh Rosen] Optimizations related to zeroing out memory: c1b3813 [Josh Rosen] Fix bug in UnsafeMemoryAllocator.free(): 738fa33 [Josh Rosen] Add feature flag to guard UnsafeGeneratedAggregate c55bf66 [Josh Rosen] Free buffer once iterator has been fully consumed. 62ab054 [Josh Rosen] Optimize for fact that get() is only called on String columns. c7f0b56 [Josh Rosen] Reuse UnsafeRow pointer in UnsafeRowConverter ae39694 [Josh Rosen] Add finalizer as "cleanup method of last resort" c754ae1 [Josh Rosen] Now that the store*() contract has been stregthened, we can remove an extra lookup f764d13 [Josh Rosen] Simplify address + length calculation in Location. 079f1bf [Josh Rosen] Some clarification of the BytesToBytesMap.lookup() / set() contract. 1a483c5 [Josh Rosen] First version that passes some aggregation tests: fc4c3a8 [Josh Rosen] Sketch how the converters will be used in UnsafeGeneratedAggregate 53ba9b7 [Josh Rosen] Start prototyping Java Row -> UnsafeRow converters 1ff814d [Josh Rosen] Add reminder to free memory on iterator completion 8a8f9df [Josh Rosen] Add skeleton for GeneratedAggregate integration. 5d55cef [Josh Rosen] Add skeleton for Row implementation. f03e9c1 [Josh Rosen] Play around with Unsafe implementations of more string methods. ab68e08 [Josh Rosen] Begin merging the UTF8String implementations. 480a74a [Josh Rosen] Initial import of code from Databricks unsafe utils repo. --- core/pom.xml | 5 + .../scala/org/apache/spark/SparkEnv.scala | 12 + .../scala/org/apache/spark/TaskContext.scala | 6 + .../org/apache/spark/TaskContextImpl.scala | 2 + .../org/apache/spark/executor/Executor.scala | 19 +- .../apache/spark/scheduler/DAGScheduler.scala | 22 +- .../org/apache/spark/scheduler/Task.scala | 16 +- .../java/org/apache/spark/JavaAPISuite.java | 2 +- .../org/apache/spark/CacheManagerSuite.scala | 8 +- .../org/apache/spark/rdd/PipedRDDSuite.scala | 2 +- .../spark/scheduler/TaskContextSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 6 +- pom.xml | 2 + project/SparkBuild.scala | 7 +- sql/catalyst/pom.xml | 5 + .../UnsafeFixedWidthAggregationMap.java | 259 +++++++++ .../sql/catalyst/expressions/UnsafeRow.java | 435 ++++++++++++++ .../expressions/UnsafeRowConverter.scala | 223 +++++++ .../UnsafeFixedWidthAggregationMapSuite.scala | 119 ++++ .../expressions/UnsafeRowConverterSuite.scala | 153 +++++ .../scala/org/apache/spark/sql/SQLConf.scala | 9 + .../org/apache/spark/sql/SQLContext.scala | 2 + .../sql/execution/GeneratedAggregate.scala | 60 ++ .../spark/sql/execution/SparkStrategies.scala | 2 + unsafe/pom.xml | 69 +++ .../spark/unsafe/PlatformDependent.java | 87 +++ .../spark/unsafe/array/ByteArrayMethods.java | 56 ++ .../apache/spark/unsafe/array/LongArray.java | 78 +++ .../apache/spark/unsafe/bitset/BitSet.java | 105 ++++ .../spark/unsafe/bitset/BitSetMethods.java | 129 ++++ .../spark/unsafe/hash/Murmur3_x86_32.java | 96 +++ .../spark/unsafe/map/BytesToBytesMap.java | 549 ++++++++++++++++++ .../unsafe/map/HashMapGrowthStrategy.java | 39 ++ .../unsafe/memory/ExecutorMemoryManager.java | 58 ++ .../unsafe/memory/HeapMemoryAllocator.java | 35 ++ .../spark/unsafe/memory/MemoryAllocator.java | 33 ++ .../spark/unsafe/memory/MemoryBlock.java | 63 ++ .../spark/unsafe/memory/MemoryLocation.java | 54 ++ .../unsafe/memory/TaskMemoryManager.java | 237 ++++++++ .../unsafe/memory/UnsafeMemoryAllocator.java | 39 ++ .../spark/unsafe/array/LongArraySuite.java | 38 ++ .../spark/unsafe/bitset/BitSetSuite.java | 82 +++ .../unsafe/hash/Murmur3_x86_32Suite.java | 119 ++++ .../map/AbstractBytesToBytesMapSuite.java | 250 ++++++++ .../map/BytesToBytesMapOffHeapSuite.java | 29 + .../map/BytesToBytesMapOnHeapSuite.java | 29 + .../unsafe/memory/TaskMemoryManagerSuite.java | 41 ++ 47 files changed, 3675 insertions(+), 18 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala create mode 100644 unsafe/pom.xml create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java create mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java diff --git a/core/pom.xml b/core/pom.xml index 459ef66712c36..2dfb00d7ecf26 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -95,6 +95,11 @@ spark-network-shuffle_${scala.binary.version} ${project.version} + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${project.version} + net.java.dev.jets3t jets3t diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 959aefabd8de4..0c4d28f786edd 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -40,6 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} import org.apache.spark.util.{RpcUtils, Utils} /** @@ -69,6 +70,7 @@ class SparkEnv ( val sparkFilesDir: String, val metricsSystem: MetricsSystem, val shuffleMemoryManager: ShuffleMemoryManager, + val executorMemoryManager: ExecutorMemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { @@ -382,6 +384,15 @@ object SparkEnv extends Logging { new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) + val executorMemoryManager: ExecutorMemoryManager = { + val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) { + MemoryAllocator.UNSAFE + } else { + MemoryAllocator.HEAP + } + new ExecutorMemoryManager(allocator) + } + val envInstance = new SparkEnv( executorId, rpcEnv, @@ -398,6 +409,7 @@ object SparkEnv extends Logging { sparkFilesDir, metricsSystem, shuffleMemoryManager, + executorMemoryManager, outputCommitCoordinator, conf) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 7d7fe1a446313..d09e17dea0911 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,6 +21,7 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener @@ -133,4 +134,9 @@ abstract class TaskContext extends Serializable { /** ::DeveloperApi:: */ @DeveloperApi def taskMetrics(): TaskMetrics + + /** + * Returns the manager for this task's managed memory. + */ + private[spark] def taskMemoryManager(): TaskMemoryManager } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 337c8e4ebebcd..b4d572cb52313 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.executor.TaskMetrics +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} import scala.collection.mutable.ArrayBuffer @@ -27,6 +28,7 @@ private[spark] class TaskContextImpl( val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index f57e215c3f2ed..dd1c48e6cb953 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -32,6 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ /** @@ -178,6 +179,7 @@ private[spark] class Executor( } override def run(): Unit = { + val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() @@ -190,6 +192,7 @@ private[spark] class Executor( val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task.setTaskMemoryManager(taskMemoryManager) // If this task has been killed before we deserialized it, let's quit now. Otherwise, // continue executing the task. @@ -206,7 +209,21 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + val value = try { + task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + } finally { + // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread; + // when changing this, make sure to update both copies. + val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (freedMemory > 0) { + val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" + if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + throw new SparkException(errMsg) + } else { + logError(errMsg) + } + } + } val taskFinish = System.currentTimeMillis() // If the task has been killed, let's fail it. diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8c4bff4e83afc..b7901c06a16e0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -34,6 +34,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -643,8 +644,15 @@ class DAGScheduler( try { val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, - attemptNumber = 0, runningLocally = true) + val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) + val taskContext = + new TaskContextImpl( + job.finalStage.id, + job.partitions(0), + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + runningLocally = true) TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) @@ -652,6 +660,16 @@ class DAGScheduler( } finally { taskContext.markTaskCompleted() TaskContext.unset() + // Note: this memory freeing logic is duplicated in Executor.run(); when changing this, + // make sure to update both copies. + val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (freedMemory > 0) { + if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes") + } else { + logError(s"Managed memory leak detected; size = $freedMemory bytes") + } + } } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index b09b19e2ac9e7..586d1e06204c1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils @@ -52,8 +53,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex * @return the result of the task */ final def run(taskAttemptId: Long, attemptNumber: Int): T = { - context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, - taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) + context = new TaskContextImpl( + stageId = stageId, + partitionId = partitionId, + taskAttemptId = taskAttemptId, + attemptNumber = attemptNumber, + taskMemoryManager = taskMemoryManager, + runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() @@ -68,6 +74,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex } } + private var taskMemoryManager: TaskMemoryManager = _ + + def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = { + this.taskMemoryManager = taskMemoryManager + } + def runTask(context: TaskContext): T def preferredLocations: Seq[TaskLocation] = Nil diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 8a4f2a08fe701..34ac9361d46c6 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1009,7 +1009,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 70529d9216591..668ddf9f5f0a9 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -65,7 +65,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index aea76c1adcc09..85eb2a1d07ba4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -176,7 +176,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0) + val tContext = new TaskContextImpl(0, 0, 0, 0, null) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 057e226916027..83ae8701243e5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 37b593b2c5f79..2080c432d77db 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0), + new TaskContextImpl(0, 0, 0, 0, null), transfer, blockManager, blocksByAddress, @@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, diff --git a/pom.xml b/pom.xml index 928f5d0f5efad..c85c5feeaf383 100644 --- a/pom.xml +++ b/pom.xml @@ -97,6 +97,7 @@ sql/catalyst sql/core sql/hive + unsafe assembly external/twitter external/flume @@ -1215,6 +1216,7 @@ false false true + true false diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 09b4976d10c26..b7dbcd9bc562a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -34,11 +34,11 @@ object BuildCommons { val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq, launcher) = + streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq", "launcher").map(ProjectRef(buildLocation, _)) + "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", @@ -159,7 +159,7 @@ object SparkBuild extends PomBuild { // TODO: Add Sql to mima checks // TODO: remove launcher from this list after 1.3. allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, launcher).contains(x)).foreach { + networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } @@ -496,6 +496,7 @@ object TestSettings { javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", + javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 3dea2ee76542f..5c322d032d474 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -50,6 +50,11 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${project.version} + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java new file mode 100644 index 0000000000000..299ff3728a6d9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import java.util.Arrays; +import java.util.Iterator; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.map.BytesToBytesMap; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. + * + * This map supports a maximum of 2 billion keys. + */ +public final class UnsafeFixedWidthAggregationMap { + + /** + * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the + * map, we copy this buffer and use it as the value. + */ + private final long[] emptyAggregationBuffer; + + private final StructType aggregationBufferSchema; + + private final StructType groupingKeySchema; + + /** + * Encodes grouping keys as UnsafeRows. + */ + private final UnsafeRowConverter groupingKeyToUnsafeRowConverter; + + /** + * A hashmap which maps from opaque bytearray keys to bytearray values. + */ + private final BytesToBytesMap map; + + /** + * Re-used pointer to the current aggregation buffer + */ + private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + + /** + * Scratch space that is used when encoding grouping keys into UnsafeRow format. + * + * By default, this is a 1MB array, but it will grow as necessary in case larger keys are + * encountered. + */ + private long[] groupingKeyConversionScratchSpace = new long[1024 / 8]; + + private final boolean enablePerfMetrics; + + /** + * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, + * false otherwise. + */ + public static boolean supportsGroupKeySchema(StructType schema) { + for (StructField field: schema.fields()) { + if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + + /** + * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given + * schema, false otherwise. + */ + public static boolean supportsAggregationBufferSchema(StructType schema) { + for (StructField field: schema.fields()) { + if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + + /** + * Create a new UnsafeFixedWidthAggregationMap. + * + * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) + * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. + * @param groupingKeySchema the schema of the grouping key, used for row conversion. + * @param memoryManager the memory manager used to allocate our Unsafe memory structures. + * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). + * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) + */ + public UnsafeFixedWidthAggregationMap( + Row emptyAggregationBuffer, + StructType aggregationBufferSchema, + StructType groupingKeySchema, + TaskMemoryManager memoryManager, + int initialCapacity, + boolean enablePerfMetrics) { + this.emptyAggregationBuffer = + convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); + this.aggregationBufferSchema = aggregationBufferSchema; + this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); + this.groupingKeySchema = groupingKeySchema; + this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.enablePerfMetrics = enablePerfMetrics; + } + + /** + * Convert a Java object row into an UnsafeRow, allocating it into a new long array. + */ + private static long[] convertToUnsafeRow(Row javaRow, StructType schema) { + final UnsafeRowConverter converter = new UnsafeRowConverter(schema); + final long[] unsafeRow = new long[converter.getSizeRequirement(javaRow)]; + final long writtenLength = + converter.writeRow(javaRow, unsafeRow, PlatformDependent.LONG_ARRAY_OFFSET); + assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!"; + return unsafeRow; + } + + /** + * Return the aggregation buffer for the current group. For efficiency, all calls to this method + * return the same object. + */ + public UnsafeRow getAggregationBuffer(Row groupingKey) { + final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); + // Make sure that the buffer is large enough to hold the key. If it's not, grow it: + if (groupingKeySize > groupingKeyConversionScratchSpace.length) { + // This new array will be initially zero, so there's no need to zero it out here + groupingKeyConversionScratchSpace = new long[groupingKeySize]; + } else { + // Zero out the buffer that's used to hold the current row. This is necessary in order + // to ensure that rows hash properly, since garbage data from the previous row could + // otherwise end up as padding in this row. As a performance optimization, we only zero out + // the portion of the buffer that we'll actually write to. + Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, 0); + } + final long actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( + groupingKey, + groupingKeyConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET); + assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; + + // Probe our map using the serialized key + final BytesToBytesMap.Location loc = map.lookup( + groupingKeyConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET, + groupingKeySize); + if (!loc.isDefined()) { + // This is the first time that we've seen this grouping key, so we'll insert a copy of the + // empty aggregation buffer into the map: + loc.putNewKey( + groupingKeyConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET, + groupingKeySize, + emptyAggregationBuffer, + PlatformDependent.LONG_ARRAY_OFFSET, + emptyAggregationBuffer.length + ); + } + + // Reset the pointer to point to the value that we just stored or looked up: + final MemoryLocation address = loc.getValueAddress(); + currentAggregationBuffer.pointTo( + address.getBaseObject(), + address.getBaseOffset(), + aggregationBufferSchema.length(), + aggregationBufferSchema + ); + return currentAggregationBuffer; + } + + /** + * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}. + */ + public static class MapEntry { + private MapEntry() { }; + public final UnsafeRow key = new UnsafeRow(); + public final UnsafeRow value = new UnsafeRow(); + } + + /** + * Returns an iterator over the keys and values in this map. + * + * For efficiency, each call returns the same object. + */ + public Iterator iterator() { + return new Iterator() { + + private final MapEntry entry = new MapEntry(); + private final Iterator mapLocationIterator = map.iterator(); + + @Override + public boolean hasNext() { + return mapLocationIterator.hasNext(); + } + + @Override + public MapEntry next() { + final BytesToBytesMap.Location loc = mapLocationIterator.next(); + final MemoryLocation keyAddress = loc.getKeyAddress(); + final MemoryLocation valueAddress = loc.getValueAddress(); + entry.key.pointTo( + keyAddress.getBaseObject(), + keyAddress.getBaseOffset(), + groupingKeySchema.length(), + groupingKeySchema + ); + entry.value.pointTo( + valueAddress.getBaseObject(), + valueAddress.getBaseOffset(), + aggregationBufferSchema.length(), + aggregationBufferSchema + ); + return entry; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Free the unsafe memory associated with this map. + */ + public void free() { + map.free(); + } + + @SuppressWarnings("UseOfSystemOutOrSystemErr") + public void printPerfMetrics() { + if (!enablePerfMetrics) { + throw new IllegalStateException("Perf metrics not enabled"); + } + System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup()); + System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); + System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); + System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + } + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java new file mode 100644 index 0000000000000..0a358ed408aa1 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import scala.collection.Map; +import scala.collection.Seq; +import scala.collection.mutable.ArraySeq; + +import javax.annotation.Nullable; +import java.math.BigDecimal; +import java.sql.Date; +import java.util.*; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.DataType; +import static org.apache.spark.sql.types.DataTypes.*; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.UTF8String; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.bitset.BitSetMethods; + +/** + * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. + * + * Each tuple has three parts: [null bit set] [values] [variable length portion] + * + * The bit set is used for null tracking and is aligned to 8-byte word boundaries. It stores + * one bit per field. + * + * In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length + * primitive types, such as long, double, or int, we store the value directly in the word. For + * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the + * base address of the row) that points to the beginning of the variable-length field. + * + * Instances of `UnsafeRow` act as pointers to row data stored in this format. + */ +public final class UnsafeRow implements MutableRow { + + private Object baseObject; + private long baseOffset; + + Object getBaseObject() { return baseObject; } + long getBaseOffset() { return baseOffset; } + + /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ + private int numFields; + + /** The width of the null tracking bit set, in bytes */ + private int bitSetWidthInBytes; + /** + * This optional schema is required if you want to call generic get() and set() methods on + * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE() + * methods. This should be removed after the planned InternalRow / Row split; right now, it's only + * needed by the generic get() method, which is only called internally by code that accesses + * UTF8String-typed columns. + */ + @Nullable + private StructType schema; + + private long getFieldOffset(int ordinal) { + return baseOffset + bitSetWidthInBytes + ordinal * 8L; + } + + public static int calculateBitSetWidthInBytes(int numFields) { + return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; + } + + /** + * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) + */ + public static final Set settableFieldTypes; + + /** + * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). + */ + public static final Set readableFieldTypes; + + static { + settableFieldTypes = Collections.unmodifiableSet( + new HashSet( + Arrays.asList(new DataType[] { + NullType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType + }))); + + // We support get() on a superset of the types for which we support set(): + final Set _readableFieldTypes = new HashSet( + Arrays.asList(new DataType[]{ + StringType + })); + _readableFieldTypes.addAll(settableFieldTypes); + readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); + } + + /** + * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, + * since the value returned by this constructor is equivalent to a null pointer. + */ + public UnsafeRow() { } + + /** + * Update this UnsafeRow to point to different backing data. + * + * @param baseObject the base object + * @param baseOffset the offset within the base object + * @param numFields the number of fields in this row + * @param schema an optional schema; this is necessary if you want to call generic get() or set() + * methods on this row, but is optional if the caller will only use type-specific + * getTYPE() and setTYPE() methods. + */ + public void pointTo( + Object baseObject, + long baseOffset, + int numFields, + @Nullable StructType schema) { + assert numFields >= 0 : "numFields should >= 0"; + assert schema == null || schema.fields().length == numFields; + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); + this.baseObject = baseObject; + this.baseOffset = baseOffset; + this.numFields = numFields; + this.schema = schema; + } + + private void assertIndexIsValid(int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < numFields : "index (" + index + ") should <= " + numFields; + } + + @Override + public void setNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.set(baseObject, baseOffset, i); + // To preserve row equality, zero out the value when setting the column to null. + // Since this row does does not currently support updates to variable-length values, we don't + // have to worry about zeroing out that data. + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); + } + + private void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); + } + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setInt(int ordinal, int value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setLong(int ordinal, long value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setDouble(int ordinal, double value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setShort(int ordinal, short value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setByte(int ordinal, byte value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setFloat(int ordinal, float value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setString(int ordinal, String value) { + throw new UnsupportedOperationException(); + } + + @Override + public int size() { + return numFields; + } + + @Override + public int length() { + return size(); + } + + @Override + public StructType schema() { + return schema; + } + + @Override + public Object apply(int i) { + return get(i); + } + + @Override + public Object get(int i) { + assertIndexIsValid(i); + assert (schema != null) : "Schema must be defined when calling generic get() method"; + final DataType dataType = schema.fields()[i].dataType(); + // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic + // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to + // separate the internal and external row interfaces, then internal code can fetch strings via + // a new getUTF8String() method and we'll be able to remove this method. + if (isNullAt(i)) { + return null; + } else if (dataType == StringType) { + return getUTF8String(i); + } else { + throw new UnsupportedOperationException(); + } + } + + @Override + public boolean isNullAt(int i) { + assertIndexIsValid(i); + return BitSetMethods.isSet(baseObject, baseOffset, i); + } + + @Override + public boolean getBoolean(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(i)); + } + + @Override + public byte getByte(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(i)); + } + + @Override + public short getShort(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(i)); + } + + @Override + public int getInt(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(i)); + } + + @Override + public long getLong(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + } + + @Override + public float getFloat(int i) { + assertIndexIsValid(i); + if (isNullAt(i)) { + return Float.NaN; + } else { + return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i)); + } + } + + @Override + public double getDouble(int i) { + assertIndexIsValid(i); + if (isNullAt(i)) { + return Float.NaN; + } else { + return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i)); + } + } + + public UTF8String getUTF8String(int i) { + assertIndexIsValid(i); + final UTF8String str = new UTF8String(); + final long offsetToStringSize = getLong(i); + final int stringSizeInBytes = + (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize); + final byte[] strBytes = new byte[stringSizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + stringSizeInBytes + ); + str.set(strBytes); + return str; + } + + @Override + public String getString(int i) { + return getUTF8String(i).toString(); + } + + @Override + public BigDecimal getDecimal(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Seq getSeq(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public List getList(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Map getMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { + throw new UnsupportedOperationException(); + } + + @Override + public java.util.Map getJavaMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Row getStruct(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(String fieldName) { + throw new UnsupportedOperationException(); + } + + @Override + public int fieldIndex(String name) { + throw new UnsupportedOperationException(); + } + + @Override + public Row copy() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean anyNull() { + return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); + } + + @Override + public Seq toSeq() { + final ArraySeq values = new ArraySeq(numFields); + for (int fieldNumber = 0; fieldNumber < numFields; fieldNumber++) { + values.update(fieldNumber, get(fieldNumber)); + } + return values; + } + + @Override + public String toString() { + return mkString("[", ",", "]"); + } + + @Override + public String mkString() { + return toSeq().mkString(); + } + + @Override + public String mkString(String sep) { + return toSeq().mkString(sep); + } + + @Override + public String mkString(String start, String sep, String end) { + return toSeq().mkString(start, sep, end); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala new file mode 100644 index 0000000000000..5b2c8572784bd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.array.ByteArrayMethods + +/** + * Converts Rows into UnsafeRow format. This class is NOT thread-safe. + * + * @param fieldTypes the data types of the row's columns. + */ +class UnsafeRowConverter(fieldTypes: Array[DataType]) { + + def this(schema: StructType) { + this(schema.fields.map(_.dataType)) + } + + /** Re-used pointer to the unsafe row being written */ + private[this] val unsafeRow = new UnsafeRow() + + /** Functions for encoding each column */ + private[this] val writers: Array[UnsafeColumnWriter] = { + fieldTypes.map(t => UnsafeColumnWriter.forType(t)) + } + + /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */ + private[this] val fixedLengthSize: Int = + (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) + + /** + * Compute the amount of space, in bytes, required to encode the given row. + */ + def getSizeRequirement(row: Row): Int = { + var fieldNumber = 0 + var variableLengthFieldSize: Int = 0 + while (fieldNumber < writers.length) { + if (!row.isNullAt(fieldNumber)) { + variableLengthFieldSize += writers(fieldNumber).getSize(row, fieldNumber) + } + fieldNumber += 1 + } + fixedLengthSize + variableLengthFieldSize + } + + /** + * Convert the given row into UnsafeRow format. + * + * @param row the row to convert + * @param baseObject the base object of the destination address + * @param baseOffset the base offset of the destination address + * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. + */ + def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + var fieldNumber = 0 + var appendCursor: Int = fixedLengthSize + while (fieldNumber < writers.length) { + if (row.isNullAt(fieldNumber)) { + unsafeRow.setNullAt(fieldNumber) + } else { + appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) + } + fieldNumber += 1 + } + appendCursor + } + +} + +/** + * Function for writing a column into an UnsafeRow. + */ +private abstract class UnsafeColumnWriter { + /** + * Write a value into an UnsafeRow. + * + * @param source the row being converted + * @param target a pointer to the converted unsafe row + * @param column the column to write + * @param appendCursor the offset from the start of the unsafe row to the end of the row; + * used for calculating where variable-length data should be written + * @return the number of variable-length bytes written + */ + def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int + + /** + * Return the number of bytes that are needed to write this variable-length value. + */ + def getSize(source: Row, column: Int): Int +} + +private object UnsafeColumnWriter { + + def forType(dataType: DataType): UnsafeColumnWriter = { + dataType match { + case NullType => NullUnsafeColumnWriter + case BooleanType => BooleanUnsafeColumnWriter + case ByteType => ByteUnsafeColumnWriter + case ShortType => ShortUnsafeColumnWriter + case IntegerType => IntUnsafeColumnWriter + case LongType => LongUnsafeColumnWriter + case FloatType => FloatUnsafeColumnWriter + case DoubleType => DoubleUnsafeColumnWriter + case StringType => StringUnsafeColumnWriter + case t => + throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") + } + } +} + +// ------------------------------------------------------------------------------------------------ + +private object NullUnsafeColumnWriter extends NullUnsafeColumnWriter +private object BooleanUnsafeColumnWriter extends BooleanUnsafeColumnWriter +private object ByteUnsafeColumnWriter extends ByteUnsafeColumnWriter +private object ShortUnsafeColumnWriter extends ShortUnsafeColumnWriter +private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter +private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter +private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter +private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter +private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter + +private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { + // Primitives don't write to the variable-length region: + def getSize(sourceRow: Row, column: Int): Int = 0 +} + +private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setNullAt(column) + 0 + } +} + +private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setBoolean(column, source.getBoolean(column)) + 0 + } +} + +private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setByte(column, source.getByte(column)) + 0 + } +} + +private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setShort(column, source.getShort(column)) + 0 + } +} + +private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setInt(column, source.getInt(column)) + 0 + } +} + +private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setLong(column, source.getLong(column)) + 0 + } +} + +private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setFloat(column, source.getFloat(column)) + 0 + } +} + +private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setDouble(column, source.getDouble(column)) + 0 + } +} + +private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { + def getSize(source: Row, column: Int): Int = { + val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length + 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } + + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + val value = source.get(column).asInstanceOf[UTF8String] + val baseObject = target.getBaseObject + val baseOffset = target.getBaseOffset + val numBytes = value.getBytes.length + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) + PlatformDependent.copyMemory( + value.getBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + appendCursor + 8, + numBytes + ) + target.setLong(column, appendCursor) + 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala new file mode 100644 index 0000000000000..7a19e511eb8b5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} +import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} + +import org.apache.spark.sql.types._ + +class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach { + + import UnsafeFixedWidthAggregationMap._ + + private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) + private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) + private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0)) + + private var memoryManager: TaskMemoryManager = null + + override def beforeEach(): Unit = { + memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + } + + override def afterEach(): Unit = { + if (memoryManager != null) { + memoryManager.cleanUpAllAllocatedMemory() + memoryManager = null + } + } + + test("supported schemas") { + assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) + assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) + + assert( + !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + assert( + !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + } + + test("empty map") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + memoryManager, + 1024, // initial capacity + false // disable perf metrics + ) + assert(!map.iterator().hasNext) + map.free() + } + + test("updating values for a single key") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + memoryManager, + 1024, // initial capacity + false // disable perf metrics + ) + val groupKey = new GenericRow(Array[Any](UTF8String("cats"))) + + // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) + map.getAggregationBuffer(groupKey) + val iter = map.iterator() + val entry = iter.next() + assert(!iter.hasNext) + entry.key.getString(0) should be ("cats") + entry.value.getInt(0) should be (0) + + // Modifications to rows retrieved from the map should update the values in the map + entry.value.setInt(0, 42) + map.getAggregationBuffer(groupKey).getInt(0) should be (42) + + map.free() + } + + test("inserting large random keys") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + memoryManager, + 128, // initial capacity + false // disable perf metrics + ) + val rand = new Random(42) + val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet + groupKeys.foreach { keyString => + map.getAggregationBuffer(new GenericRow(Array[Any](UTF8String(keyString)))) + } + val seenKeys: Set[String] = map.iterator().asScala.map { entry => + entry.key.getString(0) + }.toSet + seenKeys.size should be (groupKeys.size) + seenKeys should be (groupKeys) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala new file mode 100644 index 0000000000000..3a60c7fd32675 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.Arrays + +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.array.ByteArrayMethods + +class UnsafeRowConverterSuite extends FunSuite with Matchers { + + test("basic conversion with only primitive types") { + val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.setLong(1, 1) + row.setInt(2, 2) + + val sizeRequired: Int = converter.getSizeRequirement(row) + sizeRequired should be (8 + (3 * 8)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.getLong(0) should be (0) + unsafeRow.getLong(1) should be (1) + unsafeRow.getInt(2) should be (2) + } + + test("basic conversion with primitive and string types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.setString(1, "Hello") + row.setString(2, "World") + + val sizeRequired: Int = converter.getSizeRequirement(row) + sizeRequired should be (8 + (8 * 3) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.getLong(0) should be (0) + unsafeRow.getString(1) should be ("Hello") + unsafeRow.getString(2) should be ("World") + } + + test("null handling") { + val fieldTypes: Array[DataType] = Array( + NullType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType) + val converter = new UnsafeRowConverter(fieldTypes) + + val rowWithAllNullColumns: Row = { + val r = new SpecificMutableRow(fieldTypes) + for (i <- 0 to fieldTypes.length - 1) { + r.setNullAt(i) + } + r + } + + val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) + val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow( + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val createdFromNull = new UnsafeRow() + createdFromNull.pointTo( + createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + for (i <- 0 to fieldTypes.length - 1) { + assert(createdFromNull.isNullAt(i)) + } + createdFromNull.getBoolean(1) should be (false) + createdFromNull.getByte(2) should be (0) + createdFromNull.getShort(3) should be (0) + createdFromNull.getInt(4) should be (0) + createdFromNull.getLong(5) should be (0) + assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) + assert(java.lang.Double.isNaN(createdFromNull.getFloat(7))) + + // If we have an UnsafeRow with columns that are initially non-null and we null out those + // columns, then the serialized row representation should be identical to what we would get by + // creating an entirely null row via the converter + val rowWithNoNullColumns: Row = { + val r = new SpecificMutableRow(fieldTypes) + r.setNullAt(0) + r.setBoolean(1, false) + r.setByte(2, 20) + r.setShort(3, 30) + r.setInt(4, 400) + r.setLong(5, 500) + r.setFloat(6, 600) + r.setDouble(7, 700) + r + } + val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + converter.writeRow( + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + val setToNullAfterCreation = new UnsafeRow() + setToNullAfterCreation.pointTo( + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + + setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0)) + setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1)) + setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2)) + setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3)) + setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4)) + setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5)) + setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6)) + setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7)) + + for (i <- 0 to fieldTypes.length - 1) { + setToNullAfterCreation.setNullAt(i) + } + assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer)) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4fc5de7e824fe..2fa602a6082dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -30,6 +30,7 @@ private[spark] object SQLConf { val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" + val UNSAFE_ENABLED = "spark.sql.unsafe.enabled" val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" @@ -149,6 +150,14 @@ private[sql] class SQLConf extends Serializable { */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + /** + * When set to true, Spark SQL will use managed memory for certain operations. This option only + * takes effect if codegen is enabled. + * + * Defaults to false as this feature is currently experimental. + */ + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a279b0f07c38a..bd4a55fa132fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1011,6 +1011,8 @@ class SQLContext(@transient val sparkContext: SparkContext) def codegenEnabled: Boolean = self.conf.codegenEnabled + def unsafeEnabled: Boolean = self.conf.unsafeEnabled + def numPartitions: Int = self.conf.numShufflePartitions def strategies: Seq[Strategy] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b1ef6556de1e9..5d9f202681045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.trees._ @@ -40,6 +41,7 @@ case class AggregateEvaluation( * ensure all values where `groupingExpressions` are equal are present. * @param groupingExpressions expressions that are evaluated to determine grouping. * @param aggregateExpressions expressions that are computed for each group. + * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used. * @param child the input data source. */ @DeveloperApi @@ -47,6 +49,7 @@ case class GeneratedAggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], + unsafeEnabled: Boolean, child: SparkPlan) extends UnaryNode { @@ -225,6 +228,21 @@ case class GeneratedAggregate( case e: Expression if groupMap.contains(e) => groupMap(e) }) + val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) + + val groupKeySchema: StructType = { + val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => + // This is a dummy field name + StructField(idx.toString, expr.dataType, expr.nullable) + } + StructType(fields) + } + + val schemaSupportsUnsafe: Boolean = { + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) + } + child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -265,7 +283,49 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) + } else if (unsafeEnabled && schemaSupportsUnsafe) { + log.info("Using Unsafe-based aggregator") + val aggregationMap = new UnsafeFixedWidthAggregationMap( + newAggregationBuffer(EmptyRow), + aggregationBufferSchema, + groupKeySchema, + TaskContext.get.taskMemoryManager(), + 1024 * 16, // initial capacity + false // disable tracking of performance metrics + ) + + while (iter.hasNext) { + val currentRow: Row = iter.next() + val groupKey: Row = groupProjection(currentRow) + val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) + updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) + } + + new Iterator[Row] { + private[this] val mapIterator = aggregationMap.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext: Boolean = mapIterator.hasNext + + def next(): Row = { + val entry = mapIterator.next() + val result = resultProjection(joinedRow(entry.key, entry.value)) + if (hasNext) { + result + } else { + // This is the last element in the iterator, so let's free the buffer. Before we do, + // though, we need to make a defensive copy of the result so that we don't return an + // object that might contain dangling pointers to the freed memory + val resultCopy = result.copy() + aggregationMap.free() + resultCopy + } + } + } } else { + if (unsafeEnabled) { + log.info("Not using Unsafe-based aggregator because it is not supported for this schema") + } val buffers = new java.util.HashMap[Row, MutableRow]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3a0a6c86700a8..af58911cc0e6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -136,10 +136,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { partial = false, namedGroupingAttributes, rewrittenAggregateExpressions, + unsafeEnabled, execution.GeneratedAggregate( partial = true, groupingExpressions, partialComputation, + unsafeEnabled, planLater(child))) :: Nil // Cases where some aggregate can not be codegened diff --git a/unsafe/pom.xml b/unsafe/pom.xml new file mode 100644 index 0000000000000..8901d77591932 --- /dev/null +++ b/unsafe/pom.xml @@ -0,0 +1,69 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.4.0-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-unsafe_2.10 + jar + Spark Project Unsafe + http://spark.apache.org/ + + unsafe + + + + + + + com.google.code.findbugs + jsr305 + + + + + org.slf4j + slf4j-api + provided + + + + + junit + junit + test + + + com.novocode + junit-interface + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java new file mode 100644 index 0000000000000..91b2f9aa43921 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import java.lang.reflect.Field; + +import sun.misc.Unsafe; + +public final class PlatformDependent { + + public static final Unsafe UNSAFE; + + public static final int BYTE_ARRAY_OFFSET; + + public static final int INT_ARRAY_OFFSET; + + public static final int LONG_ARRAY_OFFSET; + + public static final int DOUBLE_ARRAY_OFFSET; + + /** + * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to + * allow safepoint polling during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + static { + sun.misc.Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (sun.misc.Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + unsafe = null; + } + UNSAFE = unsafe; + + if (UNSAFE != null) { + BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + INT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(int[].class); + LONG_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(long[].class); + DOUBLE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(double[].class); + } else { + BYTE_ARRAY_OFFSET = 0; + INT_ARRAY_OFFSET = 0; + LONG_ARRAY_OFFSET = 0; + DOUBLE_ARRAY_OFFSET = 0; + } + } + + static public void copyMemory( + Object src, + long srcOffset, + Object dst, + long dstOffset, + long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + UNSAFE.throwException(t); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java new file mode 100644 index 0000000000000..53eadf96a6b52 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.apache.spark.unsafe.PlatformDependent; + +public class ByteArrayMethods { + + private ByteArrayMethods() { + // Private constructor, since this class only contains static methods. + } + + public static int roundNumberOfBytesToNearestWord(int numBytes) { + int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + if (remainder == 0) { + return numBytes; + } else { + return numBytes + (8 - remainder); + } + } + + /** + * Optimized byte array equality check for 8-byte-word-aligned byte arrays. + * @return true if the arrays are equal, false otherwise + */ + public static boolean wordAlignedArrayEquals( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset, + long arrayLengthInBytes) { + for (int i = 0; i < arrayLengthInBytes; i += 8) { + final long left = + PlatformDependent.UNSAFE.getLong(leftBaseObject, leftBaseOffset + i); + final long right = + PlatformDependent.UNSAFE.getLong(rightBaseObject, rightBaseOffset + i); + if (left != right) return false; + } + return true; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java new file mode 100644 index 0000000000000..18d1f0d2d7eb2 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * An array of long values. Compared with native JVM arrays, this: + *
    + *
  • supports using both in-heap and off-heap memory
  • + *
  • has no bound checking, and thus can crash the JVM process when assert is turned off
  • + *
+ */ +public final class LongArray { + + // This is a long so that we perform long multiplications when computing offsets. + private static final long WIDTH = 8; + + private final MemoryBlock memory; + private final Object baseObj; + private final long baseOffset; + + private final long length; + + public LongArray(MemoryBlock memory) { + assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")"; + assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements"; + this.memory = memory; + this.baseObj = memory.getBaseObject(); + this.baseOffset = memory.getBaseOffset(); + this.length = memory.size() / WIDTH; + } + + public MemoryBlock memoryBlock() { + return memory; + } + + /** + * Returns the number of elements this array can hold. + */ + public long size() { + return length; + } + + /** + * Sets the value at position {@code index}. + */ + public void set(int index, long value) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < length : "index (" + index + ") should < length (" + length + ")"; + PlatformDependent.UNSAFE.putLong(baseObj, baseOffset + index * WIDTH, value); + } + + /** + * Returns the value at position {@code index}. + */ + public long get(int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < length : "index (" + index + ") should < length (" + length + ")"; + return PlatformDependent.UNSAFE.getLong(baseObj, baseOffset + index * WIDTH); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java new file mode 100644 index 0000000000000..f72e07fce92fd --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * A fixed size uncompressed bit set backed by a {@link LongArray}. + * + * Each bit occupies exactly one bit of storage. + */ +public final class BitSet { + + /** A long array for the bits. */ + private final LongArray words; + + /** Length of the long array. */ + private final int numWords; + + private final Object baseObject; + private final long baseOffset; + + /** + * Creates a new {@link BitSet} using the specified memory block. Size of the memory block must be + * multiple of 8 bytes (i.e. 64 bits). + */ + public BitSet(MemoryBlock memory) { + words = new LongArray(memory); + assert (words.size() <= Integer.MAX_VALUE); + numWords = (int) words.size(); + baseObject = words.memoryBlock().getBaseObject(); + baseOffset = words.memoryBlock().getBaseOffset(); + } + + public MemoryBlock memoryBlock() { + return words.memoryBlock(); + } + + /** + * Returns the number of bits in this {@code BitSet}. + */ + public long capacity() { + return numWords * 64; + } + + /** + * Sets the bit at the specified index to {@code true}. + */ + public void set(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + BitSetMethods.set(baseObject, baseOffset, index); + } + + /** + * Sets the bit at the specified index to {@code false}. + */ + public void unset(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + BitSetMethods.unset(baseObject, baseOffset, index); + } + + /** + * Returns {@code true} if the bit is set at the specified index. + */ + public boolean isSet(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + return BitSetMethods.isSet(baseObject, baseOffset, index); + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then {@code -1} is returned. + *

+ * To iterate over the true bits in a BitSet, use the following loop: + *

+   * 
+   *  for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
+   *    // operate on index i here
+   *  }
+   * 
+   * 
+ * + * @param fromIndex the index to start checking from (inclusive) + * @return the index of the next set bit, or -1 if there is no such bit + */ + public int nextSetBit(int fromIndex) { + return BitSetMethods.nextSetBit(baseObject, baseOffset, fromIndex, numWords); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java new file mode 100644 index 0000000000000..f30626d8f4317 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Methods for working with fixed-size uncompressed bitsets. + * + * We assume that the bitset data is word-aligned (that is, a multiple of 8 bytes in length). + * + * Each bit occupies exactly one bit of storage. + */ +public final class BitSetMethods { + + private static final long WORD_SIZE = 8; + + private BitSetMethods() { + // Make the default constructor private, since this only holds static methods. + } + + /** + * Sets the bit at the specified index to {@code true}. + */ + public static void set(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word | mask); + } + + /** + * Sets the bit at the specified index to {@code false}. + */ + public static void unset(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word & ~mask); + } + + /** + * Returns {@code true} if the bit is set at the specified index. + */ + public static boolean isSet(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + return (word & mask) != 0; + } + + /** + * Returns {@code true} if any bit is set. + */ + public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInBytes) { + for (int i = 0; i <= bitSetWidthInBytes; i++) { + if (PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + i) != 0) { + return true; + } + } + return false; + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then {@code -1} is returned. + *

+ * To iterate over the true bits in a BitSet, use the following loop: + *

+   * 
+   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
+   *    // operate on index i here
+   *  }
+   * 
+   * 
+ * + * @param fromIndex the index to start checking from (inclusive) + * @param bitsetSizeInWords the size of the bitset, measured in 8-byte words + * @return the index of the next set bit, or -1 if there is no such bit + */ + public static int nextSetBit( + Object baseObject, + long baseOffset, + int fromIndex, + int bitsetSizeInWords) { + int wi = fromIndex >> 6; + if (wi >= bitsetSizeInWords) { + return -1; + } + + // Try to find the next set bit in the current word + final int subIndex = fromIndex & 0x3f; + long word = + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; + if (word != 0) { + return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word); + } + + // Find the next set bit in the rest of the words + wi += 1; + while (wi < bitsetSizeInWords) { + word = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE); + if (word != 0) { + return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word); + } + wi += 1; + } + + return -1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java new file mode 100644 index 0000000000000..85cd02469adb7 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.hash; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. + */ +public final class Murmur3_x86_32 { + private static final int C1 = 0xcc9e2d51; + private static final int C2 = 0x1b873593; + + private final int seed; + + public Murmur3_x86_32(int seed) { + this.seed = seed; + } + + @Override + public String toString() { + return "Murmur3_32(seed=" + seed + ")"; + } + + public int hashInt(int input) { + int k1 = mixK1(input); + int h1 = mixH1(seed, k1); + + return fmix(h1, 4); + } + + public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; + int h1 = seed; + for (int offset = 0; offset < lengthInBytes; offset += 4) { + int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return fmix(h1, lengthInBytes); + } + + public int hashLong(long input) { + int low = (int) input; + int high = (int) (input >>> 32); + + int k1 = mixK1(low); + int h1 = mixH1(seed, k1); + + k1 = mixK1(high); + h1 = mixH1(h1, k1); + + return fmix(h1, 8); + } + + private static int mixK1(int k1) { + k1 *= C1; + k1 = Integer.rotateLeft(k1, 15); + k1 *= C2; + return k1; + } + + private static int mixH1(int h1, int k1) { + h1 ^= k1; + h1 = Integer.rotateLeft(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + return h1; + } + + // Finalization mix - force all bits of a hash block to avalanche + private static int fmix(int h1, int length) { + h1 ^= length; + h1 ^= h1 >>> 16; + h1 *= 0x85ebca6b; + h1 ^= h1 >>> 13; + h1 *= 0xc2b2ae35; + h1 ^= h1 >>> 16; + return h1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java new file mode 100644 index 0000000000000..85b64c0833803 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -0,0 +1,549 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import java.lang.Override; +import java.lang.UnsupportedOperationException; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.spark.unsafe.*; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.bitset.BitSet; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.*; + +/** + * An append-only hash map where keys and values are contiguous regions of bytes. + *

+ * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers, + * which is guaranteed to exhaust the space. + *

+ * The map can support up to 2^31 keys because we use 32 bit MurmurHash. If the key cardinality is + * higher than this, you should probably be using sorting instead of hashing for better cache + * locality. + *

+ * This class is not thread safe. + */ +public final class BytesToBytesMap { + + private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0); + + private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; + + private final TaskMemoryManager memoryManager; + + /** + * A linked list for tracking all allocated data pages so that we can free all of our memory. + */ + private final List dataPages = new LinkedList(); + + /** + * The data page that will be used to store keys and values for new hashtable entries. When this + * page becomes full, a new page will be allocated and this pointer will change to point to that + * new page. + */ + private MemoryBlock currentDataPage = null; + + /** + * Offset into `currentDataPage` that points to the location where new data can be inserted into + * the page. + */ + private long pageCursor = 0; + + /** + * The size of the data pages that hold key and value data. Map entries cannot span multiple + * pages, so this limits the maximum entry size. + */ + private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + + // This choice of page table size and page size means that we can address up to 500 gigabytes + // of memory. + + /** + * A single array to store the key and value. + * + * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i}, + * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode. + */ + private LongArray longArray; + // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode + // and exploit word-alignment to use fewer bits to hold the address. This might let us store + // only one long per map entry, increasing the chance that this array will fit in cache at the + // expense of maybe performing more lookups if we have hash collisions. Say that we stored only + // 27 bits of the hashcode and 37 bits of the address. 37 bits is enough to address 1 terabyte + // of RAM given word-alignment. If we use 13 bits of this for our page table, that gives us a + // maximum page size of 2^24 * 8 = ~134 megabytes per page. This change will require us to store + // full base addresses in the page table for off-heap mode so that we can reconstruct the full + // absolute memory addresses. + + /** + * A {@link BitSet} used to track location of the map where the key is set. + * Size of the bitset should be half of the size of the long array. + */ + private BitSet bitset; + + private final double loadFactor; + + /** + * Number of keys defined in the map. + */ + private int size; + + /** + * The map will be expanded once the number of keys exceeds this threshold. + */ + private int growthThreshold; + + /** + * Mask for truncating hashcodes so that they do not exceed the long array's size. + * This is a strength reduction optimization; we're essentially performing a modulus operation, + * but doing so with a bitmask because this is a power-of-2-sized hash map. + */ + private int mask; + + /** + * Return value of {@link BytesToBytesMap#lookup(Object, long, int)}. + */ + private final Location loc; + + private final boolean enablePerfMetrics; + + private long timeSpentResizingNs = 0; + + private long numProbes = 0; + + private long numKeyLookups = 0; + + private long numHashCollisions = 0; + + public BytesToBytesMap( + TaskMemoryManager memoryManager, + int initialCapacity, + double loadFactor, + boolean enablePerfMetrics) { + this.memoryManager = memoryManager; + this.loadFactor = loadFactor; + this.loc = new Location(); + this.enablePerfMetrics = enablePerfMetrics; + allocate(initialCapacity); + } + + public BytesToBytesMap(TaskMemoryManager memoryManager, int initialCapacity) { + this(memoryManager, initialCapacity, 0.70, false); + } + + public BytesToBytesMap( + TaskMemoryManager memoryManager, + int initialCapacity, + boolean enablePerfMetrics) { + this(memoryManager, initialCapacity, 0.70, enablePerfMetrics); + } + + /** + * Returns the number of keys defined in the map. + */ + public int size() { return size; } + + /** + * Returns an iterator for iterating over the entries of this map. + * + * For efficiency, all calls to `next()` will return the same {@link Location} object. + * + * If any other lookups or operations are performed on this map while iterating over it, including + * `lookup()`, the behavior of the returned iterator is undefined. + */ + public Iterator iterator() { + return new Iterator() { + + private int nextPos = bitset.nextSetBit(0); + + @Override + public boolean hasNext() { + return nextPos != -1; + } + + @Override + public Location next() { + final int pos = nextPos; + nextPos = bitset.nextSetBit(nextPos + 1); + return loc.with(pos, 0, true); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Looks up a key, and return a {@link Location} handle that can be used to test existence + * and read/write values. + * + * This function always return the same {@link Location} instance to avoid object allocation. + */ + public Location lookup( + Object keyBaseObject, + long keyBaseOffset, + int keyRowLengthBytes) { + if (enablePerfMetrics) { + numKeyLookups++; + } + final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes); + int pos = hashcode & mask; + int step = 1; + while (true) { + if (enablePerfMetrics) { + numProbes++; + } + if (!bitset.isSet(pos)) { + // This is a new key. + return loc.with(pos, hashcode, false); + } else { + long stored = longArray.get(pos * 2 + 1); + if ((int) (stored) == hashcode) { + // Full hash code matches. Let's compare the keys for equality. + loc.with(pos, hashcode, true); + if (loc.getKeyLength() == keyRowLengthBytes) { + final MemoryLocation keyAddress = loc.getKeyAddress(); + final Object storedKeyBaseObject = keyAddress.getBaseObject(); + final long storedKeyBaseOffset = keyAddress.getBaseOffset(); + final boolean areEqual = ByteArrayMethods.wordAlignedArrayEquals( + keyBaseObject, + keyBaseOffset, + storedKeyBaseObject, + storedKeyBaseOffset, + keyRowLengthBytes + ); + if (areEqual) { + return loc; + } else { + if (enablePerfMetrics) { + numHashCollisions++; + } + } + } + } + } + pos = (pos + step) & mask; + step++; + } + } + + /** + * Handle returned by {@link BytesToBytesMap#lookup(Object, long, int)} function. + */ + public final class Location { + /** An index into the hash map's Long array */ + private int pos; + /** True if this location points to a position where a key is defined, false otherwise */ + private boolean isDefined; + /** + * The hashcode of the most recent key passed to + * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to + * avoid re-hashing the key when storing a value for that key. + */ + private int keyHashcode; + private final MemoryLocation keyMemoryLocation = new MemoryLocation(); + private final MemoryLocation valueMemoryLocation = new MemoryLocation(); + private int keyLength; + private int valueLength; + + private void updateAddressesAndSizes(long fullKeyAddress) { + final Object page = memoryManager.getPage(fullKeyAddress); + final long keyOffsetInPage = memoryManager.getOffsetInPage(fullKeyAddress); + long position = keyOffsetInPage; + keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position); + position += 8; // word used to store the key size + keyMemoryLocation.setObjAndOffset(page, position); + position += keyLength; + valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position); + position += 8; // word used to store the key size + valueMemoryLocation.setObjAndOffset(page, position); + } + + Location with(int pos, int keyHashcode, boolean isDefined) { + this.pos = pos; + this.isDefined = isDefined; + this.keyHashcode = keyHashcode; + if (isDefined) { + final long fullKeyAddress = longArray.get(pos * 2); + updateAddressesAndSizes(fullKeyAddress); + } + return this; + } + + /** + * Returns true if the key is defined at this position, and false otherwise. + */ + public boolean isDefined() { + return isDefined; + } + + /** + * Returns the address of the key defined at this position. + * This points to the first byte of the key data. + * Unspecified behavior if the key is not defined. + * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + */ + public MemoryLocation getKeyAddress() { + assert (isDefined); + return keyMemoryLocation; + } + + /** + * Returns the length of the key defined at this position. + * Unspecified behavior if the key is not defined. + */ + public int getKeyLength() { + assert (isDefined); + return keyLength; + } + + /** + * Returns the address of the value defined at this position. + * This points to the first byte of the value data. + * Unspecified behavior if the key is not defined. + * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + */ + public MemoryLocation getValueAddress() { + assert (isDefined); + return valueMemoryLocation; + } + + /** + * Returns the length of the value defined at this position. + * Unspecified behavior if the key is not defined. + */ + public int getValueLength() { + assert (isDefined); + return valueLength; + } + + /** + * Store a new key and value. This method may only be called once for a given key; if you want + * to update the value associated with a key, then you can directly manipulate the bytes stored + * at the value address. + *

+ * It is only valid to call this method immediately after calling `lookup()` using the same key. + *

+ * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` + * will return information on the data stored by this `putNewKey` call. + *

+ * As an example usage, here's the proper way to store a new key: + *

+ *

+     *   Location loc = map.lookup(keyBaseOffset, keyBaseObject, keyLengthInBytes);
+     *   if (!loc.isDefined()) {
+     *     loc.putNewKey(keyBaseOffset, keyBaseObject, keyLengthInBytes, ...)
+     *   }
+     * 
+ *

+ * Unspecified behavior if the key is not defined. + */ + public void putNewKey( + Object keyBaseObject, + long keyBaseOffset, + int keyLengthBytes, + Object valueBaseObject, + long valueBaseOffset, + int valueLengthBytes) { + assert (!isDefined) : "Can only set value once for a key"; + isDefined = true; + assert (keyLengthBytes % 8 == 0); + assert (valueLengthBytes % 8 == 0); + // Here, we'll copy the data into our data pages. Because we only store a relative offset from + // the key address instead of storing the absolute address of the value, the key and value + // must be stored in the same memory page. + // (8 byte key length) (key) (8 byte value length) (value) + final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; + assert(requiredSize <= PAGE_SIZE_BYTES); + size++; + bitset.set(pos); + + // If there's not enough space in the current page, allocate a new page: + if (currentDataPage == null || PAGE_SIZE_BYTES - pageCursor < requiredSize) { + MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES); + dataPages.add(newPage); + pageCursor = 0; + currentDataPage = newPage; + } + + // Compute all of our offsets up-front: + final Object pageBaseObject = currentDataPage.getBaseObject(); + final long pageBaseOffset = currentDataPage.getBaseOffset(); + final long keySizeOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += 8; // word used to store the key size + final long keyDataOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += keyLengthBytes; + final long valueSizeOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += 8; // word used to store the value size + final long valueDataOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += valueLengthBytes; + + // Copy the key + PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes); + PlatformDependent.UNSAFE.copyMemory( + keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes); + // Copy the value + PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes); + PlatformDependent.UNSAFE.copyMemory( + valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes); + + final long storedKeyAddress = memoryManager.encodePageNumberAndOffset( + currentDataPage, keySizeOffsetInPage); + longArray.set(pos * 2, storedKeyAddress); + longArray.set(pos * 2 + 1, keyHashcode); + updateAddressesAndSizes(storedKeyAddress); + isDefined = true; + if (size > growthThreshold) { + growAndRehash(); + } + } + } + + /** + * Allocate new data structures for this map. When calling this outside of the constructor, + * make sure to keep references to the old data structures so that you can free them. + * + * @param capacity the new map capacity + */ + private void allocate(int capacity) { + capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64); + longArray = new LongArray(memoryManager.allocate(capacity * 8 * 2)); + bitset = new BitSet(memoryManager.allocate(capacity / 8).zero()); + + this.growthThreshold = (int) (capacity * loadFactor); + this.mask = capacity - 1; + } + + /** + * Free all allocated memory associated with this map, including the storage for keys and values + * as well as the hash map array itself. + * + * This method is idempotent. + */ + public void free() { + if (longArray != null) { + memoryManager.free(longArray.memoryBlock()); + longArray = null; + } + if (bitset != null) { + memoryManager.free(bitset.memoryBlock()); + bitset = null; + } + Iterator dataPagesIterator = dataPages.iterator(); + while (dataPagesIterator.hasNext()) { + memoryManager.freePage(dataPagesIterator.next()); + dataPagesIterator.remove(); + } + assert(dataPages.isEmpty()); + } + + /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ + public long getTotalMemoryConsumption() { + return ( + dataPages.size() * PAGE_SIZE_BYTES + + bitset.memoryBlock().size() + + longArray.memoryBlock().size()); + } + + /** + * Returns the total amount of time spent resizing this map (in nanoseconds). + */ + public long getTimeSpentResizingNs() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return timeSpentResizingNs; + } + + + /** + * Returns the average number of probes per key lookup. + */ + public double getAverageProbesPerLookup() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return (1.0 * numProbes) / numKeyLookups; + } + + public long getNumHashCollisions() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return numHashCollisions; + } + + /** + * Grows the size of the hash table and re-hash everything. + */ + private void growAndRehash() { + long resizeStartTime = -1; + if (enablePerfMetrics) { + resizeStartTime = System.nanoTime(); + } + // Store references to the old data structures to be used when we re-hash + final LongArray oldLongArray = longArray; + final BitSet oldBitSet = bitset; + final int oldCapacity = (int) oldBitSet.capacity(); + + // Allocate the new data structures + allocate(Math.min(Integer.MAX_VALUE, growthStrategy.nextCapacity(oldCapacity))); + + // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it) + for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) { + final long keyPointer = oldLongArray.get(pos * 2); + final int hashcode = (int) oldLongArray.get(pos * 2 + 1); + int newPos = hashcode & mask; + int step = 1; + boolean keepGoing = true; + + // No need to check for equality here when we insert so this has one less if branch than + // the similar code path in addWithoutResize. + while (keepGoing) { + if (!bitset.isSet(newPos)) { + bitset.set(newPos); + longArray.set(newPos * 2, keyPointer); + longArray.set(newPos * 2 + 1, hashcode); + keepGoing = false; + } else { + newPos = (newPos + step) & mask; + step++; + } + } + } + + // Deallocate the old data structures. + memoryManager.free(oldLongArray.memoryBlock()); + memoryManager.free(oldBitSet.memoryBlock()); + if (enablePerfMetrics) { + timeSpentResizingNs += System.nanoTime() - resizeStartTime; + } + } + + /** Returns the next number greater or equal num that is power of 2. */ + private static long nextPowerOf2(long num) { + final long highBit = Long.highestOneBit(num); + return (highBit == num) ? num : highBit << 1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java new file mode 100644 index 0000000000000..7c321baffe82d --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +/** + * Interface that defines how we can grow the size of a hash map when it is over a threshold. + */ +public interface HashMapGrowthStrategy { + + int nextCapacity(int currentCapacity); + + /** + * Double the size of the hash map every time. + */ + HashMapGrowthStrategy DOUBLING = new Doubling(); + + class Doubling implements HashMapGrowthStrategy { + @Override + public int nextCapacity(int currentCapacity) { + return currentCapacity * 2; + } + } + +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java new file mode 100644 index 0000000000000..62c29c8cc1e4d --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +/** + * Manages memory for an executor. Individual operators / tasks allocate memory through + * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager. + */ +public class ExecutorMemoryManager { + + /** + * Allocator, exposed for enabling untracked allocations of temporary data structures. + */ + public final MemoryAllocator allocator; + + /** + * Tracks whether memory will be allocated on the JVM heap or off-heap using sun.misc.Unsafe. + */ + final boolean inHeap; + + /** + * Construct a new ExecutorMemoryManager. + * + * @param allocator the allocator that will be used + */ + public ExecutorMemoryManager(MemoryAllocator allocator) { + this.inHeap = allocator instanceof HeapMemoryAllocator; + this.allocator = allocator; + } + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). + */ + MemoryBlock allocate(long size) throws OutOfMemoryError { + return allocator.allocate(size); + } + + void free(MemoryBlock memory) { + allocator.free(memory); + } + +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java new file mode 100644 index 0000000000000..bbe83d36cf36b --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +/** + * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. + */ +public class HeapMemoryAllocator implements MemoryAllocator { + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + long[] array = new long[(int) (size / 8)]; + return MemoryBlock.fromLongArray(array); + } + + @Override + public void free(MemoryBlock memory) { + // Do nothing + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java new file mode 100644 index 0000000000000..5192f68c862cf --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +public interface MemoryAllocator { + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). + */ + MemoryBlock allocate(long size) throws OutOfMemoryError; + + void free(MemoryBlock memory); + + MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + + MemoryAllocator HEAP = new HeapMemoryAllocator(); +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java new file mode 100644 index 0000000000000..0beb743e5644e --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + */ +public class MemoryBlock extends MemoryLocation { + + private final long length; + + /** + * Optional page number; used when this MemoryBlock represents a page allocated by a + * MemoryManager. This is package-private and is modified by MemoryManager. + */ + int pageNumber = -1; + + MemoryBlock(@Nullable Object obj, long offset, long length) { + super(obj, offset); + this.length = length; + } + + /** + * Returns the size of the memory block. + */ + public long size() { + return length; + } + + /** + * Clear the contents of this memory block. Returns `this` to facilitate chaining. + */ + public MemoryBlock zero() { + PlatformDependent.UNSAFE.setMemory(obj, offset, length, (byte) 0); + return this; + } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static MemoryBlock fromLongArray(final long[] array) { + return new MemoryBlock(array, PlatformDependent.LONG_ARRAY_OFFSET, array.length * 8); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java new file mode 100644 index 0000000000000..74ebc87dc978c --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +/** + * A memory location. Tracked either by a memory address (with off-heap allocation), + * or by an offset from a JVM object (in-heap allocation). + */ +public class MemoryLocation { + + @Nullable + Object obj; + + long offset; + + public MemoryLocation(@Nullable Object obj, long offset) { + this.obj = obj; + this.offset = offset; + } + + public MemoryLocation() { + this(null, 0); + } + + public void setObjAndOffset(Object newObj, long newOffset) { + this.obj = newObj; + this.offset = newOffset; + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java new file mode 100644 index 0000000000000..9224988e6ad69 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import java.util.*; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Manages the memory allocated by an individual task. + *

+ * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs. + * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is + * addressed by the combination of a base Object reference and a 64-bit offset within that object. + * This is a problem when we want to store pointers to data structures inside of other structures, + * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits + * to address memory, we can't just store the address of the base object since it's not guaranteed + * to remain stable as the heap gets reorganized due to GC. + *

+ * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap + * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to + * store a "page number" and the lower 51 bits to store an offset within this page. These page + * numbers are used to index into a "page table" array inside of the MemoryManager in order to + * retrieve the base object. + *

+ * This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the + * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is + * approximately 35 terabytes of memory. + */ +public final class TaskMemoryManager { + + private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); + + /** + * The number of entries in the page table. + */ + private static final int PAGE_TABLE_SIZE = 1 << 13; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; + + /** Bit mask for the upper 13 bits of a long */ + private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; + + /** + * Similar to an operating system's page table, this array maps page numbers into base object + * pointers, allowing us to translate between the hashtable's internal 64-bit address + * representation and the baseObject+offset representation which we use to support both in- and + * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`. + * When using an in-heap allocator, the entries in this map will point to pages' base objects. + * Entries are added to this map as new data pages are allocated. + */ + private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE]; + + /** + * Bitmap for tracking free pages. + */ + private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); + + /** + * Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean + * up leaked memory. + */ + private final HashSet allocatedNonPageMemory = new HashSet(); + + private final ExecutorMemoryManager executorMemoryManager; + + /** + * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods + * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, + * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. + */ + private final boolean inHeap; + + /** + * Construct a new MemoryManager. + */ + public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { + this.inHeap = executorMemoryManager.inHeap; + this.executorMemoryManager = executorMemoryManager; + } + + /** + * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is + * intended for allocating large blocks of memory that will be shared between operators. + */ + public MemoryBlock allocatePage(long size) { + if (logger.isTraceEnabled()) { + logger.trace("Allocating {} byte page", size); + } + if (size >= (1L << 51)) { + throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes"); + } + + final int pageNumber; + synchronized (this) { + pageNumber = allocatedPages.nextClearBit(0); + if (pageNumber >= PAGE_TABLE_SIZE) { + throw new IllegalStateException( + "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); + } + allocatedPages.set(pageNumber); + } + final MemoryBlock page = executorMemoryManager.allocate(size); + page.pageNumber = pageNumber; + pageTable[pageNumber] = page; + if (logger.isDebugEnabled()) { + logger.debug("Allocate page number {} ({} bytes)", pageNumber, size); + } + return page; + } + + /** + * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. + */ + public void freePage(MemoryBlock page) { + if (logger.isTraceEnabled()) { + logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size()); + } + assert (page.pageNumber != -1) : + "Called freePage() on memory that wasn't allocated with allocatePage()"; + executorMemoryManager.free(page); + synchronized (this) { + allocatedPages.clear(page.pageNumber); + } + pageTable[page.pageNumber] = null; + if (logger.isDebugEnabled()) { + logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + } + } + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). This method is intended + * to be used for allocating operators' internal data structures. For data pages that you want to + * exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since + * that will enable intra-memory pointers (see + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's + * top-level Javadoc for more details). + */ + public MemoryBlock allocate(long size) throws OutOfMemoryError { + final MemoryBlock memory = executorMemoryManager.allocate(size); + allocatedNonPageMemory.add(memory); + return memory; + } + + /** + * Free memory allocated by {@link TaskMemoryManager#allocate(long)}. + */ + public void free(MemoryBlock memory) { + assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; + executorMemoryManager.free(memory); + final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); + assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + } + + /** + * Given a memory page and offset within that page, encode this address into a 64-bit long. + * This address will remain valid as long as the corresponding page has not been freed. + */ + public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { + if (inHeap) { + assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } else { + return offsetInPage; + } + } + + /** + * Get the page associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public Object getPage(long pagePlusOffsetAddress) { + if (inHeap) { + final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final Object page = pageTable[pageNumber].getBaseObject(); + assert (page != null); + return page; + } else { + return null; + } + } + + /** + * Get the offset associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public long getOffsetInPage(long pagePlusOffsetAddress) { + if (inHeap) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + } else { + return pagePlusOffsetAddress; + } + } + + /** + * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return + * value can be used to detect memory leaks. + */ + public long cleanUpAllAllocatedMemory() { + long freedBytes = 0; + for (MemoryBlock page : pageTable) { + if (page != null) { + freedBytes += page.size(); + freePage(page); + } + } + final Iterator iter = allocatedNonPageMemory.iterator(); + while (iter.hasNext()) { + final MemoryBlock memory = iter.next(); + freedBytes += memory.size(); + // We don't call free() here because that calls Set.remove, which would lead to a + // ConcurrentModificationException here. + executorMemoryManager.free(memory); + iter.remove(); + } + return freedBytes; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java new file mode 100644 index 0000000000000..15898771fef25 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory. + */ +public class UnsafeMemoryAllocator implements MemoryAllocator { + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + long address = PlatformDependent.UNSAFE.allocateMemory(size); + return new MemoryBlock(null, address, size); + } + + @Override + public void free(MemoryBlock memory) { + assert (memory.obj == null) : + "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; + PlatformDependent.UNSAFE.freeMemory(memory.offset); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java new file mode 100644 index 0000000000000..5974cf91ff993 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class LongArraySuite { + + @Test + public void basicTest() { + long[] bytes = new long[2]; + LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); + arr.set(0, 1L); + arr.set(1, 2L); + arr.set(1, 3L); + Assert.assertEquals(2, arr.size()); + Assert.assertEquals(1L, arr.get(0)); + Assert.assertEquals(3L, arr.get(1)); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java new file mode 100644 index 0000000000000..4bf132fd4053e --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import junit.framework.Assert; +import org.apache.spark.unsafe.bitset.BitSet; +import org.junit.Test; + +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class BitSetSuite { + + private static BitSet createBitSet(int capacity) { + assert capacity % 64 == 0; + return new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]).zero()); + } + + @Test + public void basicOps() { + BitSet bs = createBitSet(64); + Assert.assertEquals(64, bs.capacity()); + + // Make sure the bit set starts empty. + for (int i = 0; i < bs.capacity(); i++) { + Assert.assertFalse(bs.isSet(i)); + } + + // Set every bit and check it. + for (int i = 0; i < bs.capacity(); i++) { + bs.set(i); + Assert.assertTrue(bs.isSet(i)); + } + + // Unset every bit and check it. + for (int i = 0; i < bs.capacity(); i++) { + Assert.assertTrue(bs.isSet(i)); + bs.unset(i); + Assert.assertFalse(bs.isSet(i)); + } + } + + @Test + public void traversal() { + BitSet bs = createBitSet(256); + + Assert.assertEquals(-1, bs.nextSetBit(0)); + Assert.assertEquals(-1, bs.nextSetBit(10)); + Assert.assertEquals(-1, bs.nextSetBit(64)); + + bs.set(10); + Assert.assertEquals(10, bs.nextSetBit(0)); + Assert.assertEquals(10, bs.nextSetBit(1)); + Assert.assertEquals(10, bs.nextSetBit(10)); + Assert.assertEquals(-1, bs.nextSetBit(11)); + + bs.set(11); + Assert.assertEquals(10, bs.nextSetBit(10)); + Assert.assertEquals(11, bs.nextSetBit(11)); + + // Skip a whole word and find it + bs.set(190); + Assert.assertEquals(190, bs.nextSetBit(12)); + + Assert.assertEquals(-1, bs.nextSetBit(191)); + Assert.assertEquals(-1, bs.nextSetBit(256)); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java new file mode 100644 index 0000000000000..3b9175835229c --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.hash; + +import java.util.HashSet; +import java.util.Random; +import java.util.Set; + +import junit.framework.Assert; +import org.apache.spark.unsafe.PlatformDependent; +import org.junit.Test; + +/** + * Test file based on Guava's Murmur3Hash32Test. + */ +public class Murmur3_x86_32Suite { + + private static final Murmur3_x86_32 hasher = new Murmur3_x86_32(0); + + @Test + public void testKnownIntegerInputs() { + Assert.assertEquals(593689054, hasher.hashInt(0)); + Assert.assertEquals(-189366624, hasher.hashInt(-42)); + Assert.assertEquals(-1134849565, hasher.hashInt(42)); + Assert.assertEquals(-1718298732, hasher.hashInt(Integer.MIN_VALUE)); + Assert.assertEquals(-1653689534, hasher.hashInt(Integer.MAX_VALUE)); + } + + @Test + public void testKnownLongInputs() { + Assert.assertEquals(1669671676, hasher.hashLong(0L)); + Assert.assertEquals(-846261623, hasher.hashLong(-42L)); + Assert.assertEquals(1871679806, hasher.hashLong(42L)); + Assert.assertEquals(1366273829, hasher.hashLong(Long.MIN_VALUE)); + Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE)); + } + + @Test + public void randomizedStressTest() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet(); + for (int i = 0; i < size; i++) { + int vint = rand.nextInt(); + long lint = rand.nextLong(); + Assert.assertEquals(hasher.hashInt(vint), hasher.hashInt(vint)); + Assert.assertEquals(hasher.hashLong(lint), hasher.hashLong(lint)); + + hashcodes.add(hasher.hashLong(lint)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestBytes() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet(); + for (int i = 0; i < size; i++) { + int byteArrSize = rand.nextInt(100) * 8; + byte[] bytes = new byte[byteArrSize]; + rand.nextBytes(bytes); + + Assert.assertEquals( + hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestPaddedStrings() { + int size = 64000; + // A set used to track collision rate. + Set hashcodes = new HashSet(); + for (int i = 0; i < size; i++) { + int byteArrSize = 8; + byte[] strBytes = ("" + i).getBytes(); + byte[] paddedBytes = new byte[byteArrSize]; + System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + + Assert.assertEquals( + hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java new file mode 100644 index 0000000000000..9038cf567f1e2 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import java.lang.Exception; +import java.nio.ByteBuffer; +import java.util.*; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.PlatformDependent; +import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public abstract class AbstractBytesToBytesMapSuite { + + private final Random rand = new Random(42); + + private TaskMemoryManager memoryManager; + + @Before + public void setup() { + memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); + } + + @After + public void tearDown() { + if (memoryManager != null) { + memoryManager.cleanUpAllAllocatedMemory(); + memoryManager = null; + } + } + + protected abstract MemoryAllocator getMemoryAllocator(); + + private static byte[] getByteArray(MemoryLocation loc, int size) { + final byte[] arr = new byte[size]; + PlatformDependent.UNSAFE.copyMemory( + loc.getBaseObject(), + loc.getBaseOffset(), + arr, + BYTE_ARRAY_OFFSET, + size + ); + return arr; + } + + private byte[] getRandomByteArray(int numWords) { + Assert.assertTrue(numWords > 0); + final int lengthInBytes = numWords * 8; + final byte[] bytes = new byte[lengthInBytes]; + rand.nextBytes(bytes); + return bytes; + } + + /** + * Fast equality checking for byte arrays, since these comparisons are a bottleneck + * in our stress tests. + */ + private static boolean arrayEquals( + byte[] expected, + MemoryLocation actualAddr, + long actualLengthBytes) { + return (actualLengthBytes == expected.length) && ByteArrayMethods.wordAlignedArrayEquals( + expected, + BYTE_ARRAY_OFFSET, + actualAddr.getBaseObject(), + actualAddr.getBaseOffset(), + expected.length + ); + } + + @Test + public void emptyMap() { + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + try { + Assert.assertEquals(0, map.size()); + final int keyLengthInWords = 10; + final int keyLengthInBytes = keyLengthInWords * 8; + final byte[] key = getRandomByteArray(keyLengthInWords); + Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + } finally { + map.free(); + } + } + + @Test + public void setAndRetrieveAKey() { + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + final int recordLengthWords = 10; + final int recordLengthBytes = recordLengthWords * 8; + final byte[] keyData = getRandomByteArray(recordLengthWords); + final byte[] valueData = getRandomByteArray(recordLengthWords); + try { + final BytesToBytesMap.Location loc = + map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + keyData, + BYTE_ARRAY_OFFSET, + recordLengthBytes, + valueData, + BYTE_ARRAY_OFFSET, + recordLengthBytes + ); + // After storing the key and value, the other location methods should return results that + // reflect the result of this store without us having to call lookup() again on the same key. + Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); + Assert.assertEquals(recordLengthBytes, loc.getValueLength()); + Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + + // After calling lookup() the location should still point to the correct data. + Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); + Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); + Assert.assertEquals(recordLengthBytes, loc.getValueLength()); + Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + + try { + loc.putNewKey( + keyData, + BYTE_ARRAY_OFFSET, + recordLengthBytes, + valueData, + BYTE_ARRAY_OFFSET, + recordLengthBytes + ); + Assert.fail("Should not be able to set a new value for a key"); + } catch (AssertionError e) { + // Expected exception; do nothing. + } + } finally { + map.free(); + } + } + + @Test + public void iteratorTest() throws Exception { + final int size = 128; + BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2); + try { + for (long i = 0; i < size; i++) { + final long[] value = new long[] { i }; + final BytesToBytesMap.Location loc = + map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8 + ); + } + final java.util.BitSet valuesSeen = new java.util.BitSet(size); + final Iterator iter = map.iterator(); + while (iter.hasNext()) { + final BytesToBytesMap.Location loc = iter.next(); + Assert.assertTrue(loc.isDefined()); + final MemoryLocation keyAddress = loc.getKeyAddress(); + final MemoryLocation valueAddress = loc.getValueAddress(); + final long key = PlatformDependent.UNSAFE.getLong( + keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long value = PlatformDependent.UNSAFE.getLong( + valueAddress.getBaseObject(), valueAddress.getBaseOffset()); + Assert.assertEquals(key, value); + valuesSeen.set((int) value); + } + Assert.assertEquals(size, valuesSeen.cardinality()); + } finally { + map.free(); + } + } + + @Test + public void randomizedStressTest() { + final int size = 65536; + // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays + // into ByteBuffers in order to use them as keys here. + final Map expected = new HashMap(); + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size); + + try { + // Fill the map to 90% full so that we can trigger probing + for (int i = 0; i < size * 0.9; i++) { + final byte[] key = getRandomByteArray(rand.nextInt(256) + 1); + final byte[] value = getRandomByteArray(rand.nextInt(512) + 1); + if (!expected.containsKey(ByteBuffer.wrap(key))) { + expected.put(ByteBuffer.wrap(key), value); + final BytesToBytesMap.Location loc = map.lookup( + key, + BYTE_ARRAY_OFFSET, + key.length + ); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + key, + BYTE_ARRAY_OFFSET, + key.length, + value, + BYTE_ARRAY_OFFSET, + value.length + ); + // After calling putNewKey, the following should be true, even before calling + // lookup(): + Assert.assertTrue(loc.isDefined()); + Assert.assertEquals(key.length, loc.getKeyLength()); + Assert.assertEquals(value.length, loc.getValueLength()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + } + } + + for (Map.Entry entry : expected.entrySet()) { + final byte[] key = entry.getKey().array(); + final byte[] value = entry.getValue(); + final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + Assert.assertTrue(loc.isDefined()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + } + } finally { + map.free(); + } + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java new file mode 100644 index 0000000000000..5a10de49f54fe --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import org.apache.spark.unsafe.memory.MemoryAllocator; + +public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite { + + @Override + protected MemoryAllocator getMemoryAllocator() { + return MemoryAllocator.UNSAFE; + } + +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java new file mode 100644 index 0000000000000..12cc9b25d93b3 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import org.apache.spark.unsafe.memory.MemoryAllocator; + +public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite { + + @Override + protected MemoryAllocator getMemoryAllocator() { + return MemoryAllocator.HEAP; + } + +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java new file mode 100644 index 0000000000000..932882f1ca248 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.junit.Assert; +import org.junit.Test; + +public class TaskMemoryManagerSuite { + + @Test + public void leakedNonPageMemoryIsDetected() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + manager.allocate(1024); // leak memory + Assert.assertEquals(1024, manager.cleanUpAllAllocatedMemory()); + } + + @Test + public void leakedPageMemoryIsDetected() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + manager.allocatePage(4096); // leak memory + Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); + } + +} From baed3f2c73afd9c7d9de34f0485c507ac6a498b0 Mon Sep 17 00:00:00 2001 From: Dean Chen Date: Wed, 29 Apr 2015 08:58:33 -0500 Subject: [PATCH 176/191] [SPARK-6918] [YARN] Secure HBase support. Obtain HBase security token with Kerberos credentials locally to be sent to executors. Tested on eBay's secure HBase cluster. Similar to obtainTokenForNamenodes and fails gracefully if HBase classes are not included in path. Requires hbase-site.xml to be in the classpath(typically via conf dir) for the zookeeper configuration. Should that go in the docs somewhere? Did not see an HBase section. Author: Dean Chen Closes #5586 from deanchen/master and squashes the following commits: 0c190ef [Dean Chen] [SPARK-6918][YARN] Secure HBase support. --- .../org/apache/spark/deploy/yarn/Client.scala | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 741239c953794..4abcf7307a388 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -39,7 +39,7 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.hadoop.security.token.Token +import org.apache.hadoop.security.token.{TokenIdentifier, Token} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -226,6 +226,7 @@ private[spark] class Client( val distributedUris = new HashSet[String] obtainTokensForNamenodes(nns, hadoopConf, credentials) obtainTokenForHiveMetastore(hadoopConf, credentials) + obtainTokenForHBase(hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", fs.getDefaultReplication(dst)).toShort @@ -1084,6 +1085,41 @@ object Client extends Logging { } } + /** + * Obtain security token for HBase. + */ + def obtainTokenForHBase(conf: Configuration, credentials: Credentials): Unit = { + if (UserGroupInformation.isSecurityEnabled) { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + + try { + val confCreate = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). + getMethod("create", classOf[Configuration]) + val obtainToken = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). + getMethod("obtainToken", classOf[Configuration]) + + logDebug("Attempting to fetch HBase security token.") + + val hbaseConf = confCreate.invoke(null, conf) + val token = obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]] + credentials.addToken(token.getService, token) + + logInfo("Added HBase security token to credentials.") + } catch { + case e:java.lang.NoSuchMethodException => + logInfo("HBase Method not found: " + e) + case e:java.lang.ClassNotFoundException => + logDebug("HBase Class not found: " + e) + case e:java.lang.NoClassDefFoundError => + logDebug("HBase Class not found: " + e) + case e:Exception => + logError("Exception when obtaining HBase security token: " + e) + } + } + } + /** * Return whether the two file systems are the same. */ From 687273d9150e1c89a74aa1473f0c6495f56509af Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Apr 2015 09:46:37 -0700 Subject: [PATCH 177/191] [SPARK-7223] Rename RPC askWithReply -> askWithReply, sendWithReply -> ask. The old naming scheme was very confusing between askWithReply and sendWithReply. I also divided RpcEnv.scala into multiple files. Author: Reynold Xin Closes #5768 from rxin/rpc-rename and squashes the following commits: a84058e [Reynold Xin] [SPARK-7223] Rename RPC askWithReply -> askWithReply, sendWithReply -> ask. --- .../org/apache/spark/MapOutputTracker.scala | 2 +- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../CoarseGrainedExecutorBackend.scala | 4 +- .../org/apache/spark/executor/Executor.scala | 2 +- .../org/apache/spark/rpc/RpcCallContext.scala | 41 +++ .../org/apache/spark/rpc/RpcEndpoint.scala | 148 +++++++++ .../org/apache/spark/rpc/RpcEndpointRef.scala | 119 +++++++ .../scala/org/apache/spark/rpc/RpcEnv.scala | 313 ++---------------- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../scheduler/OutputCommitCoordinator.scala | 2 +- .../CoarseGrainedSchedulerBackend.scala | 6 +- .../cluster/YarnSchedulerBackend.scala | 8 +- .../spark/scheduler/local/LocalBackend.scala | 2 +- .../spark/storage/BlockManagerMaster.scala | 28 +- .../storage/BlockManagerMasterEndpoint.scala | 12 +- .../apache/spark/HeartbeatReceiverSuite.scala | 4 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 20 +- .../spark/storage/BlockManagerSuite.scala | 2 +- .../receiver/ReceiverSupervisorImpl.scala | 6 +- 20 files changed, 394 insertions(+), 331 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index d65c94e410662..16072283edbe9 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -106,7 +106,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging */ protected def askTracker[T: ClassTag](message: Any): T = { try { - trackerEndpoint.askWithReply[T](message) + trackerEndpoint.askWithRetry[T](message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d0cf2a8dd01cd..5ae8fb81de809 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -555,7 +555,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkEnv.executorActorSystemName, RpcAddress(host, port), ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) - Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump)) + Some(endpointRef.askWithRetry[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 8af46f3327adb..79aed90b53e2f 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -57,7 +57,7 @@ private[spark] class CoarseGrainedExecutorBackend( logInfo("Connecting to driver: " + driverUrl) rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => driver = Some(ref) - ref.sendWithReply[RegisteredExecutor.type]( + ref.ask[RegisteredExecutor.type]( RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) } onComplete { case Success(msg) => Utils.tryLogNonFatalError { @@ -154,7 +154,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { executorConf, new SecurityManager(executorConf)) val driver = fetcher.setupEndpointRefByURI(driverUrl) - val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ + val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index dd1c48e6cb953..8f916e0502ecb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -441,7 +441,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) + val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message) if (response.reregisterBlockManager) { logWarning("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala new file mode 100644 index 0000000000000..3e5b64265e919 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +/** + * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * and can be called in any thread. + */ +private[spark] trait RpcCallContext { + + /** + * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] + * will be called. + */ + def reply(response: Any): Unit + + /** + * Report a failure to the sender. + */ + def sendFailure(e: Throwable): Unit + + /** + * The sender of this message. + */ + def sender: RpcEndpointRef +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala new file mode 100644 index 0000000000000..d2b2baef1d8c4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.SparkException + +/** + * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be + * created using Reflection. + */ +private[spark] trait RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv +} + +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint + + +/** + * An end point for the RPC that defines what functions to trigger given a message. + * + * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. + * + * The lift-cycle will be: + * + * constructor onStart receive* onStop + * + * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use + * [[ThreadSafeRpcEndpoint]] + * + * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be + * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. + */ +private[spark] trait RpcEndpoint { + + /** + * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. + */ + val rpcEnv: RpcEnv + + /** + * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is + * called. And `self` will become `null` when `onStop` is called. + * + * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not + * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. + */ + final def self: RpcEndpointRef = { + require(rpcEnv != null, "rpcEnv has not been initialized") + rpcEnv.endpointRef(this) + } + + /** + * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a + * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + */ + def receive: PartialFunction[Any, Unit] = { + case _ => throw new SparkException(self + " does not implement 'receive'") + } + + /** + * Process messages from [[RpcEndpointRef.ask]]. If receiving a unmatched message, + * [[SparkException]] will be thrown and sent to `onError`. + */ + def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case _ => context.sendFailure(new SparkException(self + " won't reply anything")) + } + + /** + * Invoked when any exception is thrown during handling messages. + */ + def onError(cause: Throwable): Unit = { + // By default, throw e and let RpcEnv handle it + throw cause + } + + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is connected to the current node. + */ + def onConnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is lost. + */ + def onDisconnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * A convenient method to stop [[RpcEndpoint]]. + */ + final def stop(): Unit = { + val _self = self + if (_self != null) { + rpcEnv.stop(_self) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala new file mode 100644 index 0000000000000..69181edb9ad44 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import scala.concurrent.{Await, Future} +import scala.concurrent.duration.FiniteDuration +import scala.reflect.ClassTag + +import org.apache.spark.util.RpcUtils +import org.apache.spark.{SparkException, Logging, SparkConf} + +/** + * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. + */ +private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) + extends Serializable with Logging { + + private[this] val maxRetries = RpcUtils.numRetries(conf) + private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) + private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) + + /** + * return the address for the [[RpcEndpointRef]] + */ + def address: RpcAddress + + def name: String + + /** + * Sends a one-way asynchronous message. Fire-and-forget semantics. + */ + def send(message: Any): Unit + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to + * receive the reply within the specified timeout. + * + * This method only sends the message once and never retries. + */ + def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to + * receive the reply within a default timeout. + * + * This method only sends the message once and never retries. + */ + def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default + * timeout, or throw a SparkException if this fails even after the default number of retries. + * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this + * method retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a + * specified timeout, throw a SparkException if this fails even after the specified number of + * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method + * retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @param timeout the timeout duration + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + var attempts = 0 + var lastException: Exception = null + while (attempts < maxRetries) { + attempts += 1 + try { + val future = ask[T](message, timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning(s"Error sending message [message = $message] in $attempts attempts", e) + } + Thread.sleep(retryWaitMs) + } + + throw new SparkException( + s"Error sending message [message = $message]", lastException) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index a5336b7563802..12b6b28d4d7ec 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -20,13 +20,41 @@ package org.apache.spark.rpc import java.net.URI import scala.concurrent.{Await, Future} -import scala.concurrent.duration._ import scala.language.postfixOps -import scala.reflect.ClassTag -import org.apache.spark.{Logging, SparkException, SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{RpcUtils, Utils} + +/** + * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor + * so that it can be created via Reflection. + */ +private[spark] object RpcEnv { + + private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { + // Add more RpcEnv implementations here + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") + val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) + Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). + newInstance().asInstanceOf[RpcEnvFactory] + } + + def create( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): RpcEnv = { + // Using Reflection to create the RpcEnv to avoid to depend on Akka directly + val config = RpcEnvConfig(conf, name, host, port, securityManager) + getRpcEnvFactory(conf).create(config) + } + +} + + /** * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote @@ -112,6 +140,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { def uriOf(systemName: String, address: RpcAddress, endpointName: String): String } + private[spark] case class RpcEnvConfig( conf: SparkConf, name: String, @@ -119,261 +148,9 @@ private[spark] case class RpcEnvConfig( port: Int, securityManager: SecurityManager) -/** - * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor - * so that it can be created via Reflection. - */ -private[spark] object RpcEnv { - - private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { - // Add more RpcEnv implementations here - val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") - val rpcEnvName = conf.get("spark.rpc", "akka") - val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). - newInstance().asInstanceOf[RpcEnvFactory] - } - - def create( - name: String, - host: String, - port: Int, - conf: SparkConf, - securityManager: SecurityManager): RpcEnv = { - // Using Reflection to create the RpcEnv to avoid to depend on Akka directly - val config = RpcEnvConfig(conf, name, host, port, securityManager) - getRpcEnvFactory(conf).create(config) - } - -} - -/** - * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be - * created using Reflection. - */ -private[spark] trait RpcEnvFactory { - - def create(config: RpcEnvConfig): RpcEnv -} /** - * An end point for the RPC that defines what functions to trigger given a message. - * - * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. - * - * The lift-cycle will be: - * - * constructor onStart receive* onStop - * - * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use - * [[ThreadSafeRpcEndpoint]] - * - * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be - * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. - */ -private[spark] trait RpcEndpoint { - - /** - * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. - */ - val rpcEnv: RpcEnv - - /** - * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is - * called. And `self` will become `null` when `onStop` is called. - * - * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not - * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. - */ - final def self: RpcEndpointRef = { - require(rpcEnv != null, "rpcEnv has not been initialized") - rpcEnv.endpointRef(this) - } - - /** - * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a - * unmatched message, [[SparkException]] will be thrown and sent to `onError`. - */ - def receive: PartialFunction[Any, Unit] = { - case _ => throw new SparkException(self + " does not implement 'receive'") - } - - /** - * Process messages from [[RpcEndpointRef.sendWithReply]]. If receiving a unmatched message, - * [[SparkException]] will be thrown and sent to `onError`. - */ - def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case _ => context.sendFailure(new SparkException(self + " won't reply anything")) - } - - /** - * Call onError when any exception is thrown during handling messages. - * - * @param cause - */ - def onError(cause: Throwable): Unit = { - // By default, throw e and let RpcEnv handle it - throw cause - } - - /** - * Invoked before [[RpcEndpoint]] starts to handle any message. - */ - def onStart(): Unit = { - // By default, do nothing. - } - - /** - * Invoked when [[RpcEndpoint]] is stopping. - */ - def onStop(): Unit = { - // By default, do nothing. - } - - /** - * Invoked when `remoteAddress` is connected to the current node. - */ - def onConnected(remoteAddress: RpcAddress): Unit = { - // By default, do nothing. - } - - /** - * Invoked when `remoteAddress` is lost. - */ - def onDisconnected(remoteAddress: RpcAddress): Unit = { - // By default, do nothing. - } - - /** - * Invoked when some network error happens in the connection between the current node and - * `remoteAddress`. - */ - def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - // By default, do nothing. - } - - /** - * A convenient method to stop [[RpcEndpoint]]. - */ - final def stop(): Unit = { - val _self = self - if (_self != null) { - rpcEnv.stop(_self) - } - } -} - -/** - * A trait that requires RpcEnv thread-safely sending messages to it. - * - * Thread-safety means processing of one message happens before processing of the next message by - * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a - * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the - * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. - * - * However, there is no guarantee that the same thread will be executing the same - * [[ThreadSafeRpcEndpoint]] for different messages. - */ -trait ThreadSafeRpcEndpoint extends RpcEndpoint - -/** - * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. - */ -private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) - extends Serializable with Logging { - - private[this] val maxRetries = RpcUtils.numRetries(conf) - private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) - private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) - - /** - * return the address for the [[RpcEndpointRef]] - */ - def address: RpcAddress - - def name: String - - /** - * Sends a one-way asynchronous message. Fire-and-forget semantics. - */ - def send(message: Any): Unit - - /** - * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to - * receive the reply within a default timeout. - * - * This method only sends the message once and never retries. - */ - def sendWithReply[T: ClassTag](message: Any): Future[T] = - sendWithReply(message, defaultAskTimeout) - - /** - * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to - * receive the reply within the specified timeout. - * - * This method only sends the message once and never retries. - */ - def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] - - /** - * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default - * timeout, or throw a SparkException if this fails even after the default number of retries. - * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this - * method retries, the message handling in the receiver side should be idempotent. - * - * Note: this is a blocking action which may cost a lot of time, so don't call it in an message - * loop of [[RpcEndpoint]]. - * - * @param message the message to send - * @tparam T type of the reply message - * @return the reply message from the corresponding [[RpcEndpoint]] - */ - def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultAskTimeout) - - /** - * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a - * specified timeout, throw a SparkException if this fails even after the specified number of - * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method - * retries, the message handling in the receiver side should be idempotent. - * - * Note: this is a blocking action which may cost a lot of time, so don't call it in an message - * loop of [[RpcEndpoint]]. - * - * @param message the message to send - * @param timeout the timeout duration - * @tparam T type of the reply message - * @return the reply message from the corresponding [[RpcEndpoint]] - */ - def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T = { - // TODO: Consider removing multiple attempts - var attempts = 0 - var lastException: Exception = null - while (attempts < maxRetries) { - attempts += 1 - try { - val future = sendWithReply[T](message, timeout) - val result = Await.result(future, timeout) - if (result == null) { - throw new SparkException("Actor returned null") - } - return result - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning(s"Error sending message [message = $message] in $attempts attempts", e) - } - Thread.sleep(retryWaitMs) - } - - throw new SparkException( - s"Error sending message [message = $message]", lastException) - } - -} - -/** - * Represent a host with a port + * Represents a host and port. */ private[spark] case class RpcAddress(host: String, port: Int) { // TODO do we need to add the type of RpcEnv in the address? @@ -383,6 +160,7 @@ private[spark] case class RpcAddress(host: String, port: Int) { override val toString: String = hostPort } + private[spark] object RpcAddress { /** @@ -404,26 +182,3 @@ private[spark] object RpcAddress { RpcAddress(host, port) } } - -/** - * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe - * and can be called in any thread. - */ -private[spark] trait RpcCallContext { - - /** - * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] - * will be called. - */ - def reply(response: Any): Unit - - /** - * Report a failure to the sender. - */ - def sendFailure(e: Throwable): Unit - - /** - * The sender of this message. - */ - def sender: RpcEndpointRef -} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 652e52f2b2e73..ba0d468f111ef 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -293,7 +293,7 @@ private[akka] class AkkaRpcEndpointRef( actorRef ! AkkaMessage(message, false) } - override def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { import scala.concurrent.ExecutionContext.Implicits.global actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { case msg @ AkkaMessage(message, reply) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b7901c06a16e0..b511c306ca508 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -167,7 +167,7 @@ class DAGScheduler( taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) - blockManagerMaster.driverEndpoint.askWithReply[Boolean]( + blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( BlockManagerHeartbeat(blockManagerId), 600 seconds) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 7c184b1dcb308..0b1d47cff3746 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -85,7 +85,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { val msg = AskPermissionToCommitOutput(stage, partition, attempt) coordinatorRef match { case Some(endpointRef) => - endpointRef.askWithReply[Boolean](msg) + endpointRef.askWithRetry[Boolean](msg) case None => logError( "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9656fb76858ea..7352fa1fe9ebd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -252,7 +252,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp try { if (driverEndpoint != null) { logInfo("Shutting down all executors") - driverEndpoint.askWithReply[Boolean](StopExecutors) + driverEndpoint.askWithRetry[Boolean](StopExecutors) } } catch { case e: Exception => @@ -264,7 +264,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp stopExecutors() try { if (driverEndpoint != null) { - driverEndpoint.askWithReply[Boolean](StopDriver) + driverEndpoint.askWithRetry[Boolean](StopDriver) } } catch { case e: Exception => @@ -287,7 +287,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - driverEndpoint.askWithReply[Boolean](RemoveExecutor(executorId, reason)) + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) } catch { case e: Exception => throw new SparkException("Error notifying standalone scheduler's driver endpoint", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index d987c7d563579..2a3a5d925d06f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -53,14 +53,14 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpoint.askWithReply[Boolean](RequestExecutors(requestedTotal)) + yarnSchedulerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - yarnSchedulerEndpoint.askWithReply[Boolean](KillExecutors(executorIds)) + yarnSchedulerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -115,7 +115,7 @@ private[spark] abstract class YarnSchedulerBackend( amEndpoint match { case Some(am) => Future { - context.reply(am.askWithReply[Boolean](r)) + context.reply(am.askWithRetry[Boolean](r)) } onFailure { case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) @@ -130,7 +130,7 @@ private[spark] abstract class YarnSchedulerBackend( amEndpoint match { case Some(am) => Future { - context.reply(am.askWithReply[Boolean](k)) + context.reply(am.askWithRetry[Boolean](k)) } onFailure { case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index ac5b524517818..e64d06c4d3cfc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -123,7 +123,7 @@ private[spark] class LocalBackend( } override def stop() { - localEndpoint.sendWithReply(StopExecutor) + localEndpoint.ask(StopExecutor) } override def reviveOffers() { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index c798843bd5d8a..9bfc4201d37c0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -55,7 +55,7 @@ class BlockManagerMaster( memSize: Long, diskSize: Long, tachyonSize: Long): Boolean = { - val res = driverEndpoint.askWithReply[Boolean]( + val res = driverEndpoint.askWithRetry[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)) logDebug(s"Updated info of block $blockId") res @@ -63,12 +63,12 @@ class BlockManagerMaster( /** Get locations of the blockId from the driver */ def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - driverEndpoint.askWithReply[Seq[BlockManagerId]](GetLocations(blockId)) + driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - driverEndpoint.askWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + driverEndpoint.askWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } /** @@ -81,11 +81,11 @@ class BlockManagerMaster( /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - driverEndpoint.askWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) + driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId)) } def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { - driverEndpoint.askWithReply[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) + driverEndpoint.askWithRetry[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) } /** @@ -93,12 +93,12 @@ class BlockManagerMaster( * blocks that the driver knows about. */ def removeBlock(blockId: BlockId) { - driverEndpoint.askWithReply[Boolean](RemoveBlock(blockId)) + driverEndpoint.askWithRetry[Boolean](RemoveBlock(blockId)) } /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - val future = driverEndpoint.askWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") @@ -110,7 +110,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { - val future = driverEndpoint.askWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") @@ -122,7 +122,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = driverEndpoint.askWithReply[Future[Seq[Int]]]( + val future = driverEndpoint.askWithRetry[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Exception => @@ -141,11 +141,11 @@ class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - driverEndpoint.askWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + driverEndpoint.askWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } def getStorageStatus: Array[StorageStatus] = { - driverEndpoint.askWithReply[Array[StorageStatus]](GetStorageStatus) + driverEndpoint.askWithRetry[Array[StorageStatus]](GetStorageStatus) } /** @@ -166,7 +166,7 @@ class BlockManagerMaster( * master endpoint for a response to a prior message. */ val response = driverEndpoint. - askWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip val result = Await.result(Future.sequence(futures), timeout) if (result == null) { @@ -190,7 +190,7 @@ class BlockManagerMaster( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) - val future = driverEndpoint.askWithReply[Future[Seq[BlockId]]](msg) + val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) Await.result(future, timeout) } @@ -205,7 +205,7 @@ class BlockManagerMaster( /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!driverEndpoint.askWithReply[Boolean](message)) { + if (!driverEndpoint.askWithRetry[Boolean](message)) { throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 4682167912ff0..7212362df5d71 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -132,7 +132,7 @@ class BlockManagerMasterEndpoint( val removeMsg = RemoveRdd(rddId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveEndpoint.sendWithReply[Int](removeMsg) + bm.slaveEndpoint.ask[Int](removeMsg) }.toSeq ) } @@ -142,7 +142,7 @@ class BlockManagerMasterEndpoint( val removeMsg = RemoveShuffle(shuffleId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveEndpoint.sendWithReply[Boolean](removeMsg) + bm.slaveEndpoint.ask[Boolean](removeMsg) }.toSeq ) } @@ -159,7 +159,7 @@ class BlockManagerMasterEndpoint( } Future.sequence( requiredBlockManagers.map { bm => - bm.slaveEndpoint.sendWithReply[Int](removeMsg) + bm.slaveEndpoint.ask[Int](removeMsg) }.toSeq ) } @@ -214,7 +214,7 @@ class BlockManagerMasterEndpoint( // Remove the block from the slave's BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveEndpoint.sendWithReply[Boolean](RemoveBlock(blockId)) + blockManager.get.slaveEndpoint.ask[Boolean](RemoveBlock(blockId)) } } } @@ -253,7 +253,7 @@ class BlockManagerMasterEndpoint( blockManagerInfo.values.map { info => val blockStatusFuture = if (askSlaves) { - info.slaveEndpoint.sendWithReply[Option[BlockStatus]](getBlockStatus) + info.slaveEndpoint.ask[Option[BlockStatus]](getBlockStatus) } else { Future { info.getStatus(blockId) } } @@ -277,7 +277,7 @@ class BlockManagerMasterEndpoint( blockManagerInfo.values.map { info => val future = if (askSlaves) { - info.slaveEndpoint.sendWithReply[Seq[BlockId]](getMatchingBlockIds) + info.slaveEndpoint.ask[Seq[BlockId]](getMatchingBlockIds) } else { Future { info.blocks.keys.filter(filter).toSeq } } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 0fd570e5297d9..b789912e9ebef 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -48,7 +48,7 @@ class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { val metrics = new TaskMetrics val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithReply[HeartbeatResponse]( + val response = receiverRef.askWithRetry[HeartbeatResponse]( Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) verify(scheduler).executorHeartbeatReceived( @@ -71,7 +71,7 @@ class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { val metrics = new TaskMetrics val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithReply[HeartbeatResponse]( + val response = receiverRef.askWithRetry[HeartbeatResponse]( Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) verify(scheduler).executorHeartbeatReceived( diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 44c88b00c442a..ae3339d80f9c6 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -100,8 +100,8 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) - val newRpcEndpointRef = rpcEndpointRef.askWithReply[RpcEndpointRef]("Hello") - val reply = newRpcEndpointRef.askWithReply[String]("Echo") + val newRpcEndpointRef = rpcEndpointRef.askWithRetry[RpcEndpointRef]("Hello") + val reply = newRpcEndpointRef.askWithRetry[String]("Echo") assert("Echo" === reply) } @@ -115,7 +115,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } } }) - val reply = rpcEndpointRef.askWithReply[String]("hello") + val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) } @@ -134,7 +134,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") try { - val reply = rpcEndpointRef.askWithReply[String]("hello") + val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) } finally { anotherEnv.shutdown() @@ -162,7 +162,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { val e = intercept[Exception] { - rpcEndpointRef.askWithReply[String]("hello", 1 millis) + rpcEndpointRef.askWithRetry[String]("hello", 1 millis) } assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) } finally { @@ -399,7 +399,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } }) - val f = endpointRef.sendWithReply[String]("Hi") + val f = endpointRef.ask[String]("Hi") val ack = Await.result(f, 5 seconds) assert("ack" === ack) @@ -419,7 +419,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") try { - val f = rpcEndpointRef.sendWithReply[String]("hello") + val f = rpcEndpointRef.ask[String]("hello") val ack = Await.result(f, 5 seconds) assert("ack" === ack) } finally { @@ -437,7 +437,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } }) - val f = endpointRef.sendWithReply[String]("Hi") + val f = endpointRef.ask[String]("Hi") val e = intercept[SparkException] { Await.result(f, 5 seconds) } @@ -460,7 +460,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-remotely-error") try { - val f = rpcEndpointRef.sendWithReply[String]("hello") + val f = rpcEndpointRef.ask[String]("hello") val e = intercept[SparkException] { Await.result(f, 5 seconds) } @@ -529,7 +529,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-unserializable-error") try { - val f = rpcEndpointRef.sendWithReply[String]("hello") + val f = rpcEndpointRef.ask[String]("hello") intercept[TimeoutException] { Await.result(f, 1 seconds) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 6957bc72e9903..f5b410f41da27 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -356,7 +356,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - val reregister = !master.driverEndpoint.askWithReply[Boolean]( + val reregister = !master.driverEndpoint.askWithRetry[Boolean]( BlockManagerHeartbeat(store.blockManagerId)) assert(reregister == true) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 89af40330b9d9..f2379366f36ef 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -146,7 +146,7 @@ private[streaming] class ReceiverSupervisorImpl( logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") val blockInfo = ReceivedBlockInfo(streamId, numRecords, blockStoreResult) - trackerEndpoint.askWithReply[Boolean](AddBlock(blockInfo)) + trackerEndpoint.askWithRetry[Boolean](AddBlock(blockInfo)) logDebug(s"Reported block $blockId") } @@ -169,13 +169,13 @@ private[streaming] class ReceiverSupervisorImpl( override protected def onReceiverStart() { val msg = RegisterReceiver( streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) - trackerEndpoint.askWithReply[Boolean](msg) + trackerEndpoint.askWithRetry[Boolean](msg) } override protected def onReceiverStop(message: String, error: Option[Throwable]) { logInfo("Deregistering receiver " + streamId) val errorString = error.map(Throwables.getStackTraceAsString).getOrElse("") - trackerEndpoint.askWithReply[Boolean](DeregisterReceiver(streamId, message, errorString)) + trackerEndpoint.askWithRetry[Boolean](DeregisterReceiver(streamId, message, errorString)) logInfo("Stopped receiver " + streamId) } From 3df9c5dd0ce9ce82f9029c2846ad2f6164c501f3 Mon Sep 17 00:00:00 2001 From: ksonj Date: Wed, 29 Apr 2015 09:48:47 -0700 Subject: [PATCH 178/191] Better error message on access to non-existing attribute I believe column access via `__getattr__` is bad and shouldn't be implicitly encouraged by the error message when accessing a non-existing attribute on DataFrame. This patch changes the error message from 'no such column' to the more generic 'no such attribute', which is also what Pandas DFs will throw. Author: ksonj Closes #5771 from ksonj/master and squashes the following commits: bcc2220 [ksonj] Better error message on access to non-existing attribute --- python/pyspark/sql/dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6879fe0805c58..d9cbbc68b3bf0 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -633,7 +633,8 @@ def __getattr__(self, name): [Row(age=2), Row(age=5)] """ if name not in self.columns: - raise AttributeError("No such column: %s" % name) + raise AttributeError( + "'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) jc = self._jdf.apply(name) return Column(jc) From 81ea42bf39db7366e266bc1e5ce7e459fa2ef409 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 29 Apr 2015 09:49:24 -0700 Subject: [PATCH 179/191] [SQL][Minor] fix java doc for DataFrame.agg Author: Wenchen Fan Closes #5712 from cloud-fan/minor and squashes the following commits: be23064 [Wenchen Fan] fix java doc for DataFrame.agg --- .../scala/org/apache/spark/sql/DataFrame.scala | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 2affba7d42cc8..0e896e5693b98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -623,17 +623,12 @@ class DataFrame private[sql]( } /** - * (Scala-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg( - * "age" -> "max", - * "expense" -> "sum" - * ) - * }}} + * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg("age" -> "max", "salary" -> "avg") + * df.groupBy().agg("age" -> "max", "salary" -> "avg") + * }} * @group dfops */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { From c0c0ba6d2a11febc8e874c437c4f676dd36ae059 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 29 Apr 2015 10:13:48 -0700 Subject: [PATCH 180/191] Fix a typo of "threshold" mengxr Author: Xusen Yin Closes #5769 from yinxusen/patch-1 and squashes the following commits: 43235f4 [Xusen Yin] Update PearsonCorrelation.scala f7287ee [Xusen Yin] Fix a typo of "threshold" --- .../spark/mllib/stat/correlation/PearsonCorrelation.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala index 23b291eee070b..8a821d1b23bab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala @@ -101,7 +101,7 @@ private[stat] object PearsonCorrelation extends Correlation with Logging { Matrices.fromBreeze(cov) } - private def closeToZero(value: Double, threshhold: Double = 1e-12): Boolean = { - math.abs(value) <= threshhold + private def closeToZero(value: Double, threshold: Double = 1e-12): Boolean = { + math.abs(value) <= threshold } } From 1868bd40dcce23990b98748b0239bd00452b1ca5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 29 Apr 2015 13:06:11 -0700 Subject: [PATCH 181/191] [SPARK-7056] [STREAMING] Make the Write Ahead Log pluggable Users may want the WAL data to be written to non-HDFS data storage systems. To allow that, we have to make the WAL pluggable. The following design doc outlines the plan. https://docs.google.com/a/databricks.com/document/d/1A2XaOLRFzvIZSi18i_luNw5Rmm9j2j4AigktXxIYxmY/edit?usp=sharing Things to add. * Unit tests for WriteAheadLogUtils Author: Tathagata Das Closes #5645 from tdas/wal-pluggable and squashes the following commits: 2c431fd [Tathagata Das] Minor fixes. c2bc7384 [Tathagata Das] More changes based on PR comments. 569a416 [Tathagata Das] fixed long line bde26b1 [Tathagata Das] Renamed segment to record handle everywhere b65e155 [Tathagata Das] More changes based on PR comments. d7cd15b [Tathagata Das] Fixed test 1a32a4b [Tathagata Das] Fixed test e0d19fb [Tathagata Das] Fixed defaults 9310cbf [Tathagata Das] style fix. 86abcb1 [Tathagata Das] Refactored WriteAheadLogUtils, and consolidated all WAL related configuration into it. 84ce469 [Tathagata Das] Added unit test and fixed compilation error. bce5e75 [Tathagata Das] Fixed long lines. 837c4f5 [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into wal-pluggable 754fbf8 [Tathagata Das] Added license and docs. 09bc6fe [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into wal-pluggable 7dd2d4b [Tathagata Das] Added pluggable WriteAheadLog interface, and refactored all code along with it --- .../spark/streaming/kafka/KafkaUtils.scala | 3 +- .../spark/streaming/util/WriteAheadLog.java | 60 ++++++ .../util/WriteAheadLogRecordHandle.java | 30 +++ .../dstream/ReceiverInputDStream.scala | 2 +- .../rdd/WriteAheadLogBackedBlockRDD.scala | 79 +++++-- .../receiver/ReceivedBlockHandler.scala | 38 ++-- .../receiver/ReceiverSupervisorImpl.scala | 5 +- .../scheduler/ReceivedBlockTracker.scala | 38 ++-- .../streaming/scheduler/ReceiverTracker.scala | 3 +- ...ger.scala => FileBasedWriteAheadLog.scala} | 76 ++++--- ... FileBasedWriteAheadLogRandomReader.scala} | 8 +- ...ala => FileBasedWriteAheadLogReader.scala} | 4 +- ...la => FileBasedWriteAheadLogSegment.scala} | 3 +- ...ala => FileBasedWriteAheadLogWriter.scala} | 9 +- .../streaming/util/WriteAheadLogUtils.scala | 129 ++++++++++++ .../streaming/JavaWriteAheadLogSuite.java | 129 ++++++++++++ .../streaming/ReceivedBlockHandlerSuite.scala | 18 +- .../streaming/ReceivedBlockTrackerSuite.scala | 28 +-- .../spark/streaming/ReceiverSuite.scala | 2 +- .../streaming/StreamingContextSuite.scala | 4 +- .../WriteAheadLogBackedBlockRDDSuite.scala | 31 +-- .../streaming/util/WriteAheadLogSuite.scala | 194 ++++++++++++------ 22 files changed, 686 insertions(+), 207 deletions(-) create mode 100644 streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java create mode 100644 streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java rename streaming/src/main/scala/org/apache/spark/streaming/util/{WriteAheadLogManager.scala => FileBasedWriteAheadLog.scala} (79%) rename streaming/src/main/scala/org/apache/spark/streaming/util/{WriteAheadLogRandomReader.scala => FileBasedWriteAheadLogRandomReader.scala} (83%) rename streaming/src/main/scala/org/apache/spark/streaming/util/{WriteAheadLogReader.scala => FileBasedWriteAheadLogReader.scala} (93%) rename streaming/src/main/scala/org/apache/spark/streaming/util/{WriteAheadLogFileSegment.scala => FileBasedWriteAheadLogSegment.scala} (86%) rename streaming/src/main/scala/org/apache/spark/streaming/util/{WriteAheadLogWriter.scala => FileBasedWriteAheadLogWriter.scala} (88%) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 0721ddaf7055a..d7cf500577c2a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -31,6 +31,7 @@ import kafka.message.MessageAndMetadata import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder} import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD @@ -80,7 +81,7 @@ object KafkaUtils { topics: Map[String, Int], storageLevel: StorageLevel ): ReceiverInputDStream[(K, V)] = { - val walEnabled = ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false) + val walEnabled = WriteAheadLogUtils.enableReceiverLog(ssc.conf) new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, walEnabled, storageLevel) } diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java new file mode 100644 index 0000000000000..8c0fdfa9c7478 --- /dev/null +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util; + +import java.nio.ByteBuffer; +import java.util.Iterator; + +/** + * This abstract class represents a write ahead log (aka journal) that is used by Spark Streaming + * to save the received data (by receivers) and associated metadata to a reliable storage, so that + * they can be recovered after driver failures. See the Spark documentation for more information + * on how to plug in your own custom implementation of a write ahead log. + */ +@org.apache.spark.annotation.DeveloperApi +public abstract class WriteAheadLog { + /** + * Write the record to the log and return a record handle, which contains all the information + * necessary to read back the written record. The time is used to the index the record, + * such that it can be cleaned later. Note that implementations of this abstract class must + * ensure that the written data is durable and readable (using the record handle) by the + * time this function returns. + */ + abstract public WriteAheadLogRecordHandle write(ByteBuffer record, long time); + + /** + * Read a written record based on the given record handle. + */ + abstract public ByteBuffer read(WriteAheadLogRecordHandle handle); + + /** + * Read and return an iterator of all the records that have been written but not yet cleaned up. + */ + abstract public Iterator readAll(); + + /** + * Clean all the records that are older than the threshold time. It can wait for + * the completion of the deletion. + */ + abstract public void clean(long threshTime, boolean waitForCompletion); + + /** + * Close this log and release any resources. + */ + abstract public void close(); +} diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java new file mode 100644 index 0000000000000..02324189b7822 --- /dev/null +++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util; + +/** + * This abstract class represents a handle that refers to a record written in a + * {@link org.apache.spark.streaming.util.WriteAheadLog WriteAheadLog}. + * It must contain all the information necessary for the record to be read and returned by + * an implemenation of the WriteAheadLog class. + * + * @see org.apache.spark.streaming.util.WriteAheadLog + */ +@org.apache.spark.annotation.DeveloperApi +public abstract class WriteAheadLogRecordHandle implements java.io.Serializable { +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 8be04314c4285..4c7fd2c57c006 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -82,7 +82,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont // WriteAheadLogBackedBlockRDD else create simple BlockRDD. if (resultTypes.size == 1 && resultTypes.head == classOf[WriteAheadLogBasedStoreResult]) { val logSegments = blockStoreResults.map { - _.asInstanceOf[WriteAheadLogBasedStoreResult].segment + _.asInstanceOf[WriteAheadLogBasedStoreResult].walRecordHandle }.toArray // Since storeInBlockManager = false, the storage level does not matter. new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 93caa4ba35c7f..ebdf418f4ab6a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -16,14 +16,17 @@ */ package org.apache.spark.streaming.rdd +import java.nio.ByteBuffer + import scala.reflect.ClassTag +import scala.util.control.NonFatal -import org.apache.hadoop.conf.Configuration +import org.apache.commons.io.FileUtils import org.apache.spark._ import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} -import org.apache.spark.streaming.util.{HdfsUtils, WriteAheadLogFileSegment, WriteAheadLogRandomReader} +import org.apache.spark.streaming.util._ /** * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]]. @@ -31,26 +34,27 @@ import org.apache.spark.streaming.util.{HdfsUtils, WriteAheadLogFileSegment, Wri * the segment of the write ahead log that backs the partition. * @param index index of the partition * @param blockId id of the block having the partition data - * @param segment segment of the write ahead log having the partition data + * @param walRecordHandle Handle of the record in a write ahead log having the partition data */ private[streaming] class WriteAheadLogBackedBlockRDDPartition( val index: Int, val blockId: BlockId, - val segment: WriteAheadLogFileSegment) + val walRecordHandle: WriteAheadLogRecordHandle) extends Partition /** * This class represents a special case of the BlockRDD where the data blocks in - * the block manager are also backed by segments in write ahead logs. For reading + * the block manager are also backed by data in write ahead logs. For reading * the data, this RDD first looks up the blocks by their ids in the block manager. - * If it does not find them, it looks up the corresponding file segment. + * If it does not find them, it looks up the corresponding data in the write ahead log. * * @param sc SparkContext * @param blockIds Ids of the blocks that contains this RDD's data - * @param segments Segments in write ahead logs that contain this RDD's data - * @param storeInBlockManager Whether to store in the block manager after reading from the segment + * @param walRecordHandles Record handles in write ahead logs that contain this RDD's data + * @param storeInBlockManager Whether to store in the block manager after reading + * from the WAL record * @param storageLevel storage level to store when storing in block manager * (applicable when storeInBlockManager = true) */ @@ -58,15 +62,15 @@ private[streaming] class WriteAheadLogBackedBlockRDD[T: ClassTag]( @transient sc: SparkContext, @transient blockIds: Array[BlockId], - @transient segments: Array[WriteAheadLogFileSegment], + @transient walRecordHandles: Array[WriteAheadLogRecordHandle], storeInBlockManager: Boolean, storageLevel: StorageLevel) extends BlockRDD[T](sc, blockIds) { require( - blockIds.length == segments.length, + blockIds.length == walRecordHandles.length, s"Number of block ids (${blockIds.length}) must be " + - s"the same as number of segments (${segments.length}})!") + s"the same as number of WAL record handles (${walRecordHandles.length}})!") // Hadoop configuration is not serializable, so broadcast it as a serializable. @transient private val hadoopConfig = sc.hadoopConfiguration @@ -75,13 +79,13 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { assertValid() Array.tabulate(blockIds.size) { i => - new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), segments(i)) + new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), walRecordHandles(i)) } } /** * Gets the partition data by getting the corresponding block from the block manager. - * If the block does not exist, then the data is read from the corresponding segment + * If the block does not exist, then the data is read from the corresponding record * in write ahead log files. */ override def compute(split: Partition, context: TaskContext): Iterator[T] = { @@ -96,10 +100,35 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logDebug(s"Read partition data of $this from block manager, block $blockId") iterator case None => // Data not found in Block Manager, grab it from write ahead log file - val reader = new WriteAheadLogRandomReader(partition.segment.path, hadoopConf) - val dataRead = reader.read(partition.segment) - reader.close() - logInfo(s"Read partition data of $this from write ahead log, segment ${partition.segment}") + var dataRead: ByteBuffer = null + var writeAheadLog: WriteAheadLog = null + try { + // The WriteAheadLogUtils.createLog*** method needs a directory to create a + // WriteAheadLog object as the default FileBasedWriteAheadLog needs a directory for + // writing log data. However, the directory is not needed if data needs to be read, hence + // a dummy path is provided to satisfy the method parameter requirements. + // FileBasedWriteAheadLog will not create any file or directory at that path. + val dummyDirectory = FileUtils.getTempDirectoryPath() + writeAheadLog = WriteAheadLogUtils.createLogForReceiver( + SparkEnv.get.conf, dummyDirectory, hadoopConf) + dataRead = writeAheadLog.read(partition.walRecordHandle) + } catch { + case NonFatal(e) => + throw new SparkException( + s"Could not read data from write ahead log record ${partition.walRecordHandle}", e) + } finally { + if (writeAheadLog != null) { + writeAheadLog.close() + writeAheadLog = null + } + } + if (dataRead == null) { + throw new SparkException( + s"Could not read data from write ahead log record ${partition.walRecordHandle}, " + + s"read returned null") + } + logInfo(s"Read partition data of $this from write ahead log, record handle " + + partition.walRecordHandle) if (storeInBlockManager) { blockManager.putBytes(blockId, dataRead, storageLevel) logDebug(s"Stored partition data of $this into block manager with level $storageLevel") @@ -111,14 +140,20 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( /** * Get the preferred location of the partition. This returns the locations of the block - * if it is present in the block manager, else it returns the location of the - * corresponding segment in HDFS. + * if it is present in the block manager, else if FileBasedWriteAheadLogSegment is used, + * it returns the location of the corresponding file segment in HDFS . */ override def getPreferredLocations(split: Partition): Seq[String] = { val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition] val blockLocations = getBlockIdLocations().get(partition.blockId) - blockLocations.getOrElse( - HdfsUtils.getFileSegmentLocations( - partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig)) + blockLocations.getOrElse { + partition.walRecordHandle match { + case fileSegment: FileBasedWriteAheadLogSegment => + HdfsUtils.getFileSegmentLocations( + fileSegment.path, fileSegment.offset, fileSegment.length, hadoopConfig) + case _ => + Seq.empty + } + } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 297bf04c0c25e..4b3d9ee4b0090 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -17,18 +17,18 @@ package org.apache.spark.streaming.receiver -import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.{existentials, postfixOps} -import WriteAheadLogBasedBlockHandler._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage._ -import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogManager} -import org.apache.spark.util.{ThreadUtils, Clock, SystemClock} +import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ +import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} +import org.apache.spark.{Logging, SparkConf, SparkException} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { @@ -96,7 +96,7 @@ private[streaming] class BlockManagerBasedBlockHandler( */ private[streaming] case class WriteAheadLogBasedStoreResult( blockId: StreamBlockId, - segment: WriteAheadLogFileSegment + walRecordHandle: WriteAheadLogRecordHandle ) extends ReceivedBlockStoreResult @@ -116,10 +116,6 @@ private[streaming] class WriteAheadLogBasedBlockHandler( private val blockStoreTimeout = conf.getInt( "spark.streaming.receiver.blockStoreTimeout", 30).seconds - private val rollingInterval = conf.getInt( - "spark.streaming.receiver.writeAheadLog.rollingInterval", 60) - private val maxFailures = conf.getInt( - "spark.streaming.receiver.writeAheadLog.maxFailures", 3) private val effectiveStorageLevel = { if (storageLevel.deserialized) { @@ -139,13 +135,9 @@ private[streaming] class WriteAheadLogBasedBlockHandler( s"$effectiveStorageLevel when write ahead log is enabled") } - // Manages rolling log files - private val logManager = new WriteAheadLogManager( - checkpointDirToLogDir(checkpointDir, streamId), - hadoopConf, rollingInterval, maxFailures, - callerName = this.getClass.getSimpleName, - clock = clock - ) + // Write ahead log manages + private val writeAheadLog = WriteAheadLogUtils.createLogForReceiver( + conf, checkpointDirToLogDir(checkpointDir, streamId), hadoopConf) // For processing futures used in parallel block storing into block manager and write ahead log // # threads = 2, so that both writing to BM and WAL can proceed in parallel @@ -183,21 +175,21 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Store the block in write ahead log val storeInWriteAheadLogFuture = Future { - logManager.writeToLog(serializedBlock) + writeAheadLog.write(serializedBlock, clock.getTimeMillis()) } - // Combine the futures, wait for both to complete, and return the write ahead log segment + // Combine the futures, wait for both to complete, and return the write ahead log record handle val combinedFuture = storeInBlockManagerFuture.zip(storeInWriteAheadLogFuture).map(_._2) - val segment = Await.result(combinedFuture, blockStoreTimeout) - WriteAheadLogBasedStoreResult(blockId, segment) + val walRecordHandle = Await.result(combinedFuture, blockStoreTimeout) + WriteAheadLogBasedStoreResult(blockId, walRecordHandle) } def cleanupOldBlocks(threshTime: Long) { - logManager.cleanupOldLogs(threshTime, waitForCompletion = false) + writeAheadLog.clean(threshTime, false) } def stop() { - logManager.stop() + writeAheadLog.close() } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index f2379366f36ef..93f047b91018f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -25,12 +25,13 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration -import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.rpc.{RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.{Logging, SparkEnv, SparkException} /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] @@ -46,7 +47,7 @@ private[streaming] class ReceiverSupervisorImpl( ) extends ReceiverSupervisor(receiver, env.conf) with Logging { private val receivedBlockHandler: ReceivedBlockHandler = { - if (env.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + if (WriteAheadLogUtils.enableReceiverLog(env.conf)) { if (checkpointDirOption.isEmpty) { throw new SparkException( "Cannot enable receiver write-ahead log without checkpoint directory set. " + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 200cf4ef4b0f1..14e769a281f51 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -25,10 +25,10 @@ import scala.language.implicitConversions import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.streaming.Time -import org.apache.spark.streaming.util.WriteAheadLogManager +import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} +import org.apache.spark.{Logging, SparkConf, SparkException} /** Trait representing any event in the ReceivedBlockTracker that updates its state. */ private[streaming] sealed trait ReceivedBlockTrackerLogEvent @@ -70,7 +70,7 @@ private[streaming] class ReceivedBlockTracker( private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue] private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks] - private val logManagerOption = createLogManager() + private val writeAheadLogOption = createWriteAheadLog() private var lastAllocatedBatchTime: Time = null @@ -155,12 +155,12 @@ private[streaming] class ReceivedBlockTracker( logInfo("Deleting batches " + timesToCleanup) writeToLog(BatchCleanupEvent(timesToCleanup)) timeToAllocatedBlocks --= timesToCleanup - logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion)) + writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) } /** Stop the block tracker. */ def stop() { - logManagerOption.foreach { _.stop() } + writeAheadLogOption.foreach { _.close() } } /** @@ -190,9 +190,10 @@ private[streaming] class ReceivedBlockTracker( timeToAllocatedBlocks --= batchTimes } - logManagerOption.foreach { logManager => + writeAheadLogOption.foreach { writeAheadLog => logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}") - logManager.readFromLog().foreach { byteBuffer => + import scala.collection.JavaConversions._ + writeAheadLog.readAll().foreach { byteBuffer => logTrace("Recovering record " + byteBuffer) Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match { case BlockAdditionEvent(receivedBlockInfo) => @@ -208,10 +209,10 @@ private[streaming] class ReceivedBlockTracker( /** Write an update to the tracker to the write ahead log */ private def writeToLog(record: ReceivedBlockTrackerLogEvent) { - if (isLogManagerEnabled) { + if (isWriteAheadLogEnabled) { logDebug(s"Writing to log $record") - logManagerOption.foreach { logManager => - logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record))) + writeAheadLogOption.foreach { logManager => + logManager.write(ByteBuffer.wrap(Utils.serialize(record)), clock.getTimeMillis()) } } } @@ -222,8 +223,8 @@ private[streaming] class ReceivedBlockTracker( } /** Optionally create the write ahead log manager only if the feature is enabled */ - private def createLogManager(): Option[WriteAheadLogManager] = { - if (conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + private def createWriteAheadLog(): Option[WriteAheadLog] = { + if (WriteAheadLogUtils.enableReceiverLog(conf)) { if (checkpointDirOption.isEmpty) { throw new SparkException( "Cannot enable receiver write-ahead log without checkpoint directory set. " + @@ -231,19 +232,16 @@ private[streaming] class ReceivedBlockTracker( "See documentation for more details.") } val logDir = ReceivedBlockTracker.checkpointDirToLogDir(checkpointDirOption.get) - val rollingIntervalSecs = conf.getInt( - "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60) - val logManager = new WriteAheadLogManager(logDir, hadoopConf, - rollingIntervalSecs = rollingIntervalSecs, clock = clock, - callerName = "ReceivedBlockHandlerMaster") - Some(logManager) + + val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) + Some(log) } else { None } } - /** Check if the log manager is enabled. This is only used for testing purposes. */ - private[streaming] def isLogManagerEnabled: Boolean = logManagerOption.nonEmpty + /** Check if the write ahead log is enabled. This is only used for testing purposes. */ + private[streaming] def isWriteAheadLogEnabled: Boolean = writeAheadLogOption.nonEmpty } private[streaming] object ReceivedBlockTracker { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index c4ead6f30a63d..1af65716d3003 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -20,6 +20,7 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials +import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} @@ -125,7 +126,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false) // Signal the receivers to delete old block data - if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") receiverInfo.values.flatMap { info => Option(info.endpoint) } .foreach { _.send(CleanupOldBlocks(cleanupThreshTime)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala similarity index 79% rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 38a93cc3c9a1f..9985fedc35141 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -17,6 +17,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer +import java.util.{Iterator => JIterator} import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, ExecutionContext, Future} @@ -24,9 +25,9 @@ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.Logging -import org.apache.spark.util.{ThreadUtils, Clock, SystemClock} -import WriteAheadLogManager._ + +import org.apache.spark.util.ThreadUtils +import org.apache.spark.{Logging, SparkConf} /** * This class manages write ahead log files. @@ -34,37 +35,32 @@ import WriteAheadLogManager._ * - Recovers the log files and the reads the recovered records upon failures. * - Cleans up old log files. * - * Uses [[org.apache.spark.streaming.util.WriteAheadLogWriter]] to write - * and [[org.apache.spark.streaming.util.WriteAheadLogReader]] to read. + * Uses [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]] to write + * and [[org.apache.spark.streaming.util.FileBasedWriteAheadLogReader]] to read. * * @param logDirectory Directory when rotating log files will be created. * @param hadoopConf Hadoop configuration for reading/writing log files. - * @param rollingIntervalSecs The interval in seconds with which logs will be rolled over. - * Default is one minute. - * @param maxFailures Max number of failures that is tolerated for every attempt to write to log. - * Default is three. - * @param callerName Optional name of the class who is using this manager. - * @param clock Optional clock that is used to check for rotation interval. */ -private[streaming] class WriteAheadLogManager( +private[streaming] class FileBasedWriteAheadLog( + conf: SparkConf, logDirectory: String, hadoopConf: Configuration, - rollingIntervalSecs: Int = 60, - maxFailures: Int = 3, - callerName: String = "", - clock: Clock = new SystemClock - ) extends Logging { + rollingIntervalSecs: Int, + maxFailures: Int + ) extends WriteAheadLog with Logging { + + import FileBasedWriteAheadLog._ private val pastLogs = new ArrayBuffer[LogInfo] - private val callerNameTag = - if (callerName.nonEmpty) s" for $callerName" else "" + private val callerNameTag = getCallerName.map(c => s" for $c").getOrElse("") + private val threadpoolName = s"WriteAheadLogManager $callerNameTag" implicit private val executionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonSingleThreadExecutor(threadpoolName)) override protected val logName = s"WriteAheadLogManager $callerNameTag" private var currentLogPath: Option[String] = None - private var currentLogWriter: WriteAheadLogWriter = null + private var currentLogWriter: FileBasedWriteAheadLogWriter = null private var currentLogWriterStartTime: Long = -1L private var currentLogWriterStopTime: Long = -1L @@ -75,14 +71,14 @@ private[streaming] class WriteAheadLogManager( * ByteBuffer to HDFS. When this method returns, the data is guaranteed to have been flushed * to HDFS, and will be available for readers to read. */ - def writeToLog(byteBuffer: ByteBuffer): WriteAheadLogFileSegment = synchronized { - var fileSegment: WriteAheadLogFileSegment = null + def write(byteBuffer: ByteBuffer, time: Long): FileBasedWriteAheadLogSegment = synchronized { + var fileSegment: FileBasedWriteAheadLogSegment = null var failures = 0 var lastException: Exception = null var succeeded = false while (!succeeded && failures < maxFailures) { try { - fileSegment = getLogWriter(clock.getTimeMillis()).write(byteBuffer) + fileSegment = getLogWriter(time).write(byteBuffer) succeeded = true } catch { case ex: Exception => @@ -99,6 +95,19 @@ private[streaming] class WriteAheadLogManager( fileSegment } + def read(segment: WriteAheadLogRecordHandle): ByteBuffer = { + val fileSegment = segment.asInstanceOf[FileBasedWriteAheadLogSegment] + var reader: FileBasedWriteAheadLogRandomReader = null + var byteBuffer: ByteBuffer = null + try { + reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf) + byteBuffer = reader.read(fileSegment) + } finally { + reader.close() + } + byteBuffer + } + /** * Read all the existing logs from the log directory. * @@ -108,12 +117,14 @@ private[streaming] class WriteAheadLogManager( * the latest the records. This does not deal with currently active log files, and * hence the implementation is kept simple. */ - def readFromLog(): Iterator[ByteBuffer] = synchronized { + def readAll(): JIterator[ByteBuffer] = synchronized { + import scala.collection.JavaConversions._ val logFilesToRead = pastLogs.map{ _.path} ++ currentLogPath logInfo("Reading from the logs: " + logFilesToRead.mkString("\n")) + logFilesToRead.iterator.map { file => logDebug(s"Creating log reader with $file") - new WriteAheadLogReader(file, hadoopConf) + new FileBasedWriteAheadLogReader(file, hadoopConf) } flatMap { x => x } } @@ -129,7 +140,7 @@ private[streaming] class WriteAheadLogManager( * deleted. This should be set to true only for testing. Else the files will be deleted * asynchronously. */ - def cleanupOldLogs(threshTime: Long, waitForCompletion: Boolean): Unit = { + def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { val oldLogFiles = synchronized { pastLogs.filter { _.endTime < threshTime } } logInfo(s"Attempting to clear ${oldLogFiles.size} old log files in $logDirectory " + s"older than $threshTime: ${oldLogFiles.map { _.path }.mkString("\n")}") @@ -160,7 +171,7 @@ private[streaming] class WriteAheadLogManager( /** Stop the manager, close any open log writer */ - def stop(): Unit = synchronized { + def close(): Unit = synchronized { if (currentLogWriter != null) { currentLogWriter.close() } @@ -169,7 +180,7 @@ private[streaming] class WriteAheadLogManager( } /** Get the current log writer while taking care of rotation */ - private def getLogWriter(currentTime: Long): WriteAheadLogWriter = synchronized { + private def getLogWriter(currentTime: Long): FileBasedWriteAheadLogWriter = synchronized { if (currentLogWriter == null || currentTime > currentLogWriterStopTime) { resetWriter() currentLogPath.foreach { @@ -180,7 +191,7 @@ private[streaming] class WriteAheadLogManager( val newLogPath = new Path(logDirectory, timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime)) currentLogPath = Some(newLogPath.toString) - currentLogWriter = new WriteAheadLogWriter(currentLogPath.get, hadoopConf) + currentLogWriter = new FileBasedWriteAheadLogWriter(currentLogPath.get, hadoopConf) } currentLogWriter } @@ -207,7 +218,7 @@ private[streaming] class WriteAheadLogManager( } } -private[util] object WriteAheadLogManager { +private[streaming] object FileBasedWriteAheadLog { case class LogInfo(startTime: Long, endTime: Long, path: String) @@ -217,6 +228,11 @@ private[util] object WriteAheadLogManager { s"log-$startTime-$stopTime" } + def getCallerName(): Option[String] = { + val stackTraceClasses = Thread.currentThread.getStackTrace().map(_.getClassName) + stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split(".").lastOption) + } + /** Convert a sequence of files to a sequence of sorted LogInfo objects */ def logFilesTologInfo(files: Seq[Path]): Seq[LogInfo] = { files.flatMap { file => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala similarity index 83% rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala index 003989092a42a..f7168229ec15a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala @@ -23,16 +23,16 @@ import org.apache.hadoop.conf.Configuration /** * A random access reader for reading write ahead log files written using - * [[org.apache.spark.streaming.util.WriteAheadLogWriter]]. Given the file segment info, - * this reads the record (bytebuffer) from the log file. + * [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]]. Given the file segment info, + * this reads the record (ByteBuffer) from the log file. */ -private[streaming] class WriteAheadLogRandomReader(path: String, conf: Configuration) +private[streaming] class FileBasedWriteAheadLogRandomReader(path: String, conf: Configuration) extends Closeable { private val instream = HdfsUtils.getInputStream(path, conf) private var closed = false - def read(segment: WriteAheadLogFileSegment): ByteBuffer = synchronized { + def read(segment: FileBasedWriteAheadLogSegment): ByteBuffer = synchronized { assertOpen() instream.seek(segment.offset) val nextLength = instream.readInt() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala similarity index 93% rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogReader.scala rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala index 2afc0d1551acf..c3bb59f3fef94 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala @@ -24,11 +24,11 @@ import org.apache.spark.Logging /** * A reader for reading write ahead log files written using - * [[org.apache.spark.streaming.util.WriteAheadLogWriter]]. This reads + * [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]]. This reads * the records (bytebuffers) in the log file sequentially and return them as an * iterator of bytebuffers. */ -private[streaming] class WriteAheadLogReader(path: String, conf: Configuration) +private[streaming] class FileBasedWriteAheadLogReader(path: String, conf: Configuration) extends Iterator[ByteBuffer] with Closeable with Logging { private val instream = HdfsUtils.getInputStream(path, conf) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogFileSegment.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogSegment.scala similarity index 86% rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogFileSegment.scala rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogSegment.scala index 1005a2c8ec303..2e1f1528fad20 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogFileSegment.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogSegment.scala @@ -17,4 +17,5 @@ package org.apache.spark.streaming.util /** Class for representing a segment of data in a write ahead log file */ -private[streaming] case class WriteAheadLogFileSegment (path: String, offset: Long, length: Int) +private[streaming] case class FileBasedWriteAheadLogSegment(path: String, offset: Long, length: Int) + extends WriteAheadLogRecordHandle diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala similarity index 88% rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogWriter.scala rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala index 679f6a6dfd7c1..e146bec32a456 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala @@ -17,18 +17,17 @@ package org.apache.spark.streaming.util import java.io._ -import java.net.URI import java.nio.ByteBuffer import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FSDataOutputStream, FileSystem} +import org.apache.hadoop.fs.FSDataOutputStream /** * A writer for writing byte-buffers to a write ahead log file. */ -private[streaming] class WriteAheadLogWriter(path: String, hadoopConf: Configuration) +private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: Configuration) extends Closeable { private lazy val stream = HdfsUtils.getOutputStream(path, hadoopConf) @@ -43,11 +42,11 @@ private[streaming] class WriteAheadLogWriter(path: String, hadoopConf: Configura private var closed = false /** Write the bytebuffer to the log file */ - def write(data: ByteBuffer): WriteAheadLogFileSegment = synchronized { + def write(data: ByteBuffer): FileBasedWriteAheadLogSegment = synchronized { assertOpen() data.rewind() // Rewind to ensure all data in the buffer is retrieved val lengthToWrite = data.remaining() - val segment = new WriteAheadLogFileSegment(path, nextOffset, lengthToWrite) + val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite) stream.writeInt(lengthToWrite) if (data.hasArray) { stream.write(data.array()) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala new file mode 100644 index 0000000000000..7f6ff12c58d47 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf, SparkException} + +/** A helper class with utility functions related to the WriteAheadLog interface */ +private[streaming] object WriteAheadLogUtils extends Logging { + val RECEIVER_WAL_ENABLE_CONF_KEY = "spark.streaming.receiver.writeAheadLog.enable" + val RECEIVER_WAL_CLASS_CONF_KEY = "spark.streaming.receiver.writeAheadLog.class" + val RECEIVER_WAL_ROLLING_INTERVAL_CONF_KEY = + "spark.streaming.receiver.writeAheadLog.rollingIntervalSecs" + val RECEIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.receiver.writeAheadLog.maxFailures" + + val DRIVER_WAL_CLASS_CONF_KEY = "spark.streaming.driver.writeAheadLog.class" + val DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY = + "spark.streaming.driver.writeAheadLog.rollingIntervalSecs" + val DRIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.driver.writeAheadLog.maxFailures" + + val DEFAULT_ROLLING_INTERVAL_SECS = 60 + val DEFAULT_MAX_FAILURES = 3 + + def enableReceiverLog(conf: SparkConf): Boolean = { + conf.getBoolean(RECEIVER_WAL_ENABLE_CONF_KEY, false) + } + + def getRollingIntervalSecs(conf: SparkConf, isDriver: Boolean): Int = { + if (isDriver) { + conf.getInt(DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY, DEFAULT_ROLLING_INTERVAL_SECS) + } else { + conf.getInt(RECEIVER_WAL_ROLLING_INTERVAL_CONF_KEY, DEFAULT_ROLLING_INTERVAL_SECS) + } + } + + def getMaxFailures(conf: SparkConf, isDriver: Boolean): Int = { + if (isDriver) { + conf.getInt(DRIVER_WAL_MAX_FAILURES_CONF_KEY, DEFAULT_MAX_FAILURES) + } else { + conf.getInt(RECEIVER_WAL_MAX_FAILURES_CONF_KEY, DEFAULT_MAX_FAILURES) + } + } + + /** + * Create a WriteAheadLog for the driver. If configured with custom WAL class, it will try + * to create instance of that class, otherwise it will create the default FileBasedWriteAheadLog. + */ + def createLogForDriver( + sparkConf: SparkConf, + fileWalLogDirectory: String, + fileWalHadoopConf: Configuration + ): WriteAheadLog = { + createLog(true, sparkConf, fileWalLogDirectory, fileWalHadoopConf) + } + + /** + * Create a WriteAheadLog for the receiver. If configured with custom WAL class, it will try + * to create instance of that class, otherwise it will create the default FileBasedWriteAheadLog. + */ + def createLogForReceiver( + sparkConf: SparkConf, + fileWalLogDirectory: String, + fileWalHadoopConf: Configuration + ): WriteAheadLog = { + createLog(false, sparkConf, fileWalLogDirectory, fileWalHadoopConf) + } + + /** + * Create a WriteAheadLog based on the value of the given config key. The config key is used + * to get the class name from the SparkConf. If the class is configured, it will try to + * create instance of that class by first trying `new CustomWAL(sparkConf, logDir)` then trying + * `new CustomWAL(sparkConf)`. If either fails, it will fail. If no class is configured, then + * it will create the default FileBasedWriteAheadLog. + */ + private def createLog( + isDriver: Boolean, + sparkConf: SparkConf, + fileWalLogDirectory: String, + fileWalHadoopConf: Configuration + ): WriteAheadLog = { + + val classNameOption = if (isDriver) { + sparkConf.getOption(DRIVER_WAL_CLASS_CONF_KEY) + } else { + sparkConf.getOption(RECEIVER_WAL_CLASS_CONF_KEY) + } + classNameOption.map { className => + try { + instantiateClass( + Utils.classForName(className).asInstanceOf[Class[_ <: WriteAheadLog]], sparkConf) + } catch { + case NonFatal(e) => + throw new SparkException(s"Could not create a write ahead log of class $className", e) + } + }.getOrElse { + new FileBasedWriteAheadLog(sparkConf, fileWalLogDirectory, fileWalHadoopConf, + getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver)) + } + } + + /** Instantiate the class, either using single arg constructor or zero arg constructor */ + private def instantiateClass(cls: Class[_ <: WriteAheadLog], conf: SparkConf): WriteAheadLog = { + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf) + } catch { + case nsme: NoSuchMethodException => + cls.getConstructor().newInstance() + } + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java new file mode 100644 index 0000000000000..50e8f9fc159c8 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import java.util.ArrayList; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collection; + +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.collections.Transformer; +import org.apache.spark.SparkConf; +import org.apache.spark.streaming.util.WriteAheadLog; +import org.apache.spark.streaming.util.WriteAheadLogRecordHandle; +import org.apache.spark.streaming.util.WriteAheadLogUtils; + +import org.junit.Test; +import org.junit.Assert; + +class JavaWriteAheadLogSuiteHandle extends WriteAheadLogRecordHandle { + int index = -1; + public JavaWriteAheadLogSuiteHandle(int idx) { + index = idx; + } +} + +public class JavaWriteAheadLogSuite extends WriteAheadLog { + + class Record { + long time; + int index; + ByteBuffer buffer; + + public Record(long tym, int idx, ByteBuffer buf) { + index = idx; + time = tym; + buffer = buf; + } + } + private int index = -1; + private ArrayList records = new ArrayList(); + + + // Methods for WriteAheadLog + @Override + public WriteAheadLogRecordHandle write(java.nio.ByteBuffer record, long time) { + index += 1; + records.add(new org.apache.spark.streaming.JavaWriteAheadLogSuite.Record(time, index, record)); + return new JavaWriteAheadLogSuiteHandle(index); + } + + @Override + public java.nio.ByteBuffer read(WriteAheadLogRecordHandle handle) { + if (handle instanceof JavaWriteAheadLogSuiteHandle) { + int reqdIndex = ((JavaWriteAheadLogSuiteHandle) handle).index; + for (Record record: records) { + if (record.index == reqdIndex) { + return record.buffer; + } + } + } + return null; + } + + @Override + public java.util.Iterator readAll() { + Collection buffers = CollectionUtils.collect(records, new Transformer() { + @Override + public Object transform(Object input) { + return ((Record) input).buffer; + } + }); + return buffers.iterator(); + } + + @Override + public void clean(long threshTime, boolean waitForCompletion) { + for (int i = 0; i < records.size(); i++) { + if (records.get(i).time < threshTime) { + records.remove(i); + i--; + } + } + } + + @Override + public void close() { + records.clear(); + } + + @Test + public void testCustomWAL() { + SparkConf conf = new SparkConf(); + conf.set("spark.streaming.driver.writeAheadLog.class", JavaWriteAheadLogSuite.class.getName()); + WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null); + + String data1 = "data1"; + WriteAheadLogRecordHandle handle = wal.write(ByteBuffer.wrap(data1.getBytes()), 1234); + Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle); + Assert.assertTrue(new String(wal.read(handle).array()).equals(data1)); + + wal.write(ByteBuffer.wrap("data2".getBytes()), 1235); + wal.write(ByteBuffer.wrap("data3".getBytes()), 1236); + wal.write(ByteBuffer.wrap("data4".getBytes()), 1237); + wal.clean(1236, false); + + java.util.Iterator dataIterator = wal.readAll(); + ArrayList readData = new ArrayList(); + while (dataIterator.hasNext()) { + readData.add(new String(dataIterator.next().array())); + } + Assert.assertTrue(readData.equals(Arrays.asList("data3", "data4"))); + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index c090eaec2928d..23804237bda80 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -43,7 +43,7 @@ import WriteAheadLogSuite._ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { - val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() val storageLevel = StorageLevel.MEMORY_ONLY_SER val streamId = 1 @@ -130,10 +130,13 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche "Unexpected store result type" ) // Verify the data in write ahead log files is correct - val fileSegments = storeResults.map { _.asInstanceOf[WriteAheadLogBasedStoreResult].segment} - val loggedData = fileSegments.flatMap { segment => - val reader = new WriteAheadLogRandomReader(segment.path, hadoopConf) - val bytes = reader.read(segment) + val walSegments = storeResults.map { result => + result.asInstanceOf[WriteAheadLogBasedStoreResult].walRecordHandle + } + val loggedData = walSegments.flatMap { walSegment => + val fileSegment = walSegment.asInstanceOf[FileBasedWriteAheadLogSegment] + val reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf) + val bytes = reader.read(fileSegment) reader.close() blockManager.dataDeserialize(generateBlockId(), bytes).toList } @@ -148,13 +151,13 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche } } - test("WriteAheadLogBasedBlockHandler - cleanup old blocks") { + test("WriteAheadLogBasedBlockHandler - clean old blocks") { withWriteAheadLogBasedBlockHandler { handler => val blocks = Seq.tabulate(10) { i => IteratorBlock(Iterator(1 to i)) } storeBlocks(handler, blocks) val preCleanupLogFiles = getWriteAheadLogFiles() - preCleanupLogFiles.size should be > 1 + require(preCleanupLogFiles.size > 1) // this depends on the number of blocks inserted using generateAndStoreData() manualClock.getTimeMillis() shouldEqual 5000L @@ -218,6 +221,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche /** Instantiate a WriteAheadLogBasedBlockHandler and run a code with it */ private def withWriteAheadLogBasedBlockHandler(body: WriteAheadLogBasedBlockHandler => Unit) { + require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = false) === 1) val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, 1, storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock) try { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index b63b37d9f9cef..8317fb9720416 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.util.WriteAheadLogReader +import org.apache.spark.streaming.util.{WriteAheadLogUtils, FileBasedWriteAheadLogReader} import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} @@ -59,7 +59,7 @@ class ReceivedBlockTrackerSuite test("block addition, and block to batch allocation") { val receivedBlockTracker = createTracker(setCheckpointDir = false) - receivedBlockTracker.isLogManagerEnabled should be (false) // should be disable by default + receivedBlockTracker.isWriteAheadLogEnabled should be (false) // should be disable by default receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty val blockInfos = generateBlockInfos() @@ -88,7 +88,7 @@ class ReceivedBlockTrackerSuite receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos } - test("block addition, block to batch allocation and cleanup with write ahead log") { + test("block addition, block to batch allocation and clean up with write ahead log") { val manualClock = new ManualClock // Set the time increment level to twice the rotation interval so that every increment creates // a new log file @@ -113,11 +113,15 @@ class ReceivedBlockTrackerSuite logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n") } - // Start tracker and add blocks + // Set WAL configuration conf.set("spark.streaming.receiver.writeAheadLog.enable", "true") - conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1") + conf.set("spark.streaming.driver.writeAheadLog.rollingIntervalSecs", "1") + require(WriteAheadLogUtils.enableReceiverLog(conf)) + require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = true) === 1) + + // Start tracker and add blocks val tracker1 = createTracker(clock = manualClock) - tracker1.isLogManagerEnabled should be (true) + tracker1.isWriteAheadLogEnabled should be (true) val blockInfos1 = addBlockInfos(tracker1) tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1 @@ -171,7 +175,7 @@ class ReceivedBlockTrackerSuite eventually(timeout(10 seconds), interval(10 millisecond)) { getWriteAheadLogFiles() should not contain oldestLogFile } - printLogFiles("After cleanup") + printLogFiles("After clean") // Restart tracker and verify recovered state, specifically whether info about the first // batch has been removed, but not the second batch @@ -192,17 +196,17 @@ class ReceivedBlockTrackerSuite test("setting checkpoint dir but not enabling write ahead log") { // When WAL config is not set, log manager should not be enabled val tracker1 = createTracker(setCheckpointDir = true) - tracker1.isLogManagerEnabled should be (false) + tracker1.isWriteAheadLogEnabled should be (false) // When WAL is explicitly disabled, log manager should not be enabled conf.set("spark.streaming.receiver.writeAheadLog.enable", "false") val tracker2 = createTracker(setCheckpointDir = true) - tracker2.isLogManagerEnabled should be(false) + tracker2.isWriteAheadLogEnabled should be(false) } /** * Create tracker object with the optional provided clock. Use fake clock if you - * want to control time by manually incrementing it to test log cleanup. + * want to control time by manually incrementing it to test log clean. */ def createTracker( setCheckpointDir: Boolean = true, @@ -231,7 +235,7 @@ class ReceivedBlockTrackerSuite def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles) : Seq[ReceivedBlockTrackerLogEvent] = { logFiles.flatMap { - file => new WriteAheadLogReader(file, hadoopConf).toSeq + file => new FileBasedWriteAheadLogReader(file, hadoopConf).toSeq }.map { byteBuffer => Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) }.toList @@ -250,7 +254,7 @@ class ReceivedBlockTrackerSuite BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos)))) } - /** Create batch cleanup object from the given info */ + /** Create batch clean object from the given info */ def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanupEvent = { BatchCleanupEvent((Seq(time) ++ moreTimes).map(Time.apply)) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index b84129fd70dd4..393a360cfe150 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -225,7 +225,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { .setAppName(framework) .set("spark.ui.enabled", "true") .set("spark.streaming.receiver.writeAheadLog.enable", "true") - .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1") + .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val batchDuration = Milliseconds(500) val tempDirectory = Utils.createTempDir() val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 58353a5f97c8a..09440b1e791ab 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -363,7 +363,7 @@ class TestReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging } def onStop() { - // no cleanup to be done, the receiving thread should stop on it own + // no clean to be done, the receiving thread should stop on it own } } @@ -396,7 +396,7 @@ class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) def onStop() { // Simulate slow receiver by waiting for all records to be produced while(!SlowTestReceiver.receivedAllRecords) Thread.sleep(100) - // no cleanup to be done, the receiving thread should stop on it own + // no clean to be done, the receiving thread should stop on it own } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index c3602a5b73732..8b300d8dd3fbe 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -21,12 +21,12 @@ import java.io.File import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} -import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} -import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter} +import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkContext} class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { @@ -100,9 +100,10 @@ class WriteAheadLogBackedBlockRDDSuite blockManager.putIterator(blockId, block.iterator, StorageLevel.MEMORY_ONLY_SER) } - // Generate write ahead log segments - val segments = generateFakeSegments(numPartitionsInBM) ++ - writeLogSegments(data.takeRight(numPartitionsInWAL), blockIds.takeRight(numPartitionsInWAL)) + // Generate write ahead log file segments + val recordHandles = generateFakeRecordHandles(numPartitionsInBM) ++ + generateWALRecordHandles(data.takeRight(numPartitionsInWAL), + blockIds.takeRight(numPartitionsInWAL)) // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not require( @@ -116,24 +117,24 @@ class WriteAheadLogBackedBlockRDDSuite // Make sure that the right `numPartitionsInWAL` blocks are in WALs, and other are not require( - segments.takeRight(numPartitionsInWAL).forall(s => + recordHandles.takeRight(numPartitionsInWAL).forall(s => new File(s.path.stripPrefix("file://")).exists()), "Expected blocks not in write ahead log" ) require( - segments.take(numPartitionsInBM).forall(s => + recordHandles.take(numPartitionsInBM).forall(s => !new File(s.path.stripPrefix("file://")).exists()), "Unexpected blocks in write ahead log" ) // Create the RDD and verify whether the returned data is correct val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, - segments.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY) + recordHandles.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY) assert(rdd.collect() === data.flatten) if (testStoreInBM) { val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray, - segments.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY) + recordHandles.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY) assert(rdd2.collect() === data.flatten) assert( blockIds.forall(blockManager.get(_).nonEmpty), @@ -142,12 +143,12 @@ class WriteAheadLogBackedBlockRDDSuite } } - private def writeLogSegments( + private def generateWALRecordHandles( blockData: Seq[Seq[String]], blockIds: Seq[BlockId] - ): Seq[WriteAheadLogFileSegment] = { + ): Seq[FileBasedWriteAheadLogSegment] = { require(blockData.size === blockIds.size) - val writer = new WriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) + val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => writer.write(blockManager.dataSerialize(id, data.iterator)) } @@ -155,7 +156,7 @@ class WriteAheadLogBackedBlockRDDSuite segments } - private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = { - Array.fill(count)(new WriteAheadLogFileSegment("random", 0L, 0)) + private def generateFakeRecordHandles(count: Int): Seq[FileBasedWriteAheadLogSegment] = { + Array.fill(count)(new FileBasedWriteAheadLogSegment("random", 0L, 0)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index a3919c43b95b4..79098bcf4861c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -18,33 +18,38 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer +import java.util import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} +import scala.reflect.ClassTag -import WriteAheadLogSuite._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.util.{ManualClock, Utils} -import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Eventually._ +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.spark.util.{ManualClock, Utils} +import org.apache.spark.{SparkConf, SparkException} class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { + import WriteAheadLogSuite._ + val hadoopConf = new Configuration() var tempDir: File = null var testDir: String = null var testFile: String = null - var manager: WriteAheadLogManager = null + var writeAheadLog: FileBasedWriteAheadLog = null before { tempDir = Utils.createTempDir() testDir = tempDir.toString testFile = new File(tempDir, "testFile").toString - if (manager != null) { - manager.stop() - manager = null + if (writeAheadLog != null) { + writeAheadLog.close() + writeAheadLog = null } } @@ -52,16 +57,60 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { Utils.deleteRecursively(tempDir) } - test("WriteAheadLogWriter - writing data") { + test("WriteAheadLogUtils - log selection and creation") { + val logDir = Utils.createTempDir().getAbsolutePath() + + def assertDriverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + log + } + + def assertReceiverLogClass[T: ClassTag](conf: SparkConf): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + log + } + + val emptyConf = new SparkConf() // no log configuration + assertDriverLogClass[FileBasedWriteAheadLog](emptyConf) + assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) + + // Verify setting driver WAL class + val conf1 = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](conf1) + assertReceiverLogClass[FileBasedWriteAheadLog](conf1) + + // Verify setting receiver WAL class + val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + + // Verify setting receiver WAL class with 1-arg constructor + val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog1].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) + + // Verify failure setting receiver WAL class with 2-arg constructor + intercept[SparkException] { + val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog2].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + } + } + + test("FileBasedWriteAheadLogWriter - writing data") { val dataToWrite = generateRandomData() val segments = writeDataUsingWriter(testFile, dataToWrite) val writtenData = readDataManually(segments) assert(writtenData === dataToWrite) } - test("WriteAheadLogWriter - syncing of data by writing and reading immediately") { + test("FileBasedWriteAheadLogWriter - syncing of data by writing and reading immediately") { val dataToWrite = generateRandomData() - val writer = new WriteAheadLogWriter(testFile, hadoopConf) + val writer = new FileBasedWriteAheadLogWriter(testFile, hadoopConf) dataToWrite.foreach { data => val segment = writer.write(stringToByteBuffer(data)) val dataRead = readDataManually(Seq(segment)).head @@ -70,10 +119,10 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { writer.close() } - test("WriteAheadLogReader - sequentially reading data") { + test("FileBasedWriteAheadLogReader - sequentially reading data") { val writtenData = generateRandomData() writeDataManually(writtenData, testFile) - val reader = new WriteAheadLogReader(testFile, hadoopConf) + val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf) val readData = reader.toSeq.map(byteBufferToString) assert(readData === writtenData) assert(reader.hasNext === false) @@ -83,14 +132,14 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { reader.close() } - test("WriteAheadLogReader - sequentially reading data written with writer") { + test("FileBasedWriteAheadLogReader - sequentially reading data written with writer") { val dataToWrite = generateRandomData() writeDataUsingWriter(testFile, dataToWrite) val readData = readDataUsingReader(testFile) assert(readData === dataToWrite) } - test("WriteAheadLogReader - reading data written with writer after corrupted write") { + test("FileBasedWriteAheadLogReader - reading data written with writer after corrupted write") { // Write data manually for testing the sequential reader val dataToWrite = generateRandomData() writeDataUsingWriter(testFile, dataToWrite) @@ -113,38 +162,38 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { assert(readDataUsingReader(testFile) === (dataToWrite.dropRight(1))) } - test("WriteAheadLogRandomReader - reading data using random reader") { + test("FileBasedWriteAheadLogRandomReader - reading data using random reader") { // Write data manually for testing the random reader val writtenData = generateRandomData() val segments = writeDataManually(writtenData, testFile) // Get a random order of these segments and read them back val writtenDataAndSegments = writtenData.zip(segments).toSeq.permutations.take(10).flatten - val reader = new WriteAheadLogRandomReader(testFile, hadoopConf) + val reader = new FileBasedWriteAheadLogRandomReader(testFile, hadoopConf) writtenDataAndSegments.foreach { case (data, segment) => assert(data === byteBufferToString(reader.read(segment))) } reader.close() } - test("WriteAheadLogRandomReader - reading data using random reader written with writer") { + test("FileBasedWriteAheadLogRandomReader- reading data using random reader written with writer") { // Write data using writer for testing the random reader val data = generateRandomData() val segments = writeDataUsingWriter(testFile, data) // Read a random sequence of segments and verify read data val dataAndSegments = data.zip(segments).toSeq.permutations.take(10).flatten - val reader = new WriteAheadLogRandomReader(testFile, hadoopConf) + val reader = new FileBasedWriteAheadLogRandomReader(testFile, hadoopConf) dataAndSegments.foreach { case (data, segment) => assert(data === byteBufferToString(reader.read(segment))) } reader.close() } - test("WriteAheadLogManager - write rotating logs") { - // Write data using manager + test("FileBasedWriteAheadLog - write rotating logs") { + // Write data with rotation using WriteAheadLog class val dataToWrite = generateRandomData() - writeDataUsingManager(testDir, dataToWrite) + writeDataUsingWriteAheadLog(testDir, dataToWrite) // Read data manually to verify the written data val logFiles = getLogFilesInDirectory(testDir) @@ -153,8 +202,8 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { assert(writtenData === dataToWrite) } - test("WriteAheadLogManager - read rotating logs") { - // Write data manually for testing reading through manager + test("FileBasedWriteAheadLog - read rotating logs") { + // Write data manually for testing reading through WriteAheadLog val writtenData = (1 to 10).map { i => val data = generateRandomData() val file = testDir + s"/log-$i-$i" @@ -167,25 +216,25 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { assert(fileSystem.exists(logDirectoryPath) === true) // Read data using manager and verify - val readData = readDataUsingManager(testDir) + val readData = readDataUsingWriteAheadLog(testDir) assert(readData === writtenData) } - test("WriteAheadLogManager - recover past logs when creating new manager") { + test("FileBasedWriteAheadLog - recover past logs when creating new manager") { // Write data with manager, recover with new manager and verify val dataToWrite = generateRandomData() - writeDataUsingManager(testDir, dataToWrite) + writeDataUsingWriteAheadLog(testDir, dataToWrite) val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size > 1) - val readData = readDataUsingManager(testDir) + val readData = readDataUsingWriteAheadLog(testDir) assert(dataToWrite === readData) } - test("WriteAheadLogManager - cleanup old logs") { + test("FileBasedWriteAheadLog - clean old logs") { logCleanUpTest(waitForCompletion = false) } - test("WriteAheadLogManager - cleanup old logs synchronously") { + test("FileBasedWriteAheadLog - clean old logs synchronously") { logCleanUpTest(waitForCompletion = true) } @@ -193,11 +242,11 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { // Write data with manager, recover with new manager and verify val manualClock = new ManualClock val dataToWrite = generateRandomData() - manager = writeDataUsingManager(testDir, dataToWrite, manualClock, stopManager = false) + writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, manualClock, closeLog = false) val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size > 1) - manager.cleanupOldLogs(manualClock.getTimeMillis() / 2, waitForCompletion) + writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) if (waitForCompletion) { assert(getLogFilesInDirectory(testDir).size < logFiles.size) @@ -208,11 +257,11 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { } } - test("WriteAheadLogManager - handling file errors while reading rotating logs") { + test("FileBasedWriteAheadLog - handling file errors while reading rotating logs") { // Generate a set of log files val manualClock = new ManualClock val dataToWrite1 = generateRandomData() - writeDataUsingManager(testDir, dataToWrite1, manualClock) + writeDataUsingWriteAheadLog(testDir, dataToWrite1, manualClock) val logFiles1 = getLogFilesInDirectory(testDir) assert(logFiles1.size > 1) @@ -220,12 +269,12 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { // Recover old files and generate a second set of log files val dataToWrite2 = generateRandomData() manualClock.advance(100000) - writeDataUsingManager(testDir, dataToWrite2, manualClock) + writeDataUsingWriteAheadLog(testDir, dataToWrite2, manualClock) val logFiles2 = getLogFilesInDirectory(testDir) assert(logFiles2.size > logFiles1.size) // Read the files and verify that all the written data can be read - val readData1 = readDataUsingManager(testDir) + val readData1 = readDataUsingWriteAheadLog(testDir) assert(readData1 === (dataToWrite1 ++ dataToWrite2)) // Corrupt the first set of files so that they are basically unreadable @@ -236,25 +285,51 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { } // Verify that the corrupted files do not prevent reading of the second set of data - val readData = readDataUsingManager(testDir) + val readData = readDataUsingWriteAheadLog(testDir) assert(readData === dataToWrite2) } + + test("FileBasedWriteAheadLog - do not create directories or files unless write") { + val nonexistentTempPath = File.createTempFile("test", "") + nonexistentTempPath.delete() + assert(!nonexistentTempPath.exists()) + + val writtenSegment = writeDataManually(generateRandomData(), testFile) + val wal = new FileBasedWriteAheadLog( + new SparkConf(), tempDir.getAbsolutePath, new Configuration(), 1, 1) + assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") + wal.read(writtenSegment.head) + assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") + } } object WriteAheadLogSuite { + class MockWriteAheadLog0() extends WriteAheadLog { + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } + override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } + override def readAll(): util.Iterator[ByteBuffer] = { null } + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } + override def close(): Unit = { } + } + + class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + + class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() + + private val hadoopConf = new Configuration() /** Write data to a file directly and return an array of the file segments written. */ - def writeDataManually(data: Seq[String], file: String): Seq[WriteAheadLogFileSegment] = { - val segments = new ArrayBuffer[WriteAheadLogFileSegment]() + def writeDataManually(data: Seq[String], file: String): Seq[FileBasedWriteAheadLogSegment] = { + val segments = new ArrayBuffer[FileBasedWriteAheadLogSegment]() val writer = HdfsUtils.getOutputStream(file, hadoopConf) data.foreach { item => val offset = writer.getPos val bytes = Utils.serialize(item) writer.writeInt(bytes.size) writer.write(bytes) - segments += WriteAheadLogFileSegment(file, offset, bytes.size) + segments += FileBasedWriteAheadLogSegment(file, offset, bytes.size) } writer.close() segments @@ -263,8 +338,11 @@ object WriteAheadLogSuite { /** * Write data to a file using the writer class and return an array of the file segments written. */ - def writeDataUsingWriter(filePath: String, data: Seq[String]): Seq[WriteAheadLogFileSegment] = { - val writer = new WriteAheadLogWriter(filePath, hadoopConf) + def writeDataUsingWriter( + filePath: String, + data: Seq[String] + ): Seq[FileBasedWriteAheadLogSegment] = { + val writer = new FileBasedWriteAheadLogWriter(filePath, hadoopConf) val segments = data.map { item => writer.write(item) } @@ -272,27 +350,27 @@ object WriteAheadLogSuite { segments } - /** Write data to rotating files in log directory using the manager class. */ - def writeDataUsingManager( + /** Write data to rotating files in log directory using the WriteAheadLog class. */ + def writeDataUsingWriteAheadLog( logDirectory: String, data: Seq[String], manualClock: ManualClock = new ManualClock, - stopManager: Boolean = true - ): WriteAheadLogManager = { + closeLog: Boolean = true + ): FileBasedWriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) - val manager = new WriteAheadLogManager(logDirectory, hadoopConf, - rollingIntervalSecs = 1, callerName = "WriteAheadLogSuite", clock = manualClock) + val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) + // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => manualClock.advance(500) - manager.writeToLog(item) + wal.write(item, manualClock.getTimeMillis()) } - if (stopManager) manager.stop() - manager + if (closeLog) wal.close() + wal } /** Read data from a segments of a log file directly and return the list of byte buffers. */ - def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = { + def readDataManually(segments: Seq[FileBasedWriteAheadLogSegment]): Seq[String] = { segments.map { segment => val reader = HdfsUtils.getInputStream(segment.path, hadoopConf) try { @@ -331,18 +409,18 @@ object WriteAheadLogSuite { /** Read all the data from a log file using reader class and return the list of byte buffers. */ def readDataUsingReader(file: String): Seq[String] = { - val reader = new WriteAheadLogReader(file, hadoopConf) + val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) val readData = reader.toList.map(byteBufferToString) reader.close() readData } - /** Read all the data in the log file in a directory using the manager class. */ - def readDataUsingManager(logDirectory: String): Seq[String] = { - val manager = new WriteAheadLogManager(logDirectory, hadoopConf, - callerName = "WriteAheadLogSuite") - val data = manager.readFromLog().map(byteBufferToString).toSeq - manager.stop() + /** Read all the data in the log file in a directory using the WriteAheadLog class. */ + def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = { + import scala.collection.JavaConversions._ + val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) + val data = wal.readAll().map(byteBufferToString).toSeq + wal.close() data } From a9c4e29950a14e32acaac547e9a0e8879fd37fc9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 29 Apr 2015 13:10:31 -0700 Subject: [PATCH 182/191] [SPARK-6752] [STREAMING] [REOPENED] Allow StreamingContext to be recreated from checkpoint and existing SparkContext Original PR #5428 got reverted due to issues between MutableBoolean and Hadoop 1.0.4 (see JIRA). This replaces MutableBoolean with AtomicBoolean. srowen pwendell Author: Tathagata Das Closes #5773 from tdas/SPARK-6752 and squashes the following commits: a0c0ead [Tathagata Das] Fix for hadoop 1.0.4 70ae85b [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-6752 94db63c [Tathagata Das] Fix long line. 524f519 [Tathagata Das] Many changes based on PR comments. eabd092 [Tathagata Das] Added Function0, Java API and unit tests for StreamingContext.getOrCreate 36a7823 [Tathagata Das] Minor changes. 204814e [Tathagata Das] Added StreamingContext.getOrCreate with existing SparkContext --- .../spark/api/java/function/Function0.java | 27 +++ .../apache/spark/streaming/Checkpoint.scala | 26 ++- .../spark/streaming/StreamingContext.scala | 85 ++++++++-- .../api/java/JavaStreamingContext.scala | 119 ++++++++++++- .../apache/spark/streaming/JavaAPISuite.java | 145 ++++++++++++---- .../spark/streaming/CheckpointSuite.scala | 3 +- .../streaming/StreamingContextSuite.scala | 159 ++++++++++++++++++ 7 files changed, 503 insertions(+), 61 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/Function0.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java new file mode 100644 index 0000000000000..38e410c5debe6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A zero-argument function that returns an R. + */ +public interface Function0 extends Serializable { + public R call() throws Exception; +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 0a50485118588..7bfae253c3a0c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -77,7 +77,8 @@ object Checkpoint extends Logging { } /** Get checkpoint files present in the give directory, ordered by oldest-first */ - def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = { + def getCheckpointFiles(checkpointDir: String, fsOption: Option[FileSystem] = None): Seq[Path] = { + def sortFunc(path1: Path, path2: Path): Boolean = { val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } @@ -85,6 +86,7 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) + val fs = fsOption.getOrElse(path.getFileSystem(new Configuration())) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { @@ -160,7 +162,7 @@ class CheckpointWriter( } // Delete old checkpoint files - val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) + val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) if (allCheckpointFiles.size > 10) { allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { logInfo("Deleting " + file) @@ -234,15 +236,24 @@ class CheckpointWriter( private[streaming] object CheckpointReader extends Logging { - def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = - { + /** + * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint + * files, then return None, else try to return the latest valid checkpoint object. If no + * checkpoint files could be read correctly, then return None (if ignoreReadError = true), + * or throw exception (if ignoreReadError = false). + */ + def read( + checkpointDir: String, + conf: SparkConf, + hadoopConf: Configuration, + ignoreReadError: Boolean = false): Option[Checkpoint] = { val checkpointPath = new Path(checkpointDir) // TODO(rxin): Why is this a def?! def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files - val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse if (checkpointFiles.isEmpty) { return None } @@ -282,7 +293,10 @@ object CheckpointReader extends Logging { }) // If none of checkpoint files could be read, then throw exception - throw new SparkException("Failed to read checkpoint from directory " + checkpointPath) + if (!ignoreReadError) { + throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") + } + None } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index f57f295874645..90c8b47aebce0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -107,6 +107,19 @@ class StreamingContext private[streaming] ( */ def this(path: String) = this(path, new Configuration) + /** + * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. + * @param path Path to the directory that was specified as the checkpoint directory + * @param sparkContext Existing SparkContext + */ + def this(path: String, sparkContext: SparkContext) = { + this( + sparkContext, + CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get, + null) + } + + if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") @@ -115,10 +128,12 @@ class StreamingContext private[streaming] ( private[streaming] val isCheckpointPresent = (cp_ != null) private[streaming] val sc: SparkContext = { - if (isCheckpointPresent) { + if (sc_ != null) { + sc_ + } else if (isCheckpointPresent) { new SparkContext(cp_.createSparkConf()) } else { - sc_ + throw new SparkException("Cannot create StreamingContext without a SparkContext") } } @@ -129,7 +144,7 @@ class StreamingContext private[streaming] ( private[streaming] val conf = sc.conf - private[streaming] val env = SparkEnv.get + private[streaming] val env = sc.env private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { @@ -174,7 +189,9 @@ class StreamingContext private[streaming] ( /** Register streaming source to metrics system */ private val streamingSource = new StreamingSource(this) - SparkEnv.get.metricsSystem.registerSource(streamingSource) + assert(env != null) + assert(env.metricsSystem != null) + env.metricsSystem.registerSource(streamingSource) /** Enumeration to identify current state of the StreamingContext */ private[streaming] object StreamingContextState extends Enumeration { @@ -621,19 +638,59 @@ object StreamingContext extends Logging { hadoopConf: Configuration = new Configuration(), createOnError: Boolean = false ): StreamingContext = { - val checkpointOption = try { - CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf) - } catch { - case e: Exception => - if (createOnError) { - None - } else { - throw e - } - } + val checkpointOption = CheckpointReader.read( + checkpointPath, new SparkConf(), hadoopConf, createOnError) checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note + * that the SparkConf configuration in the checkpoint data will not be restored as the + * SparkContext has already been created. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext using the given SparkContext + * @param sparkContext SparkContext using which the StreamingContext will be created + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: SparkContext => StreamingContext, + sparkContext: SparkContext + ): StreamingContext = { + getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note + * that the SparkConf configuration in the checkpoint data will not be restored as the + * SparkContext has already been created. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext using the given SparkContext + * @param sparkContext SparkContext using which the StreamingContext will be created + * @param createOnError Whether to create a new StreamingContext if there is an + * error in reading checkpoint data. By default, an exception will be + * thrown on error. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: SparkContext => StreamingContext, + sparkContext: SparkContext, + createOnError: Boolean + ): StreamingContext = { + val checkpointOption = CheckpointReader.read( + checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError) + checkpointOption.map(new StreamingContext(sparkContext, _, null)) + .getOrElse(creatingFunc(sparkContext)) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 4095a7cc84946..572d7d8e8753d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -32,13 +32,14 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import org.apache.spark.api.java.function.{Function0 => JFunction0} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.hadoop.conf.Configuration -import org.apache.spark.streaming.dstream.{PluggableInputDStream, ReceiverInputDStream, DStream} +import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver +import org.apache.hadoop.conf.Configuration /** * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main @@ -655,6 +656,7 @@ object JavaStreamingContext { * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext */ + @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( checkpointPath: String, factory: JavaStreamingContextFactory @@ -676,6 +678,7 @@ object JavaStreamingContext { * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system */ + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -700,6 +703,7 @@ object JavaStreamingContext { * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. */ + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -712,6 +716,117 @@ object JavaStreamingContext { new JavaStreamingContext(ssc) } + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext] + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext], + hadoopConf: Configuration + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }, hadoopConf) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction0[JavaStreamingContext], + hadoopConf: Configuration, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + creatingFunc.call().ssc + }, hadoopConf, createOnError) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param sparkContext SparkContext using which the StreamingContext will be created + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], + sparkContext: JavaSparkContext + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { + creatingFunc.call(new JavaSparkContext(sparkContext)).ssc + }, sparkContext.sc) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new JavaStreamingContext + * @param sparkContext SparkContext using which the StreamingContext will be created + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], + sparkContext: JavaSparkContext, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { + creatingFunc.call(new JavaSparkContext(sparkContext)).ssc + }, sparkContext.sc, createOnError) + new JavaStreamingContext(ssc) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 90340753a4eed..b1adf881dd0f5 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -21,11 +21,13 @@ import java.lang.Iterable; import java.nio.charset.Charset; import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; + import scala.Tuple2; import org.junit.Assert; @@ -45,6 +47,7 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; +import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -929,7 +932,7 @@ public void testPairMap() { // Maps pair -> pair of different type public Tuple2 call(Tuple2 in) throws Exception { return in.swap(); } - }); + }); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -987,12 +990,12 @@ public void testPairMap2() { // Maps pair -> single JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaDStream reversed = pairStream.map( - new Function, Integer>() { - @Override - public Integer call(Tuple2 in) throws Exception { - return in._2(); - } - }); + new Function, Integer>() { + @Override + public Integer call(Tuple2 in) throws Exception { + return in._2(); + } + }); JavaTestUtils.attachTestOutputStream(reversed); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1123,7 +1126,7 @@ public void testCombineByKey() { JavaPairDStream combined = pairStream.combineByKey( new Function() { - @Override + @Override public Integer call(Integer i) throws Exception { return i; } @@ -1144,14 +1147,14 @@ public void testCountByValue() { Arrays.asList("hello")); List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), - Arrays.asList( - new Tuple2("hello", 1L))); + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("moon", 1L)), + Arrays.asList( + new Tuple2("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1249,17 +1252,17 @@ public void testUpdateStateByKey() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); } - return Optional.of(out); - } }); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1292,17 +1295,17 @@ public void testUpdateStateByKeyWithInitial() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v: values) { - out = out + v; + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v : values) { + out = out + v; + } + return Optional.of(out); } - return Optional.of(out); - } }, new HashPartitioner(1), initialRDD); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1328,7 +1331,7 @@ public void testReduceByKeyAndWindowWithInverse() { JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1707,6 +1710,74 @@ public Integer call(String s) throws Exception { Utils.deleteRecursively(tempDir); } + @SuppressWarnings("unchecked") + @Test + public void testContextGetOrCreate() throws InterruptedException { + + final SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("newContext", "true"); + + File emptyDir = Files.createTempDir(); + emptyDir.deleteOnExit(); + StreamingContextSuite contextSuite = new StreamingContextSuite(); + String corruptedCheckpointDir = contextSuite.createCorruptedCheckpoint(); + String checkpointDir = contextSuite.createValidCheckpoint(); + + // Function to create JavaStreamingContext without any output operations + // (used to detect the new context) + final AtomicBoolean newContextCreated = new AtomicBoolean(false); + Function0 creatingFunc = new Function0() { + public JavaStreamingContext call() { + newContextCreated.set(true); + return new JavaStreamingContext(conf, Seconds.apply(1)); + } + }; + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc); + Assert.assertTrue("new context not created", newContextCreated.get()); + ssc.stop(); + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration(), true); + Assert.assertTrue("new context not created", newContextCreated.get()); + ssc.stop(); + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration()); + Assert.assertTrue("old context not recovered", !newContextCreated.get()); + ssc.stop(); + + // Function to create JavaStreamingContext using existing JavaSparkContext + // without any output operations (used to detect the new context) + Function creatingFunc2 = + new Function() { + public JavaStreamingContext call(JavaSparkContext context) { + newContextCreated.set(true); + return new JavaStreamingContext(context, Seconds.apply(1)); + } + }; + + JavaSparkContext sc = new JavaSparkContext(conf); + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc); + Assert.assertTrue("new context not created", newContextCreated.get()); + ssc.stop(false); + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true); + Assert.assertTrue("new context not created", newContextCreated.get()); + ssc.stop(false); + + newContextCreated.set(false); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc); + Assert.assertTrue("old context not recovered", !newContextCreated.get()); + ssc.stop(); + } /* TEST DISABLED: Pending a discussion about checkpoint() semantics with TD @SuppressWarnings("unchecked") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 54c30440a6e8d..6b0a3f91d4d06 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -430,9 +430,8 @@ class CheckpointSuite extends TestSuiteBase { assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } // Wait for a checkpoint to be written - val fs = new Path(checkpointDir).getFileSystem(ssc.sc.hadoopConfiguration) eventually(eventuallyTimeout) { - assert(Checkpoint.getCheckpointFiles(checkpointDir, fs).size === 6) + assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) } ssc.stop() // Check that we shut down while the third batch was being processed diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 09440b1e791ab..5207b7109e69b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.streaming +import java.io.File import java.util.concurrent.atomic.AtomicInteger +import org.apache.commons.io.FileUtils import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ @@ -330,6 +332,139 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("getOrCreate") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + + // Function to create StreamingContext that has a config to identify it to be new context + var newContextCreated = false + def creatingFunction(): StreamingContext = { + newContextCreated = true + new StreamingContext(conf, batchDuration) + } + + // Call ssc.stop after a body of code + def testGetOrCreate(body: => Unit): Unit = { + newContextCreated = false + try { + body + } finally { + if (ssc != null) { + ssc.stop() + } + ssc = null + } + } + + val emptyPath = Utils.createTempDir().getAbsolutePath() + + // getOrCreate should create new context with empty path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + val corrutedCheckpointPath = createCorruptedCheckpoint() + + // getOrCreate should throw exception with fake checkpoint file and createOnError = false + intercept[Exception] { + ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _) + } + + // getOrCreate should throw exception with fake checkpoint file + intercept[Exception] { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, createOnError = false) + } + + // getOrCreate should create new context with fake checkpoint file and createOnError = true + testGetOrCreate { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + } + + val checkpointPath = createValidCheckpoint() + + // getOrCreate should recover context with checkpoint path, and recover old configuration + testGetOrCreate { + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(ssc.conf.get("someKey") === "someValue") + } + } + + test("getOrCreate with existing SparkContext") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + sc = new SparkContext(conf) + + // Function to create StreamingContext that has a config to identify it to be new context + var newContextCreated = false + def creatingFunction(sparkContext: SparkContext): StreamingContext = { + newContextCreated = true + new StreamingContext(sparkContext, batchDuration) + } + + // Call ssc.stop(stopSparkContext = false) after a body of cody + def testGetOrCreate(body: => Unit): Unit = { + newContextCreated = false + try { + body + } finally { + if (ssc != null) { + ssc.stop(stopSparkContext = false) + } + ssc = null + } + } + + val emptyPath = Utils.createTempDir().getAbsolutePath() + + // getOrCreate should create new context with empty path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + } + + val corrutedCheckpointPath = createCorruptedCheckpoint() + + // getOrCreate should throw exception with fake checkpoint file and createOnError = false + intercept[Exception] { + ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc) + } + + // getOrCreate should throw exception with fake checkpoint file + intercept[Exception] { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, sc, createOnError = false) + } + + // getOrCreate should create new context with fake checkpoint file and createOnError = true + testGetOrCreate { + ssc = StreamingContext.getOrCreate( + corrutedCheckpointPath, creatingFunction _, sc, createOnError = true) + assert(ssc != null, "no context created") + assert(newContextCreated, "new context not created") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + } + + val checkpointPath = createValidCheckpoint() + + // StreamingContext.getOrCreate should recover context with checkpoint path + testGetOrCreate { + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc) + assert(ssc != null, "no context created") + assert(!newContextCreated, "old context not recovered") + assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") + assert(!ssc.conf.contains("someKey"), + "recovered StreamingContext unexpectedly has old config") + } + } + test("DStream and generated RDD creation sites") { testPackage.test() } @@ -339,6 +474,30 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val inputStream = new TestInputStream(s, input, 1) inputStream } + + def createValidCheckpoint(): String = { + val testDirectory = Utils.createTempDir().getAbsolutePath() + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set("someKey", "someValue") + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDirectory) + ssc.textFileStream(testDirectory).foreachRDD { rdd => rdd.count() } + ssc.start() + eventually(timeout(10000 millis)) { + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) + } + ssc.stop() + checkpointDirectory + } + + def createCorruptedCheckpoint(): String = { + val checkpointDirectory = Utils.createTempDir().getAbsolutePath() + val fakeCheckpointFile = Checkpoint.checkpointFile(checkpointDirectory, Time(1000)) + FileUtils.write(new File(fakeCheckpointFile.toString()), "blablabla") + assert(Checkpoint.getCheckpointFiles(checkpointDirectory).nonEmpty) + checkpointDirectory + } } class TestException(msg: String) extends Exception(msg) From 3a180c19a4ef165366186e23d8e8844c5baaecdd Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 29 Apr 2015 13:31:52 -0700 Subject: [PATCH 183/191] [SPARK-6629] cancelJobGroup() may not work for jobs whose job groups are inherited from parent threads When a job is submitted with a job group and that job group is inherited from a parent thread, there are multiple bugs that may prevent this job from being cancelable via `SparkContext.cancelJobGroup()`: - When filtering jobs based on their job group properties, DAGScheduler calls `get()` instead of `getProperty()`, which does not respect inheritance, so it will skip over jobs whose job group properties were inherited. - `Properties` objects are mutable, but we do not make defensive copies / snapshots, so modifications of the parent thread's job group will cause running jobs' groups to change; this also breaks cancelation. Both of these issues are easy to fix: use `getProperty()` and perform defensive copying. Author: Josh Rosen Closes #5288 from JoshRosen/localProperties-mutability-race and squashes the following commits: 9e29654 [Josh Rosen] Fix style issue 5d90750 [Josh Rosen] Merge remote-tracking branch 'origin/master' into localProperties-mutability-race 3f7b9e8 [Josh Rosen] Add JIRA reference; move clone into DAGScheduler 707e417 [Josh Rosen] Clone local properties to prevent mutations from breaking job cancellation. b376114 [Josh Rosen] Fix bug that prevented jobs with inherited job group properties from being cancelled. --- .../apache/spark/scheduler/DAGScheduler.scala | 15 ++++-- .../apache/spark/JobCancellationSuite.scala | 35 ++++++++++++++ .../org/apache/spark/ThreadingSuite.scala | 46 ++++++++++++++++++- 3 files changed, 91 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b511c306ca508..05b8ab0d0a1f9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -28,6 +28,8 @@ import scala.language.existentials import scala.language.postfixOps import scala.util.control.NonFatal +import org.apache.commons.lang3.SerializationUtils + import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics @@ -510,7 +512,8 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)) + jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, + SerializationUtils.clone(properties))) waiter } @@ -547,7 +550,8 @@ class DAGScheduler( val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) + jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, + SerializationUtils.clone(properties))) listener.awaitResult() // Will throw an exception if the job fails } @@ -704,8 +708,11 @@ class DAGScheduler( private[scheduler] def handleJobGroupCancelled(groupId: String) { // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. - val activeInGroup = activeJobs.filter(activeJob => - Option(activeJob.properties).exists(_.get(SparkContext.SPARK_JOB_GROUP_ID) == groupId)) + val activeInGroup = activeJobs.filter { activeJob => + Option(activeJob.properties).exists { + _.getProperty(SparkContext.SPARK_JOB_GROUP_ID) == groupId + } + } val jobIds = activeInGroup.map(_.jobId) jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId))) submitWaitingStages() diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 4d3e09793faff..ae17fc60e4a43 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -141,6 +141,41 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter assert(jobB.get() === 100) } + test("inherited job group (SPARK-6629)") { + sc = new SparkContext("local[2]", "test") + + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + sc.setJobGroup("jobA", "this is a job to be cancelled") + @volatile var exception: Exception = null + val jobA = new Thread() { + // The job group should be inherited by this thread + override def run(): Unit = { + exception = intercept[SparkException] { + sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() + } + } + } + jobA.start() + + // Block until both tasks of job A have started and cancel job A. + sem.acquire(2) + sc.cancelJobGroup("jobA") + jobA.join(10000) + assert(!jobA.isAlive) + assert(exception.getMessage contains "cancel") + + // Once A is cancelled, job B should finish fairly quickly. + val jobB = sc.parallelize(1 to 100, 2).countAsync() + assert(jobB.get() === 100) + } + test("job group with interruption") { sc = new SparkContext("local[2]", "test") diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index b5383d553add1..10917c866cc7d 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark -import java.util.concurrent.Semaphore +import java.util.concurrent.{TimeUnit, Semaphore} import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicInteger +import org.apache.spark.scheduler._ import org.scalatest.FunSuite /** @@ -189,4 +190,47 @@ class ThreadingSuite extends FunSuite with LocalSparkContext { assert(sc.getLocalProperty("test") === "parent") assert(sc.getLocalProperty("Foo") === null) } + + test("mutations to local properties should not affect submitted jobs (SPARK-6629)") { + val jobStarted = new Semaphore(0) + val jobEnded = new Semaphore(0) + @volatile var jobResult: JobResult = null + + sc = new SparkContext("local", "test") + sc.setJobGroup("originalJobGroupId", "description") + sc.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobStarted.release() + } + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + jobResult = jobEnd.jobResult + jobEnded.release() + } + }) + + // Create a new thread which will inherit the current thread's properties + val thread = new Thread() { + override def run(): Unit = { + assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId") + // Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task + try { + sc.parallelize(1 to 100).foreach { x => + Thread.sleep(100) + } + } catch { + case s: SparkException => // ignored so that we don't print noise in test logs + } + } + } + thread.start() + // Wait for the job to start, then mutate the original properties, which should have been + // inherited by the running job but hopefully defensively copied or snapshotted: + jobStarted.tryAcquire(10, TimeUnit.SECONDS) + sc.setJobGroup("modifiedJobGroupId", "description") + // Canceling the original job group should cancel the running job. In other words, the + // modification of the properties object should not affect the properties of running jobs + sc.cancelJobGroup("originalJobGroupId") + jobEnded.tryAcquire(10, TimeUnit.SECONDS) + assert(jobResult.isInstanceOf[JobFailed]) + } } From 15995c883aa248235fdebf0cbeeaa3ef12c97e9c Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 29 Apr 2015 14:53:37 -0700 Subject: [PATCH 184/191] [SPARK-7222] [ML] Added mathematical derivation in comment and compressed the model, removed the correction terms in LinearRegression with ElasticNet Added detailed mathematical derivation of how scaling and LeastSquaresAggregator work. Refactored the code so the model is compressed based on the storage. We may try compression based on the prediction time. Also, I found that diffSum will be always zero mathematically, so no corrections are required. Author: DB Tsai Closes #5767 from dbtsai/lir-doc and squashes the following commits: 5e346c9 [DB Tsai] refactoring fc9f582 [DB Tsai] doc 58456d8 [DB Tsai] address feedback 69757b8 [DB Tsai] actually diffSum is mathematically zero! No correction is needed. 5929e49 [DB Tsai] typo 63f7d1e [DB Tsai] Added compression to the model based on storage 203a295 [DB Tsai] Add more documentation to LinearRegression in new ML framework. --- .../ml/regression/LinearRegression.scala | 115 ++++++++++++++---- .../GeneralizedLinearAlgorithm.scala | 2 +- .../mllib/util/LinearDataGenerator.scala | 21 ++-- 3 files changed, 101 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f92c6816eb54c..cc9ad22cb860f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter +import org.apache.spark.Logging /** * Params for linear regression. @@ -48,7 +49,7 @@ private[regression] trait LinearRegressionParams extends RegressorParams */ @AlphaComponent class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] - with LinearRegressionParams { + with LinearRegressionParams with Logging { /** * Set the regularization parameter. @@ -110,6 +111,15 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress val yMean = statCounter.mean val yStd = math.sqrt(statCounter.variance) + // If the yStd is zero, then the intercept is yMean with zero weights; + // as a result, training is not needed. + if (yStd == 0.0) { + logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " + + s"and the intercept will be the mean of the label; as a result, training is not needed.") + if (handlePersistence) instances.unpersist() + return new LinearRegressionModel(this, paramMap, Vectors.sparse(numFeatures, Seq()), yMean) + } + val featuresMean = summarizer.mean.toArray val featuresStd = summarizer.variance.toArray.map(math.sqrt) @@ -141,7 +151,6 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress } lossHistory += state.value - // TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector. // The weights are trained in the scaled space; we're converting them back to // the original space. val weights = { @@ -158,11 +167,10 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress // converged. See the following discussion for detail. // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet val intercept = yMean - dot(weights, Vectors.dense(featuresMean)) + if (handlePersistence) instances.unpersist() - if (handlePersistence) { - instances.unpersist() - } - new LinearRegressionModel(this, paramMap, weights, intercept) + // TODO: Converts to sparse format based on the storage, but may base on the scoring speed. + new LinearRegressionModel(this, paramMap, weights.compressed, intercept) } } @@ -198,15 +206,84 @@ class LinearRegressionModel private[ml] ( * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of * the corresponding joint dataset. * - - * * Compute gradient and loss for a Least-squared loss function, as used in linear regression. - * This is correct for the averaged least squares loss function (mean squared error) - * L = 1/2n ||A weights-y||^2 - * See also the documentation for the precise formulation. + * For improving the convergence rate during the optimization process, and also preventing against + * features with very large variances exerting an overly large influence during model training, + * package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce + * the condition number, and then trains the model in scaled space but returns the weights in + * the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf + * + * However, we don't want to apply the `StandardScaler` on the training dataset, and then cache + * the standardized dataset since it will create a lot of overhead. As a result, we perform the + * scaling implicitly when we compute the objective function. The following is the mathematical + * derivation. + * + * Note that we don't deal with intercept by adding bias here, because the intercept + * can be computed using closed form after the coefficients are converged. + * See this discussion for detail. + * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + * + * The objective function in the scaled space is given by + * {{{ + * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, + * }}} + * where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i, + * \bar{y} is the mean of label, and \hat{y} is the standard deviation of label. + * + * This can be rewritten as + * {{{ + * L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} + * + \bar{y} / \hat{y}||^2 + * = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2 + * }}} + * where w_i^\prime is the effective weights defined by w_i/\hat{x_i}, offset is + * {{{ + * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}. + * }}}, and diff is + * {{{ + * \sum_i w_i^\prime x_i - y / \hat{y} + offset + * }}} * - * @param weights weights/coefficients corresponding to features + * Note that the effective weights and offset don't depend on training dataset, + * so they can be precomputed. * - * @param updater Updater to be used to update weights after every iteration. + * Now, the first derivative of the objective function in scaled space is + * {{{ + * \frac{\partial L}{\partial\w_i} = diff/N (x_i - \bar{x_i}) / \hat{x_i} + * }}} + * However, ($x_i - \bar{x_i}$) will densify the computation, so it's not + * an ideal formula when the training dataset is sparse format. + * + * This can be addressed by adding the dense \bar{x_i} / \har{x_i} terms + * in the end by keeping the sum of diff. The first derivative of total + * objective function from all the samples is + * {{{ + * \frac{\partial L}{\partial\w_i} = + * 1/N \sum_j diff_j (x_{ij} - \bar{x_i}) / \hat{x_i} + * = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) - diffSum \bar{x_i}) / \hat{x_i}) + * = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + correction_i) + * }}}, + * where correction_i = - diffSum \bar{x_i}) / \hat{x_i} + * + * A simple math can show that diffSum is actually zero, so we don't even + * need to add the correction terms in the end. From the definition of diff, + * {{{ + * diffSum = \sum_j (\sum_i w_i(x_{ij} - \bar{x_i}) / \hat{x_i} - (y_j - \bar{y}) / \hat{y}) + * = N * (\sum_i w_i(\bar{x_i} - \bar{x_i}) / \hat{x_i} - (\bar{y_j} - \bar{y}) / \hat{y}) + * = 0 + * }}} + * + * As a result, the first derivative of the total objective function only depends on + * the training dataset, which can be easily computed in distributed fashion, and is + * sparse format friendly. + * {{{ + * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i}) + * }}}, + * + * @param weights The weights/coefficients corresponding to the features. + * @param labelStd The standard deviation value of the label. + * @param labelMean The mean value of the label. + * @param featuresStd The standard deviation values of the features. + * @param featuresMean The mean values of the features. */ private class LeastSquaresAggregator( weights: Vector, @@ -302,18 +379,6 @@ private class LeastSquaresAggregator( def gradient: Vector = { val result = Vectors.dense(gradientSumArray.clone()) - - val correction = { - val temp = effectiveWeightsArray.clone() - var i = 0 - while (i < temp.length) { - temp(i) *= featuresMean(i) - i += 1 - } - Vectors.dense(temp) - } - - axpy(-diffSum, correction, result) scal(1.0 / totalCnt, result) result } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 9fd60ff7a0c79..26be30ff9d6fd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -225,7 +225,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] throw new SparkException("Input validation failed.") } - /* + /** * Scaling columns to unit variance as a heuristic to reduce the condition number: * * During the optimization process, the convergence (rate) depends on the condition number of diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index d7bb943e84f53..91c2f4ccd3d17 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -103,17 +103,16 @@ object LinearDataGenerator { val rnd = new Random(seed) val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(rnd.nextDouble)) - - x.map(vector => { - // This doesn't work if `vector` is a sparse vector. - val vectorArray = vector.toArray - var i = 0 - while (i < vectorArray.size) { - vectorArray(i) = (vectorArray(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) - i += 1 - } - }) + Array.fill[Double](weights.length)(rnd.nextDouble())) + + x.foreach { + case v => + var i = 0 + while (i < v.length) { + v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + } val y = x.map { xi => blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() From c9d530e2e5123dbd4fd13fc487c890d6076b24bf Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 29 Apr 2015 14:55:32 -0700 Subject: [PATCH 185/191] [SPARK-6529] [ML] Add Word2Vec transformer See JIRA issue [here](https://issues.apache.org/jira/browse/SPARK-6529). There are some notes: 1. I add `learningRate` in sharedParams since it is a common parameter for ML algorithms. 2. We will not support transform of finding synonyms from a `Vector`, which will support in further JIRA issues. 3. Word2Vec is different with other ML models that its training set and transformed set are different. Its training set is an `RDD[Iterable[String]]` which represents documents, but the transformed set we want is an `RDD[String]` that represents unique words. So you have to switch your `inputCol` in these two stages. Author: Xusen Yin Closes #5596 from yinxusen/SPARK-6529 and squashes the following commits: ee2b37a [Xusen Yin] merge with former HEAD 4945462 [Xusen Yin] merge with #5626 3bc2cbd [Xusen Yin] change foldLeft to for loop and use blas 5dd4ee7 [Xusen Yin] fix scala style 743e0d5 [Xusen Yin] fix comments and code style 04c48e9 [Xusen Yin] ensure the functionality a190f2c [Xusen Yin] fix code style and refine the transform function of word2vec 02848fa [Xusen Yin] refine comments 34a55c0 [Xusen Yin] fix errors 109d124 [Xusen Yin] add test suite and pass it 04dde06 [Xusen Yin] add shared params c594095 [Xusen Yin] add word2vec transformer 23d77fa [Xusen Yin] merge with #5626 e8cfaf7 [Xusen Yin] fix conflict with master 66e7bd3 [Xusen Yin] change foldLeft to for loop and use blas 566ec20 [Xusen Yin] fix scala style b54399f [Xusen Yin] fix comments and code style 1211e86 [Xusen Yin] ensure the functionality 6b97ec8 [Xusen Yin] fix code style and refine the transform function of word2vec 7cde18f [Xusen Yin] rm sharedParams 618abd0 [Xusen Yin] refine comments e29680a [Xusen Yin] fix errors fe3afe9 [Xusen Yin] add test suite and pass it 02767fb [Xusen Yin] add shared params 6a514f1 [Xusen Yin] add word2vec transformer --- .../apache/spark/ml/feature/Word2Vec.scala | 185 ++++++++++++++++++ .../ml/param/shared/SharedParamsCodeGen.scala | 3 +- .../spark/ml/param/shared/sharedParams.scala | 17 ++ .../spark/ml/feature/Word2VecSuite.scala | 63 ++++++ 4 files changed, 267 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala new file mode 100644 index 0000000000000..0163fa8bd8a5b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +/** + * Params for [[Word2Vec]] and [[Word2VecModel]]. + */ +private[feature] trait Word2VecBase extends Params + with HasInputCol with HasOutputCol with HasMaxIter with HasStepSize with HasSeed { + + /** + * The dimension of the code that you want to transform from words. + */ + final val vectorSize = new IntParam( + this, "vectorSize", "the dimension of codes after transforming from words") + setDefault(vectorSize -> 100) + + /** @group getParam */ + def getVectorSize: Int = getOrDefault(vectorSize) + + /** + * Number of partitions for sentences of words. + */ + final val numPartitions = new IntParam( + this, "numPartitions", "number of partitions for sentences of words") + setDefault(numPartitions -> 1) + + /** @group getParam */ + def getNumPartitions: Int = getOrDefault(numPartitions) + + /** + * The minimum number of times a token must appear to be included in the word2vec model's + * vocabulary. + */ + final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + + "appear to be included in the word2vec model's vocabulary") + setDefault(minCount -> 5) + + /** @group getParam */ + def getMinCount: Int = getOrDefault(minCount) + + setDefault(stepSize -> 0.025) + setDefault(maxIter -> 1) + setDefault(seed -> 42L) + + /** + * Validate and transform the input schema. + */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), new ArrayType(StringType, true)) + SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT) + } +} + +/** + * :: AlphaComponent :: + * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further + * natural language processing or machine learning process. + */ +@AlphaComponent +final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setVectorSize(value: Int): this.type = set(vectorSize, value) + + /** @group setParam */ + def setStepSize(value: Double): this.type = set(stepSize, value) + + /** @group setParam */ + def setNumPartitions(value: Int): this.type = set(numPartitions, value) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + def setMinCount(value: Int): this.type = set(minCount, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): Word2VecModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val input = dataset.select(map(inputCol)).map { case Row(v: Seq[String]) => v } + val wordVectors = new feature.Word2Vec() + .setLearningRate(map(stepSize)) + .setMinCount(map(minCount)) + .setNumIterations(map(maxIter)) + .setNumPartitions(map(numPartitions)) + .setSeed(map(seed)) + .setVectorSize(map(vectorSize)) + .fit(input) + val model = new Word2VecModel(this, map, wordVectors) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[Word2Vec]]. + */ +@AlphaComponent +class Word2VecModel private[ml] ( + override val parent: Word2Vec, + override val fittingParamMap: ParamMap, + wordVectors: feature.Word2VecModel) + extends Model[Word2VecModel] with Word2VecBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Transform a sentence column to a vector column to represent the whole sentence. The transform + * is performed by averaging all word vectors it contains. + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) + val word2Vec = udf { sentence: Seq[String] => + if (sentence.size == 0) { + Vectors.sparse(map(vectorSize), Array.empty[Int], Array.empty[Double]) + } else { + val cum = Vectors.zeros(map(vectorSize)) + val model = bWordVectors.value.getVectors + for (word <- sentence) { + if (model.contains(word)) { + axpy(1.0, bWordVectors.value.transform(word), cum) + } else { + // pass words which not belong to model + } + } + scal(1.0 / sentence.size, cum) + cum + } + } + dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol)))) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 3f7e8f5a0b22c..654cd72d53074 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -48,7 +48,8 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")), ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"), - ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms")) + ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), + ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization.")) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 7d2c76d6c62c8..96d11ed76fa8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -310,4 +310,21 @@ trait HasTol extends Params { /** @group getParam */ final def getTol: Double = getOrDefault(tol) } + +/** + * :: DeveloperApi :: + * Trait for shared param stepSize. + */ +@DeveloperApi +trait HasStepSize extends Params { + + /** + * Param for Step size to be used for each iteration of optimization.. + * @group param + */ + final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization.") + + /** @group getParam */ + final def getStepSize: Double = getOrDefault(stepSize) +} // scalastyle:on diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala new file mode 100644 index 0000000000000..03ba86670d453 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} + +class Word2VecSuite extends FunSuite with MLlibTestSparkContext { + + test("Word2Vec") { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val numOfWords = sentence.split(" ").size + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + + val codes = Map( + "a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451), + "b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342), + "c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351) + ) + + val expected = doc.map { sentence => + Vectors.dense(sentence.map(codes.apply).reduce((word1, word2) => + word1.zip(word2).map { case (v1, v2) => v1 + v2 } + ).map(_ / numOfWords)) + } + + val docDF = doc.zip(expected).toDF("text", "expected") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .fit(docDF) + + model.transform(docDF).select("result", "expected").collect().foreach { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") + } + } +} + From d7dbce8f7da8a7fd01df6633a6043f51161b7d18 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Wed, 29 Apr 2015 15:34:05 -0700 Subject: [PATCH 186/191] [SPARK-7156][SQL] support RandomSplit in DataFrames This is built on top of kaka1992 's PR #5711 using Logical plans. Author: Burak Yavuz Closes #5761 from brkyvz/random-sample and squashes the following commits: a1fb0aa [Burak Yavuz] remove unrelated file 69669c3 [Burak Yavuz] fix broken test 1ddb3da [Burak Yavuz] copy base 6000328 [Burak Yavuz] added python api and fixed test 3c11d1b [Burak Yavuz] fixed broken test f400ade [Burak Yavuz] fix build errors 2384266 [Burak Yavuz] addressed comments v0.1 e98ebac [Burak Yavuz] [SPARK-7156][SQL] support RandomSplit in DataFrames --- .../main/scala/org/apache/spark/rdd/RDD.scala | 19 +++++++++- .../java/org/apache/spark/JavaAPISuite.java | 8 ++-- python/pyspark/sql/dataframe.py | 18 ++++++++- .../spark/sql/catalyst/dsl/package.scala | 6 --- .../plans/logical/basicOperators.scala | 18 ++++++++- .../org/apache/spark/sql/DataFrame.scala | 38 ++++++++++++++++++- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../spark/sql/execution/basicOperators.scala | 20 +++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 17 +++++++++ .../org/apache/spark/sql/hive/HiveQl.scala | 4 +- 10 files changed, 130 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index d80d94a588346..330255f89247f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -407,11 +407,26 @@ abstract class RDD[T: ClassTag]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new PartitionwiseSampledRDD[T, T]( - this, new BernoulliCellSampler[T](x(0), x(1)), true, seed) + randomSampleWithRange(x(0), x(1), seed) }.toArray } + /** + * Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability + * range. + * @param lb lower bound to use for the Bernoulli sampler + * @param ub upper bound to use for the Bernoulli sampler + * @param seed the seed for the Random number generator + * @return A random sub-sample of the RDD without replacement. + */ + private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = { + this.mapPartitionsWithIndex { case (index, partition) => + val sampler = new BernoulliCellSampler[T](lb, ub) + sampler.setSeed(seed + index) + sampler.sample(partition) + } + } + /** * Return a fixed-size sampled subset of this RDD in an array * diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 34ac9361d46c6..c2089b0e56a1f 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -157,11 +157,11 @@ public void sample() { public void randomSplit() { List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); JavaRDD rdd = sc.parallelize(ints); - JavaRDD[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 11); + JavaRDD[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31); Assert.assertEquals(3, splits.length); - Assert.assertEquals(2, splits[0].count()); - Assert.assertEquals(3, splits[1].count()); - Assert.assertEquals(5, splits[2].count()); + Assert.assertEquals(1, splits[0].count()); + Assert.assertEquals(2, splits[1].count()); + Assert.assertEquals(7, splits[2].count()); } @Test diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d9cbbc68b3bf0..3074af3ed2e83 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -426,7 +426,7 @@ def distinct(self): def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. - >>> df.sample(False, 0.5, 97).count() + >>> df.sample(False, 0.5, 42).count() 1 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction @@ -434,6 +434,22 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + def randomSplit(self, weights, seed=None): + """Randomly splits this :class:`DataFrame` with the provided weights. + + >>> splits = df4.randomSplit([1.0, 2.0], 24) + >>> splits[0].count() + 1 + + >>> splits[1].count() + 3 + """ + for w in weights: + assert w >= 0.0, "Negative weight value: %s" % w + seed = seed if seed is not None else random.randint(0, sys.maxsize) + rdd_array = self._jdf.randomSplit(_to_seq(self.sql_ctx._sc, weights), long(seed)) + return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] + @property def dtypes(self): """Returns all column names and their data types as a list. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5d5aba9644ff7..fa6cc7a1a36cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -278,12 +278,6 @@ package object dsl { def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan = Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) - def sample( - fraction: Double, - withReplacement: Boolean = true, - seed: Int = (math.random * 1000).toInt): LogicalPlan = - Sample(fraction, withReplacement, seed, logicalPlan) - // TODO specify the output column names def generate( generator: Generator, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 608e272da7784..21208c8a5c281 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -300,8 +300,22 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } -case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) - extends UnaryNode { +/** + * Sample the dataset. + * + * @param lowerBound Lower-bound of the sampling probability (usually 0.0) + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + * @param withReplacement Whether to sample with replacement. + * @param seed the random seed + * @param child the LogicalPlan + */ +case class Sample( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 0e896e5693b98..0d02e14c21be0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -706,7 +706,7 @@ class DataFrame private[sql]( * @group dfops */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { - Sample(fraction, withReplacement, seed, logicalPlan) + Sample(0.0, fraction, withReplacement, seed, logicalPlan) } /** @@ -720,6 +720,42 @@ class DataFrame private[sql]( sample(withReplacement, fraction, Utils.random.nextLong) } + /** + * Randomly splits this [[DataFrame]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + * @group dfops + */ + def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = { + val sum = weights.sum + val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) + normalizedCumWeights.sliding(2).map { x => + new DataFrame(sqlContext, Sample(x(0), x(1), false, seed, logicalPlan)) + }.toArray + } + + /** + * Randomly splits this [[DataFrame]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @group dfops + */ + def randomSplit(weights: Array[Double]): Array[DataFrame] = { + randomSplit(weights, Utils.random.nextLong) + } + + /** + * Randomly splits this [[DataFrame]] with the provided weights. Provided for the Python Api. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + * @group dfops + */ + def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = { + randomSplit(weights.toArray, seed) + } + /** * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index af58911cc0e6a..326e8ce4ca524 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -303,8 +303,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Expand(projections, output, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil - case logical.Sample(fraction, withReplacement, seed, child) => - execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil + case logical.Sample(lb, ub, withReplacement, seed, child) => + execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScan(output, data) :: Nil case logical.Limit(IntegerLiteral(limit), child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 1afdb409417ce..5ca11e67a9434 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -63,16 +63,32 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { /** * :: DeveloperApi :: + * Sample the dataset. + * @param lowerBound Lower-bound of the sampling probability (usually 0.0) + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * will be ub - lb. + * @param withReplacement Whether to sample with replacement. + * @param seed the random seed + * @param child the QueryPlan */ @DeveloperApi -case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan) +case class Sample( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long, + child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output // TODO: How to pick seed? override def execute(): RDD[Row] = { - child.execute().map(_.copy()).sample(withReplacement, fraction, seed) + if (withReplacement) { + child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed) + } else { + child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5ec06d448e50f..b70e127b4ed1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -510,6 +510,23 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name).toSeq === Seq("key", "valueRenamed", "newCol")) } + test("randomSplit") { + val n = 600 + val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id") + for (seed <- 1 to 5) { + val splits = data.randomSplit(Array[Double](1, 2, 3), seed) + assert(splits.length == 3, "wrong number of splits") + + assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == + data.collect().toList, "incomplete or wrong split") + + val s = splits.map(_.count()) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 + } + } + test("describe") { val describeTestData = Seq( ("Bob", 16, 176), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 2dc6463abafa7..0a86519e1412b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -887,13 +887,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), s"Sampling fraction ($fraction) must be on interval [0, 100]") - Sample(fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, + Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, relation) case Token("TOK_TABLEBUCKETSAMPLE", Token(numerator, Nil) :: Token(denominator, Nil) :: Nil) => val fraction = numerator.toDouble / denominator.toDouble - Sample(fraction, withReplacement = false, (math.random * 1000).toInt, relation) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) case a: ASTNode => throw new NotImplementedError( s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} : From 7f4b583733714bbecb43fb0823134bf2ec720a17 Mon Sep 17 00:00:00 2001 From: Qiping Li Date: Wed, 29 Apr 2015 23:52:16 +0100 Subject: [PATCH 187/191] [SPARK-7181] [CORE] fix inifite loop in Externalsorter's mergeWithAggregation see [SPARK-7181](https://issues.apache.org/jira/browse/SPARK-7181). Author: Qiping Li Closes #5737 from chouqin/externalsorter and squashes the following commits: 2924b93 [Qiping Li] fix inifite loop in Externalsorter's mergeWithAggregation --- .../org/apache/spark/util/collection/ExternalSorter.scala | 3 ++- .../apache/spark/util/collection/ExternalSorterSuite.scala | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index ef3cac622505e..4ed8a740f99db 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -527,7 +527,8 @@ private[spark] class ExternalSorter[K, V, C]( val k = elem._1 var c = elem._2 while (sorted.hasNext && sorted.head._1 == k) { - c = mergeCombiners(c, sorted.head._2) + val pair = sorted.next() + c = mergeCombiners(c, pair._2) } (k, c) } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 9ff067f86af44..de26aa351b0d2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -506,7 +506,10 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val ord = implicitly[Ordering[Int]] val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None) - sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i))) + + // avoid combine before spill + sorter.insertAll((0 until 50000).iterator.map(i => (i , 2 * i))) + sorter.insertAll((0 until 50000).iterator.map(i => (i, 2 * i + 1))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) From 3fc6cfd079d8cdd35574605cb9a4178ca7f2613d Mon Sep 17 00:00:00 2001 From: yongtang Date: Wed, 29 Apr 2015 23:55:51 +0100 Subject: [PATCH 188/191] [SPARK-7155] [CORE] Allow newAPIHadoopFile to support comma-separated list of files as input See JIRA: https://issues.apache.org/jira/browse/SPARK-7155 SparkContext's newAPIHadoopFile() does not support comma-separated list of files. For example, the following: ```scala sc.newAPIHadoopFile("/root/file1.txt,/root/file2.txt", classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) ``` will throw ``` org.apache.hadoop.mapreduce.lib.input.InvalidInputException: Input path does not exist: file:/root/file1.txt,/root/file2.txt ``` However, the other API hadoopFile() is able to process comma-separated list of files correctly. In addition, since sc.textFile() uses hadoopFile(), it is also able to process comma-separated list of files correctly. That means the behaviors of hadoopFile() and newAPIHadoopFile() are not aligned. This pull request fix this issue and allows newAPIHadoopFile() to support comma-separated list of files as input. A unit test has also been added in SparkContextSuite.scala. It creates two temporary text files as the input and tested against sc.textFile(), sc.hadoopFile(), and sc.newAPIHadoopFile(). Note: The contribution is my original work and that I license the work to the project under the project's open source license. Author: yongtang Closes #5708 from yongtang/SPARK-7155 and squashes the following commits: 654c80c [yongtang] [SPARK-7155] [CORE] Remove unneeded temp file deletion in unit test as parent dir is already temporary. 26faa6a [yongtang] [SPARK-7155] [CORE] Support comma-separated list of files as input for newAPIHadoopFile, wholeTextFiles, and binaryFiles. Use setInputPaths for consistency. 73e1f16 [yongtang] [SPARK-7155] [CORE] Allow newAPIHadoopFile to support comma-separated list of files as input. --- .../scala/org/apache/spark/SparkContext.scala | 12 +++- .../org/apache/spark/SparkContextSuite.scala | 63 ++++++++++++++++++- 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5ae8fb81de809..bae951f388337 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -713,7 +713,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli RDD[(String, String)] = { assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) - NewFileInputFormat.addInputPath(job, new Path(path)) + // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking + // comma separated files as input. (see SPARK-7155) + NewFileInputFormat.setInputPaths(job, path) val updateConf = job.getConfiguration new WholeTextFileRDD( this, @@ -759,7 +761,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli RDD[(String, PortableDataStream)] = { assertNotStopped() val job = new NewHadoopJob(hadoopConfiguration) - NewFileInputFormat.addInputPath(job, new Path(path)) + // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking + // comma separated files as input. (see SPARK-7155) + NewFileInputFormat.setInputPaths(job, path) val updateConf = job.getConfiguration new BinaryFileRDD( this, @@ -935,7 +939,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // The call to new NewHadoopJob automatically adds security credentials to conf, // so we don't need to explicitly add them ourselves val job = new NewHadoopJob(conf) - NewFileInputFormat.addInputPath(job, new Path(path)) + // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking + // comma separated files as input. (see SPARK-7155) + NewFileInputFormat.setInputPaths(job, path) val updatedConf = job.getConfiguration new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 728558a424780..9049db7755358 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -25,7 +25,9 @@ import com.google.common.io.Files import org.scalatest.FunSuite -import org.apache.hadoop.io.BytesWritable +import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.apache.spark.util.Utils import scala.concurrent.Await @@ -213,4 +215,63 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { sc.stop() } } + + test("Comma separated paths for newAPIHadoopFile/wholeTextFiles/binaryFiles (SPARK-7155)") { + // Regression test for SPARK-7155 + // dir1 and dir2 are used for wholeTextFiles and binaryFiles + val dir1 = Utils.createTempDir() + val dir2 = Utils.createTempDir() + + val dirpath1=dir1.getAbsolutePath + val dirpath2=dir2.getAbsolutePath + + // file1 and file2 are placed inside dir1, they are also used for + // textFile, hadoopFile, and newAPIHadoopFile + // file3, file4 and file5 are placed inside dir2, they are used for + // textFile, hadoopFile, and newAPIHadoopFile as well + val file1 = new File(dir1, "part-00000") + val file2 = new File(dir1, "part-00001") + val file3 = new File(dir2, "part-00000") + val file4 = new File(dir2, "part-00001") + val file5 = new File(dir2, "part-00002") + + val filepath1=file1.getAbsolutePath + val filepath2=file2.getAbsolutePath + val filepath3=file3.getAbsolutePath + val filepath4=file4.getAbsolutePath + val filepath5=file5.getAbsolutePath + + + try { + // Create 5 text files. + Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", file1, UTF_8) + Files.write("someline1 in file2\nsomeline2 in file2", file2, UTF_8) + Files.write("someline1 in file3", file3, UTF_8) + Files.write("someline1 in file4\nsomeline2 in file4", file4, UTF_8) + Files.write("someline1 in file2\nsomeline2 in file5", file5, UTF_8) + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + + // Test textFile, hadoopFile, and newAPIHadoopFile for file1 and file2 + assert(sc.textFile(filepath1 + "," + filepath2).count() == 5L) + assert(sc.hadoopFile(filepath1 + "," + filepath2, + classOf[TextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) + assert(sc.newAPIHadoopFile(filepath1 + "," + filepath2, + classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) + + // Test textFile, hadoopFile, and newAPIHadoopFile for file3, file4, and file5 + assert(sc.textFile(filepath3 + "," + filepath4 + "," + filepath5).count() == 5L) + assert(sc.hadoopFile(filepath3 + "," + filepath4 + "," + filepath5, + classOf[TextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) + assert(sc.newAPIHadoopFile(filepath3 + "," + filepath4 + "," + filepath5, + classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]).count() == 5L) + + // Test wholeTextFiles, and binaryFiles for dir1 and dir2 + assert(sc.wholeTextFiles(dirpath1 + "," + dirpath2).count() == 5L) + assert(sc.binaryFiles(dirpath1 + "," + dirpath2).count() == 5L) + + } finally { + sc.stop() + } + } } From f8cbb0a4b37b0d4ba49515d888cb52dea9eb01f1 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Wed, 29 Apr 2015 16:23:34 -0700 Subject: [PATCH 189/191] [SPARK-7229] [SQL] SpecificMutableRow should take integer type as internal representation for Date Author: Cheng Hao Closes #5772 from chenghao-intel/specific_row and squashes the following commits: 2cd064d [Cheng Hao] scala style issue 60347a2 [Cheng Hao] SpecificMutableRow should take integer type as internal representation for DateType --- .../sql/catalyst/expressions/SpecificMutableRow.scala | 1 + .../test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 3475ed05f4454..aa4099e4d7bf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -202,6 +202,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR case DoubleType => new MutableDouble case BooleanType => new MutableBoolean case LongType => new MutableLong + case DateType => new MutableInt // We use INT for DATE internally case _ => new MutableAny }.toArray) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index db096af4535a9..856a806781ec4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -256,6 +256,15 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) } + test("test DATE types in cache") { + val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").collect() + TestSQLContext + .jdbc(urlWithUserAndPass, "TEST.TIMETYPES").cache().registerTempTable("mycached_date") + val cachedRows = sql("select * from mycached_date").collect() + assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) + } + test("H2 floating-point types") { val rows = sql("SELECT * FROM flttypes").collect() assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==. From b1ef6a60ff6ea2adb43c6544e5311c11f4364f64 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 29 Apr 2015 16:35:17 -0700 Subject: [PATCH 190/191] [SPARK-7259] [ML] VectorIndexer: do not copy non-ML metadata to output column Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column. Removed ml.util.TestingUtils since VectorIndexer was the only use. CC: mengxr Author: Joseph K. Bradley Closes #5789 from jkbradley/vector-indexer-metadata and squashes the following commits: b28e159 [Joseph K. Bradley] Changed VectorIndexer so it does not carry non-ML metadata from the input to the output column. Removed ml.util.TestingUtils since VectorIndexer was the only use. --- .../spark/ml/feature/VectorIndexer.scala | 69 ++++++++++--------- .../spark/ml/feature/VectorIndexerSuite.scala | 7 +- .../apache/spark/ml/util/TestingUtils.scala | 60 ---------------- 3 files changed, 37 insertions(+), 99 deletions(-) delete mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 452faa06e2021..1e5ffd15afa51 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -233,6 +233,7 @@ private object VectorIndexer { * - Continuous features (columns) are left unchanged. * This also appends metadata to the output column, marking features as Numeric (continuous), * Nominal (categorical), or Binary (either continuous or categorical). + * Non-ML metadata is not carried over from the input to the output column. * * This maintains vector sparsity. * @@ -283,34 +284,40 @@ class VectorIndexerModel private[ml] ( // TODO: Check more carefully about whether this whole class will be included in a closure. + /** Per-vector transform function */ private val transformFunc: Vector => Vector = { - val sortedCategoricalFeatureIndices = categoryMaps.keys.toArray.sorted + val sortedCatFeatureIndices = categoryMaps.keys.toArray.sorted val localVectorMap = categoryMaps - val f: Vector => Vector = { - case dv: DenseVector => - val tmpv = dv.copy - localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => - tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) - } - tmpv - case sv: SparseVector => - // We use the fact that categorical value 0 is always mapped to index 0. - val tmpv = sv.copy - var catFeatureIdx = 0 // index into sortedCategoricalFeatureIndices - var k = 0 // index into non-zero elements of sparse vector - while (catFeatureIdx < sortedCategoricalFeatureIndices.length && k < tmpv.indices.length) { - val featureIndex = sortedCategoricalFeatureIndices(catFeatureIdx) - if (featureIndex < tmpv.indices(k)) { - catFeatureIdx += 1 - } else if (featureIndex > tmpv.indices(k)) { - k += 1 - } else { - tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) - catFeatureIdx += 1 - k += 1 + val localNumFeatures = numFeatures + val f: Vector => Vector = { (v: Vector) => + assert(v.size == localNumFeatures, "VectorIndexerModel expected vector of length" + + s" $numFeatures but found length ${v.size}") + v match { + case dv: DenseVector => + val tmpv = dv.copy + localVectorMap.foreach { case (featureIndex: Int, categoryMap: Map[Double, Int]) => + tmpv.values(featureIndex) = categoryMap(tmpv(featureIndex)) } - } - tmpv + tmpv + case sv: SparseVector => + // We use the fact that categorical value 0 is always mapped to index 0. + val tmpv = sv.copy + var catFeatureIdx = 0 // index into sortedCatFeatureIndices + var k = 0 // index into non-zero elements of sparse vector + while (catFeatureIdx < sortedCatFeatureIndices.length && k < tmpv.indices.length) { + val featureIndex = sortedCatFeatureIndices(catFeatureIdx) + if (featureIndex < tmpv.indices(k)) { + catFeatureIdx += 1 + } else if (featureIndex > tmpv.indices(k)) { + k += 1 + } else { + tmpv.values(k) = localVectorMap(featureIndex)(tmpv.values(k)) + catFeatureIdx += 1 + k += 1 + } + } + tmpv + } } f } @@ -326,13 +333,6 @@ class VectorIndexerModel private[ml] ( val map = extractParamMap(paramMap) val newField = prepOutputField(dataset.schema, map) val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol))) - // For now, just check the first row of inputCol for vector length. - val firstRow = dataset.select(map(inputCol)).take(1) - if (firstRow.length != 0) { - val actualNumFeatures = firstRow(0).getAs[Vector](0).size - require(numFeatures == actualNumFeatures, "VectorIndexerModel expected vector of length" + - s" $numFeatures but found length $actualNumFeatures") - } dataset.withColumn(map(outputCol), newCol.as(map(outputCol), newField.metadata)) } @@ -345,6 +345,7 @@ class VectorIndexerModel private[ml] ( s"VectorIndexerModel requires output column parameter: $outputCol") SchemaUtils.checkColumnType(schema, map(inputCol), dataType) + // If the input metadata specifies numFeatures, compare with expected numFeatures. val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { Some(origAttrGroup.attributes.get.length) @@ -364,7 +365,7 @@ class VectorIndexerModel private[ml] ( * Prepare the output column field, including per-feature metadata. * @param schema Input schema * @param map Parameter map (with this class' embedded parameter map folded in) - * @return Output column field + * @return Output column field. This field does not contain non-ML metadata. */ private def prepOutputField(schema: StructType, map: ParamMap): StructField = { val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) @@ -391,6 +392,6 @@ class VectorIndexerModel private[ml] ( partialFeatureAttributes } val newAttributeGroup = new AttributeGroup(map(outputCol), featureAttributes) - newAttributeGroup.toStructField(schema(map(inputCol)).metadata) + newAttributeGroup.toStructField() } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 1b261b2643854..38dc83b1241cf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -23,7 +23,6 @@ import org.scalatest.FunSuite import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.util.TestingUtils import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD @@ -111,8 +110,8 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { val model = vectorIndexer.fit(densePoints1) // vectors of length 3 model.transform(densePoints1) // should work model.transform(sparsePoints1) // should work - intercept[IllegalArgumentException] { - model.transform(densePoints2) + intercept[SparkException] { + model.transform(densePoints2).collect() println("Did not throw error when fit, transform were called on vectors of different lengths") } intercept[SparkException] { @@ -245,8 +244,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { // TODO: Once input features marked as categorical are handled correctly, check that here. } } - // Check that non-ML metadata are preserved. - TestingUtils.testPreserveMetadata(densePoints1WithMeta, model, "features", "indexed") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala deleted file mode 100644 index c44cb61b34171..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.util - -import org.apache.spark.ml.Transformer -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.types.MetadataBuilder -import org.scalatest.FunSuite - -private[ml] object TestingUtils extends FunSuite { - - /** - * Test whether unrelated metadata are preserved for this transformer. - * This attaches extra metadata to a column, transforms the column, and check to ensure the - * extra metadata have not changed. - * @param data Input dataset - * @param transformer Transformer to test - * @param inputCol Unique input column for Transformer. This must be the ONLY input column. - * @param outputCol Output column to test for metadata presence. - */ - def testPreserveMetadata( - data: DataFrame, - transformer: Transformer, - inputCol: String, - outputCol: String): Unit = { - // Create some fake metadata - val origMetadata = data.schema(inputCol).metadata - val metaKey = "__testPreserveMetadata__fake_key" - val metaValue = 12345 - assert(!origMetadata.contains(metaKey), - s"Unit test with testPreserveMetadata will fail since metadata key was present: $metaKey") - val newMetadata = - new MetadataBuilder().withMetadata(origMetadata).putLong(metaKey, metaValue).build() - // Add metadata to the inputCol - val withMetadata = data.select(data(inputCol).as(inputCol, newMetadata)) - // Transform, and ensure extra metadata was not affected - val transformed = transformer.transform(withMetadata) - val transMetadata = transformed.schema(outputCol).metadata - assert(transMetadata.contains(metaKey), - "Unit test with testPreserveMetadata failed; extra metadata key was not present.") - assert(transMetadata.getLong(metaKey) === metaValue, - "Unit test with testPreserveMetadata failed; extra metadata value was wrong." + - s" Expected $metaValue but found ${transMetadata.getLong(metaKey)}") - } -} From 1fdfdb47b44315ff8ccb0ef92e56d3f2a070f1f1 Mon Sep 17 00:00:00 2001 From: wangfei Date: Wed, 29 Apr 2015 17:00:24 -0700 Subject: [PATCH 191/191] [SQL] [Minor] Print detail query execution info when spark answer is not right Print detail query execution info including parsed/analyzed/optimized/Physical plan for query when spak answer is not rignt. ``` Results do not match for query: == Parsed Logical Plan == 'Aggregate ['x.str], ['x.str,SUM('x.strCount) AS c1#46] 'Join Inner, Some(('x.str = 'y.str)) 'UnresolvedRelation [df], Some(x) 'UnresolvedRelation [df], Some(y) == Analyzed Logical Plan == Aggregate [str#44], [str#44,SUM(strCount#45L) AS c1#46L] Join Inner, Some((str#44 = str#51)) Subquery x Subquery df Aggregate [str#44], [str#44,COUNT(str#44) AS strCount#45L] Project [_1#41 AS int#43,_2#42 AS str#44] LocalRelation [_1#41,_2#42], [[1,1],[2,2],[3,3]] Subquery y Subquery df Aggregate [str#51], [str#51,COUNT(str#51) AS strCount#47L] Project [_1#41 AS int#50,_2#42 AS str#51] LocalRelation [_1#41,_2#42], [[1,1],[2,2],[3,3]] == Optimized Logical Plan == Aggregate [str#44], [str#44,SUM(strCount#45L) AS c1#46L] Project [str#44,strCount#45L] Join Inner, Some((str#44 = str#51)) Aggregate [str#44], [str#44,COUNT(str#44) AS strCount#45L] LocalRelation [str#44], [[1],[2],[3]] Aggregate [str#51], [str#51] LocalRelation [str#51], [[1],[2],[3]] == Physical Plan == Aggregate false, [str#44], [str#44,CombineSum(PartialSum#53L) AS c1#46L] Aggregate true, [str#44], [str#44,SUM(strCount#45L) AS PartialSum#53L] Project [str#44,strCount#45L] BroadcastHashJoin [str#44], [str#51], BuildRight Aggregate false, [str#44], [str#44,Coalesce(SUM(PartialCount#55L),0) AS strCount#45L] Exchange (HashPartitioning [str#44], 5), [] Aggregate true, [str#44], [str#44,COUNT(str#44) AS PartialCount#55L] LocalTableScan [str#44], [[1],[2],[3]] Aggregate false, [str#51], [str#51] Exchange (HashPartitioning [str#51], 5), [] Aggregate true, [str#51], [str#51] LocalTableScan [str#51], [[1],[2],[3]] Code Generation: false == RDD == == Results == !== Correct Answer - 3 == == Spark Answer - 3 == [1,1] [1,1] ![2,3] [2,1] [3,1] [3,1] ``` Author: wangfei Closes #5774 from scwf/checkanswer and squashes the following commits: 5be6f78 [wangfei] print detail query execution info when Spark Answer is not right --- .../src/test/scala/org/apache/spark/sql/QueryTest.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 59f9508444f25..bbf9ab113ca43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -132,11 +132,7 @@ object QueryTest { val errorMessage = s""" |Results do not match for query: - |${df.logicalPlan} - |== Analyzed Plan == - |${df.queryExecution.analyzed} - |== Physical Plan == - |${df.queryExecution.executedPlan} + |${df.queryExecution} |== Results == |${sideBySide( s"== Correct Answer - ${expectedAnswer.size} ==" +:

Property NameDefaultMeaning
spark.reducer.maxMbInFlight48spark.reducer.maxSizeInFlight48m - Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since + Maximum size of map outputs to fetch simultaneously from each reduce task. Since each output requires us to create a buffer to receive it, this represents a fixed memory overhead per reduce task, so keep it small unless you have a large amount of memory.
spark.shuffle.file.buffer.kb32spark.shuffle.file.buffer32k - Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers + Size of the in-memory buffer for each shuffle file output stream. These buffers reduce the number of disk seeks and system calls made in creating intermediate shuffle files.
spark.io.compression.lz4.block.size32768spark.io.compression.lz4.blockSize32k - Block size (in bytes) used in LZ4 compression, in the case when LZ4 compression codec + Block size used in LZ4 compression, in the case when LZ4 compression codec is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used.
spark.io.compression.snappy.block.size32768spark.io.compression.snappy.blockSize32k - Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec + Block size used in Snappy compression, in the case when Snappy compression codec is used. Lowering this block size will also lower shuffle memory usage when Snappy is used.
spark.kryoserializer.buffer.max.mb64spark.kryoserializer.buffer.max64m - Maximum allowable size of Kryo serialization buffer, in megabytes. This must be larger than any + Maximum allowable size of Kryo serialization buffer. This must be larger than any object you attempt to serialize. Increase this if you get a "buffer limit exceeded" exception inside Kryo.
spark.kryoserializer.buffer.mb0.064spark.kryoserializer.buffer64k - Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer + Initial size of Kryo's serialization buffer. Note that there will be one buffer per core on each worker. This buffer will grow up to spark.kryoserializer.buffer.max.mb if needed.
Property NameDefaultMeaning
spark.broadcast.blockSize40964m - Size of each piece of a block in kilobytes for TorrentBroadcastFactory. + Size of each piece of a block for TorrentBroadcastFactory. Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, BlockManager might take a performance hit.
spark.storage.memoryMapThreshold20971522m - Size of a block, in bytes, above which Spark memory maps when reading a block from disk. + Size of a block above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory mapping has high overhead for blocks close to or below the page size of the operating system.
{submission.submissionId}{submission.submissionDate}{submission.command.mainClass}cpus: {submission.cores}, mem: {submission.mem}
{state.driverDescription.submissionId}{state.driverDescription.submissionDate}{state.driverDescription.command.mainClass}cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem}{state.startDate}{state.slaveId.getValue}{stateString(state.mesosTaskStatus)}
{submission.submissionId}{submission.submissionDate}{submission.command.mainClass}{submission.retryState.get.lastFailureStatus}{submission.retryState.get.nextRetry}{submission.retryState.get.retries}