From caf0136ec5838cf5bf61f39a5b3474a505a6ae11 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Fri, 24 Apr 2015 12:52:07 -0700 Subject: [PATCH 01/39] [SPARK-6852] [SPARKR] Accept numeric as numPartitions in SparkR. Author: Sun Rui Closes #5613 from sun-rui/SPARK-6852 and squashes the following commits: abaf02e [Sun Rui] Change the type of default numPartitions from integer to numeric in generics.R. 29d67c1 [Sun Rui] [SPARK-6852][SPARKR] Accept numeric as numPartitions in SparkR. --- R/pkg/R/RDD.R | 2 +- R/pkg/R/generics.R | 12 ++++++------ R/pkg/R/pairRDD.R | 24 ++++++++++++------------ 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index cc09efb1e5418..1662d6bb3b1ac 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -967,7 +967,7 @@ setMethod("keyBy", setMethod("repartition", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { - coalesce(x, numToInt(numPartitions), TRUE) + coalesce(x, numPartitions, TRUE) }) #' Return a new RDD that is reduced into numPartitions partitions. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 6c6233390134c..34dbe84051c50 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -60,7 +60,7 @@ setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) #' @rdname distinct #' @export -setGeneric("distinct", function(x, numPartitions = 1L) { standardGeneric("distinct") }) +setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) #' @rdname filterRDD #' @export @@ -182,7 +182,7 @@ setGeneric("setName", function(x, name) { standardGeneric("setName") }) #' @rdname sortBy #' @export setGeneric("sortBy", - function(x, func, ascending = TRUE, numPartitions = 1L) { + function(x, func, ascending = TRUE, numPartitions = 1) { standardGeneric("sortBy") }) @@ -244,7 +244,7 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") #' @rdname intersection #' @export -setGeneric("intersection", function(x, other, numPartitions = 1L) { +setGeneric("intersection", function(x, other, numPartitions = 1) { standardGeneric("intersection") }) #' @rdname keys @@ -346,21 +346,21 @@ setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("ri #' @rdname sortByKey #' @export setGeneric("sortByKey", - function(x, ascending = TRUE, numPartitions = 1L) { + function(x, ascending = TRUE, numPartitions = 1) { standardGeneric("sortByKey") }) #' @rdname subtract #' @export setGeneric("subtract", - function(x, other, numPartitions = 1L) { + function(x, other, numPartitions = 1) { standardGeneric("subtract") }) #' @rdname subtractByKey #' @export setGeneric("subtractByKey", - function(x, other, numPartitions = 1L) { + function(x, other, numPartitions = 1) { standardGeneric("subtractByKey") }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index f99b474ff8f2a..9791e55791bae 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -190,7 +190,7 @@ setMethod("flatMapValues", #' @rdname partitionBy #' @aliases partitionBy,RDD,integer-method setMethod("partitionBy", - signature(x = "RDD", numPartitions = "integer"), + signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, partitionFunc = hashCode) { #if (missing(partitionFunc)) { @@ -211,7 +211,7 @@ setMethod("partitionBy", # the content (key-val pairs). pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD", callJMethod(jrdd, "rdd"), - as.integer(numPartitions), + numToInt(numPartitions), serializedHashFuncBytes, getSerializedMode(x), packageNamesArr, @@ -221,7 +221,7 @@ setMethod("partitionBy", # Create a corresponding partitioner. rPartitioner <- newJObject("org.apache.spark.HashPartitioner", - as.integer(numPartitions)) + numToInt(numPartitions)) # Call partitionBy on the obtained PairwiseRDD. javaPairRDD <- callJMethod(pairwiseRRDD, "asJavaPairRDD") @@ -256,7 +256,7 @@ setMethod("partitionBy", #' @rdname groupByKey #' @aliases groupByKey,RDD,integer-method setMethod("groupByKey", - signature(x = "RDD", numPartitions = "integer"), + signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions) { shuffled <- partitionBy(x, numPartitions) groupVals <- function(part) { @@ -315,7 +315,7 @@ setMethod("groupByKey", #' @rdname reduceByKey #' @aliases reduceByKey,RDD,integer-method setMethod("reduceByKey", - signature(x = "RDD", combineFunc = "ANY", numPartitions = "integer"), + signature(x = "RDD", combineFunc = "ANY", numPartitions = "numeric"), function(x, combineFunc, numPartitions) { reduceVals <- function(part) { vals <- new.env() @@ -422,7 +422,7 @@ setMethod("reduceByKeyLocally", #' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method setMethod("combineByKey", signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY", - mergeCombiners = "ANY", numPartitions = "integer"), + mergeCombiners = "ANY", numPartitions = "numeric"), function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) { combineLocally <- function(part) { combiners <- new.env() @@ -483,7 +483,7 @@ setMethod("combineByKey", #' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method setMethod("aggregateByKey", signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", - combOp = "ANY", numPartitions = "integer"), + combOp = "ANY", numPartitions = "numeric"), function(x, zeroValue, seqOp, combOp, numPartitions) { createCombiner <- function(v) { do.call(seqOp, list(zeroValue, v)) @@ -514,7 +514,7 @@ setMethod("aggregateByKey", #' @aliases foldByKey,RDD,ANY,ANY,integer-method setMethod("foldByKey", signature(x = "RDD", zeroValue = "ANY", - func = "ANY", numPartitions = "integer"), + func = "ANY", numPartitions = "numeric"), function(x, zeroValue, func, numPartitions) { aggregateByKey(x, zeroValue, func, func, numPartitions) }) @@ -553,7 +553,7 @@ setMethod("join", joinTaggedList(v, list(FALSE, FALSE)) } - joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)), + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) @@ -582,7 +582,7 @@ setMethod("join", #' @rdname join-methods #' @aliases leftOuterJoin,RDD,RDD-method setMethod("leftOuterJoin", - signature(x = "RDD", y = "RDD", numPartitions = "integer"), + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) @@ -619,7 +619,7 @@ setMethod("leftOuterJoin", #' @rdname join-methods #' @aliases rightOuterJoin,RDD,RDD-method setMethod("rightOuterJoin", - signature(x = "RDD", y = "RDD", numPartitions = "integer"), + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) @@ -659,7 +659,7 @@ setMethod("rightOuterJoin", #' @rdname join-methods #' @aliases fullOuterJoin,RDD,RDD-method setMethod("fullOuterJoin", - signature(x = "RDD", y = "RDD", numPartitions = "integer"), + signature(x = "RDD", y = "RDD", numPartitions = "numeric"), function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) From 438859eb7c4e605bb4041d9a547a16be9c827c75 Mon Sep 17 00:00:00 2001 From: Calvin Jia Date: Fri, 24 Apr 2015 17:57:41 -0400 Subject: [PATCH 02/39] [SPARK-6122] [CORE] Upgrade tachyon-client version to 0.6.3 This is a reopening of #4867. A short summary of the issues resolved from the previous PR: 1. HTTPClient version mismatch: Selenium (used for UI tests) requires version 4.3.x, and Tachyon included 4.2.5 through a transitive dependency of its shaded thrift jar. To address this, Tachyon 0.6.3 will promote the transitive dependencies of the shaded jar so they can be excluded in spark. 2. Jackson-Mapper-ASL version mismatch: In lower versions of hadoop-client (ie. 1.0.4), version 1.0.1 is included. The parquet library used in spark sql requires version 1.8+. Its unclear to me why upgrading tachyon-client would cause this dependency to break. The solution was to exclude jackson-mapper-asl from hadoop-client. It seems that the dependency management in spark-parent will not work on transitive dependencies, one way to make sure jackson-mapper-asl is included with the correct version is to add it as a top level dependency. The best solution would be to exclude the dependency in the modules which require a higher version, but that did not fix the unit tests. Any suggestions on the best way to solve this would be appreciated! Author: Calvin Jia Closes #5354 from calvinjia/upgrade_tachyon_0.6.3 and squashes the following commits: 0eefe4d [Calvin Jia] Handle httpclient version in maven dependency management. Remove httpclient version setting from profiles. 7c00dfa [Calvin Jia] Set httpclient version to 4.3.2 for selenium. Specify version of httpclient for sql/hive (previously 4.2.5 transitive dependency of libthrift). 9263097 [Calvin Jia] Merge master to test latest changes dbfc1bd [Calvin Jia] Use Tachyon 0.6.4 for cleaner dependencies. e2ff80a [Calvin Jia] Exclude the jetty and curator promoted dependencies from tachyon-client. a3a29da [Calvin Jia] Update tachyon-client exclusions. 0ae6c97 [Calvin Jia] Change tachyon version to 0.6.3 a204df9 [Calvin Jia] Update make distribution tachyon version. a93c94f [Calvin Jia] Exclude jackson-mapper-asl from hadoop client since it has a lower version than spark's expected version. a8a923c [Calvin Jia] Exclude httpcomponents from Tachyon 910fabd [Calvin Jia] Update to master eed9230 [Calvin Jia] Update tachyon version to 0.6.1. 11907b3 [Calvin Jia] Use TachyonURI for tachyon paths instead of strings. 71bf441 [Calvin Jia] Upgrade Tachyon client version to 0.6.0. --- assembly/pom.xml | 10 ---------- core/pom.xml | 6 +++++- .../spark/storage/TachyonBlockManager.scala | 16 ++++++++-------- .../main/scala/org/apache/spark/util/Utils.scala | 4 +++- examples/pom.xml | 5 ----- launcher/pom.xml | 6 ++++++ make-distribution.sh | 2 +- pom.xml | 12 +++++++++++- sql/hive/pom.xml | 5 +++++ 9 files changed, 39 insertions(+), 27 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index f1f8b0d3682e2..20593e710dedb 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -213,16 +213,6 @@ - - kinesis-asl - - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - - - diff --git a/core/pom.xml b/core/pom.xml index e80829b7a7f3d..5e89d548cd47f 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -74,6 +74,10 @@ javax.servlet servlet-api + + org.codehaus.jackson + jackson-mapper-asl + @@ -275,7 +279,7 @@ org.tachyonproject tachyon-client - 0.5.0 + 0.6.4 org.apache.hadoop diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 951897cead996..583f1fdf0475b 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -20,8 +20,8 @@ package org.apache.spark.storage import java.text.SimpleDateFormat import java.util.{Date, Random} -import tachyon.client.TachyonFS -import tachyon.client.TachyonFile +import tachyon.TachyonURI +import tachyon.client.{TachyonFile, TachyonFS} import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode @@ -40,7 +40,7 @@ private[spark] class TachyonBlockManager( val master: String) extends Logging { - val client = if (master != null && master != "") TachyonFS.get(master) else null + val client = if (master != null && master != "") TachyonFS.get(new TachyonURI(master)) else null if (client == null) { logError("Failed to connect to the Tachyon as the master address is not configured") @@ -60,11 +60,11 @@ private[spark] class TachyonBlockManager( addShutdownHook() def removeFile(file: TachyonFile): Boolean = { - client.delete(file.getPath(), false) + client.delete(new TachyonURI(file.getPath()), false) } def fileExists(file: TachyonFile): Boolean = { - client.exist(file.getPath()) + client.exist(new TachyonURI(file.getPath())) } def getFile(filename: String): TachyonFile = { @@ -81,7 +81,7 @@ private[spark] class TachyonBlockManager( if (old != null) { old } else { - val path = tachyonDirs(dirId) + "/" + "%02x".format(subDirId) + val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}") client.mkdir(path) val newDir = client.getFile(path) subDirs(dirId)(subDirId) = newDir @@ -89,7 +89,7 @@ private[spark] class TachyonBlockManager( } } } - val filePath = subDir + "/" + filename + val filePath = new TachyonURI(s"$subDir/$filename") if(!client.exist(filePath)) { client.createFile(filePath) } @@ -113,7 +113,7 @@ private[spark] class TachyonBlockManager( tries += 1 try { tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - val path = rootDir + "/" + "spark-tachyon-" + tachyonDirId + val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId") if (!client.exist(path)) { foundLocalDir = client.mkdir(path) tachyonDir = client.getFile(path) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2feb7341b159b..667aa168e7ef3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -42,6 +42,8 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ + +import tachyon.TachyonURI import tachyon.client.{TachyonFS, TachyonFile} import org.apache.spark._ @@ -955,7 +957,7 @@ private[spark] object Utils extends Logging { * Delete a file or directory and its contents recursively. */ def deleteRecursively(dir: TachyonFile, client: TachyonFS) { - if (!client.delete(dir.getPath(), true)) { + if (!client.delete(new TachyonURI(dir.getPath()), true)) { throw new IOException("Failed to delete the tachyon dir: " + dir) } } diff --git a/examples/pom.xml b/examples/pom.xml index afd7c6d52f0dd..df1717403b673 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -390,11 +390,6 @@ spark-streaming-kinesis-asl_${scala.binary.version} ${project.version} - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - diff --git a/launcher/pom.xml b/launcher/pom.xml index 182e5f60218db..ebfa7685eaa18 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -68,6 +68,12 @@ org.apache.hadoop hadoop-client test + + + org.codehaus.jackson + jackson-mapper-asl + + diff --git a/make-distribution.sh b/make-distribution.sh index 738a9c4d69601..cb65932b4abc0 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -32,7 +32,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.5.0" +TACHYON_VERSION="0.6.4" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" diff --git a/pom.xml b/pom.xml index bcc2f57f1af5d..4b0b0c85eff21 100644 --- a/pom.xml +++ b/pom.xml @@ -146,7 +146,7 @@ 0.7.1 1.8.3 1.1.0 - 4.2.6 + 4.3.2 3.4.1 ${project.build.directory}/spark-test-classpath.txt 2.10.4 @@ -420,6 +420,16 @@ jsr305 1.3.9 + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + + + org.apache.httpcomponents + httpcore + ${commons.httpclient.version} + org.seleniumhq.selenium selenium-java diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 04440076a26a3..21dce8d8a565a 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -59,6 +59,11 @@ ${hive.group} hive-exec + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + org.codehaus.jackson jackson-mapper-asl From d874f8b546d8fae95bc92d8461b8189e51cb731b Mon Sep 17 00:00:00 2001 From: linweizhong Date: Fri, 24 Apr 2015 20:23:19 -0700 Subject: [PATCH 03/39] [PySpark][Minor] Update sql example, so that can read file correctly To run Spark, default will read file from HDFS if we don't set the schema. Author: linweizhong Closes #5684 from Sephiroth-Lin/pyspark_example_minor and squashes the following commits: 19fe145 [linweizhong] Update example sql.py, so that can read file correctly --- examples/src/main/python/sql.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index 87d7b088f077b..2c188759328f2 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -18,6 +18,7 @@ from __future__ import print_function import os +import sys from pyspark import SparkContext from pyspark.sql import SQLContext @@ -50,7 +51,11 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. - path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") + if len(sys.argv) < 2: + path = "file://" + \ + os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") + else: + path = sys.argv[1] # Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root From 59b7cfc41b2c06fbfbf6aca16c1619496a8d1d00 Mon Sep 17 00:00:00 2001 From: Deborah Siegel Date: Fri, 24 Apr 2015 20:25:07 -0700 Subject: [PATCH 04/39] [SPARK-7136][Docs] Spark SQL and DataFrame Guide fix example file and paths Changes example file for Generic Load/Save Functions to users.parquet rather than people.parquet which doesn't exist unless a later example has already been executed. Also adds filepaths. Author: Deborah Siegel Author: DEBORAH SIEGEL Author: DEBORAH SIEGEL Author: DEBORAH SIEGEL Closes #5693 from d3borah/master and squashes the following commits: 4d5e43b [Deborah Siegel] sparkSQL doc change b15a497 [Deborah Siegel] Revert "sparkSQL doc change" 5a2863c [DEBORAH SIEGEL] Merge remote-tracking branch 'upstream/master' 91972fc [DEBORAH SIEGEL] sparkSQL doc change f000e59 [DEBORAH SIEGEL] Merge remote-tracking branch 'upstream/master' db54173 [DEBORAH SIEGEL] fixed aggregateMessages example in graphX doc --- docs/sql-programming-guide.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 49b1e69f0e9db..b8233ae06fdf3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -681,8 +681,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config
{% highlight scala %} -val df = sqlContext.load("people.parquet") -df.select("name", "age").save("namesAndAges.parquet") +val df = sqlContext.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% endhighlight %}
@@ -691,8 +691,8 @@ df.select("name", "age").save("namesAndAges.parquet") {% highlight java %} -DataFrame df = sqlContext.load("people.parquet"); -df.select("name", "age").save("namesAndAges.parquet"); +DataFrame df = sqlContext.load("examples/src/main/resources/users.parquet"); +df.select("name", "favorite_color").save("namesAndFavColors.parquet"); {% endhighlight %} @@ -702,8 +702,8 @@ df.select("name", "age").save("namesAndAges.parquet"); {% highlight python %} -df = sqlContext.load("people.parquet") -df.select("name", "age").save("namesAndAges.parquet") +df = sqlContext.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% endhighlight %} @@ -722,7 +722,7 @@ using this syntax.
{% highlight scala %} -val df = sqlContext.load("people.json", "json") +val df = sqlContext.load("examples/src/main/resources/people.json", "json") df.select("name", "age").save("namesAndAges.parquet", "parquet") {% endhighlight %} @@ -732,7 +732,7 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet") {% highlight java %} -DataFrame df = sqlContext.load("people.json", "json"); +DataFrame df = sqlContext.load("examples/src/main/resources/people.json", "json"); df.select("name", "age").save("namesAndAges.parquet", "parquet"); {% endhighlight %} @@ -743,7 +743,7 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet"); {% highlight python %} -df = sqlContext.load("people.json", "json") +df = sqlContext.load("examples/src/main/resources/people.json", "json") df.select("name", "age").save("namesAndAges.parquet", "parquet") {% endhighlight %} From cca9905b93483614b330b09b36c6526b551e17dc Mon Sep 17 00:00:00 2001 From: KeheCAI Date: Sat, 25 Apr 2015 08:42:38 -0400 Subject: [PATCH 05/39] update the deprecated CountMinSketchMonoid function to TopPctCMS function http://twitter.github.io/algebird/index.html#com.twitter.algebird.legacy.CountMinSketchMonoid$ The CountMinSketchMonoid has been deprecated since 0.8.1. Newer code should use TopPctCMS.monoid(). ![image](https://cloud.githubusercontent.com/assets/1327396/7269619/d8b48b92-e8d5-11e4-8902-087f630e6308.png) Author: KeheCAI Closes #5629 from caikehe/master and squashes the following commits: e8aa06f [KeheCAI] update algebird-core to version 0.9.0 from 0.8.1 5653351 [KeheCAI] change scala code style 4c0dfd1 [KeheCAI] update the deprecated CountMinSketchMonoid function to TopPctCMS function --- examples/pom.xml | 2 +- .../apache/spark/examples/streaming/TwitterAlgebirdCMS.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/pom.xml b/examples/pom.xml index df1717403b673..5b04b4f8d6ca0 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -245,7 +245,7 @@ com.twitter algebird-core_${scala.binary.version} - 0.8.1 + 0.9.0 org.scalacheck diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala index 62f49530edb12..c10de84a80ffe 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala @@ -18,6 +18,7 @@ package org.apache.spark.examples.streaming import com.twitter.algebird._ +import com.twitter.algebird.CMSHasherImplicits._ import org.apache.spark.SparkConf import org.apache.spark.SparkContext._ @@ -67,7 +68,8 @@ object TwitterAlgebirdCMS { val users = stream.map(status => status.getUser.getId) - val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) + // val cms = new CountMinSketchMonoid(EPS, DELTA, SEED, PERC) + val cms = TopPctCMS.monoid[Long](EPS, DELTA, SEED, PERC) var globalCMS = cms.zero val mm = new MapMonoid[Long, Int]() var globalExact = Map[Long, Int]() From a61d65fc8b97c01be0fa756b52afdc91c46a8561 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Sat, 25 Apr 2015 10:37:34 -0700 Subject: [PATCH 06/39] Revert "[SPARK-6752][Streaming] Allow StreamingContext to be recreated from checkpoint and existing SparkContext" This reverts commit 534f2a43625fbf1a3a65d09550a19875cd1dce43. --- .../spark/api/java/function/Function0.java | 27 --- .../apache/spark/streaming/Checkpoint.scala | 26 +-- .../spark/streaming/StreamingContext.scala | 85 ++-------- .../api/java/JavaStreamingContext.scala | 119 +------------ .../apache/spark/streaming/JavaAPISuite.java | 145 ++++------------ .../spark/streaming/CheckpointSuite.scala | 3 +- .../streaming/StreamingContextSuite.scala | 159 ------------------ 7 files changed, 61 insertions(+), 503 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/api/java/function/Function0.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java deleted file mode 100644 index 38e410c5debe6..0000000000000 --- a/core/src/main/java/org/apache/spark/api/java/function/Function0.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java.function; - -import java.io.Serializable; - -/** - * A zero-argument function that returns an R. - */ -public interface Function0 extends Serializable { - public R call() throws Exception; -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 7bfae253c3a0c..0a50485118588 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -77,8 +77,7 @@ object Checkpoint extends Logging { } /** Get checkpoint files present in the give directory, ordered by oldest-first */ - def getCheckpointFiles(checkpointDir: String, fsOption: Option[FileSystem] = None): Seq[Path] = { - + def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = { def sortFunc(path1: Path, path2: Path): Boolean = { val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } @@ -86,7 +85,6 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) - val fs = fsOption.getOrElse(path.getFileSystem(new Configuration())) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { @@ -162,7 +160,7 @@ class CheckpointWriter( } // Delete old checkpoint files - val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) + val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) if (allCheckpointFiles.size > 10) { allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { logInfo("Deleting " + file) @@ -236,24 +234,15 @@ class CheckpointWriter( private[streaming] object CheckpointReader extends Logging { - /** - * Read checkpoint files present in the given checkpoint directory. If there are no checkpoint - * files, then return None, else try to return the latest valid checkpoint object. If no - * checkpoint files could be read correctly, then return None (if ignoreReadError = true), - * or throw exception (if ignoreReadError = false). - */ - def read( - checkpointDir: String, - conf: SparkConf, - hadoopConf: Configuration, - ignoreReadError: Boolean = false): Option[Checkpoint] = { + def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = + { val checkpointPath = new Path(checkpointDir) // TODO(rxin): Why is this a def?! def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files - val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse if (checkpointFiles.isEmpty) { return None } @@ -293,10 +282,7 @@ object CheckpointReader extends Logging { }) // If none of checkpoint files could be read, then throw exception - if (!ignoreReadError) { - throw new SparkException(s"Failed to read checkpoint from directory $checkpointPath") - } - None + throw new SparkException("Failed to read checkpoint from directory " + checkpointPath) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 90c8b47aebce0..f57f295874645 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -107,19 +107,6 @@ class StreamingContext private[streaming] ( */ def this(path: String) = this(path, new Configuration) - /** - * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. - * @param path Path to the directory that was specified as the checkpoint directory - * @param sparkContext Existing SparkContext - */ - def this(path: String, sparkContext: SparkContext) = { - this( - sparkContext, - CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get, - null) - } - - if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") @@ -128,12 +115,10 @@ class StreamingContext private[streaming] ( private[streaming] val isCheckpointPresent = (cp_ != null) private[streaming] val sc: SparkContext = { - if (sc_ != null) { - sc_ - } else if (isCheckpointPresent) { + if (isCheckpointPresent) { new SparkContext(cp_.createSparkConf()) } else { - throw new SparkException("Cannot create StreamingContext without a SparkContext") + sc_ } } @@ -144,7 +129,7 @@ class StreamingContext private[streaming] ( private[streaming] val conf = sc.conf - private[streaming] val env = sc.env + private[streaming] val env = SparkEnv.get private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { @@ -189,9 +174,7 @@ class StreamingContext private[streaming] ( /** Register streaming source to metrics system */ private val streamingSource = new StreamingSource(this) - assert(env != null) - assert(env.metricsSystem != null) - env.metricsSystem.registerSource(streamingSource) + SparkEnv.get.metricsSystem.registerSource(streamingSource) /** Enumeration to identify current state of the StreamingContext */ private[streaming] object StreamingContextState extends Enumeration { @@ -638,59 +621,19 @@ object StreamingContext extends Logging { hadoopConf: Configuration = new Configuration(), createOnError: Boolean = false ): StreamingContext = { - val checkpointOption = CheckpointReader.read( - checkpointPath, new SparkConf(), hadoopConf, createOnError) + val checkpointOption = try { + CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf) + } catch { + case e: Exception => + if (createOnError) { + None + } else { + throw e + } + } checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the StreamingContext - * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note - * that the SparkConf configuration in the checkpoint data will not be restored as the - * SparkContext has already been created. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new StreamingContext using the given SparkContext - * @param sparkContext SparkContext using which the StreamingContext will be created - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: SparkContext => StreamingContext, - sparkContext: SparkContext - ): StreamingContext = { - getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the StreamingContext - * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note - * that the SparkConf configuration in the checkpoint data will not be restored as the - * SparkContext has already been created. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new StreamingContext using the given SparkContext - * @param sparkContext SparkContext using which the StreamingContext will be created - * @param createOnError Whether to create a new StreamingContext if there is an - * error in reading checkpoint data. By default, an exception will be - * thrown on error. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: SparkContext => StreamingContext, - sparkContext: SparkContext, - createOnError: Boolean - ): StreamingContext = { - val checkpointOption = CheckpointReader.read( - checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError) - checkpointOption.map(new StreamingContext(sparkContext, _, null)) - .getOrElse(creatingFunc(sparkContext)) - } - /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 572d7d8e8753d..4095a7cc84946 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -32,14 +32,13 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} -import org.apache.spark.api.java.function.{Function0 => JFunction0} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.receiver.Receiver import org.apache.hadoop.conf.Configuration +import org.apache.spark.streaming.dstream.{PluggableInputDStream, ReceiverInputDStream, DStream} +import org.apache.spark.streaming.receiver.Receiver /** * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main @@ -656,7 +655,6 @@ object JavaStreamingContext { * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext */ - @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( checkpointPath: String, factory: JavaStreamingContextFactory @@ -678,7 +676,6 @@ object JavaStreamingContext { * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system */ - @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -703,7 +700,6 @@ object JavaStreamingContext { * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. */ - @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, hadoopConf: Configuration, @@ -716,117 +712,6 @@ object JavaStreamingContext { new JavaStreamingContext(ssc) } - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction0[JavaStreamingContext] - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - creatingFunc.call().ssc - }) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible - * file system - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction0[JavaStreamingContext], - hadoopConf: Configuration - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - creatingFunc.call().ssc - }, hadoopConf) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible - * file system - * @param createOnError Whether to create a new JavaStreamingContext if there is an - * error in reading checkpoint data. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction0[JavaStreamingContext], - hadoopConf: Configuration, - createOnError: Boolean - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, () => { - creatingFunc.call().ssc - }, hadoopConf, createOnError) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param sparkContext SparkContext using which the StreamingContext will be created - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], - sparkContext: JavaSparkContext - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { - creatingFunc.call(new JavaSparkContext(sparkContext)).ssc - }, sparkContext.sc) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param sparkContext SparkContext using which the StreamingContext will be created - * @param createOnError Whether to create a new JavaStreamingContext if there is an - * error in reading checkpoint data. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], - sparkContext: JavaSparkContext, - createOnError: Boolean - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { - creatingFunc.call(new JavaSparkContext(sparkContext)).ssc - }, sparkContext.sc, createOnError) - new JavaStreamingContext(ssc) - } - /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index cb2e8380b4933..90340753a4eed 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -22,12 +22,10 @@ import java.nio.charset.Charset; import java.util.*; -import org.apache.commons.lang.mutable.MutableBoolean; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.LongWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; - import scala.Tuple2; import org.junit.Assert; @@ -47,7 +45,6 @@ import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; import org.apache.spark.util.Utils; -import org.apache.spark.SparkConf; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -932,7 +929,7 @@ public void testPairMap() { // Maps pair -> pair of different type public Tuple2 call(Tuple2 in) throws Exception { return in.swap(); } - }); + }); JavaTestUtils.attachTestOutputStream(reversed); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -990,12 +987,12 @@ public void testPairMap2() { // Maps pair -> single JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaDStream reversed = pairStream.map( - new Function, Integer>() { - @Override - public Integer call(Tuple2 in) throws Exception { - return in._2(); - } - }); + new Function, Integer>() { + @Override + public Integer call(Tuple2 in) throws Exception { + return in._2(); + } + }); JavaTestUtils.attachTestOutputStream(reversed); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -1126,7 +1123,7 @@ public void testCombineByKey() { JavaPairDStream combined = pairStream.combineByKey( new Function() { - @Override + @Override public Integer call(Integer i) throws Exception { return i; } @@ -1147,14 +1144,14 @@ public void testCountByValue() { Arrays.asList("hello")); List>> expected = Arrays.asList( - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("world", 1L)), - Arrays.asList( - new Tuple2("hello", 1L), - new Tuple2("moon", 1L)), - Arrays.asList( - new Tuple2("hello", 1L))); + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("world", 1L)), + Arrays.asList( + new Tuple2("hello", 1L), + new Tuple2("moon", 1L)), + Arrays.asList( + new Tuple2("hello", 1L))); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream counted = stream.countByValue(); @@ -1252,17 +1249,17 @@ public void testUpdateStateByKey() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v : values) { - out = out + v; - } - return Optional.of(out); + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v: values) { + out = out + v; } + return Optional.of(out); + } }); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1295,17 +1292,17 @@ public void testUpdateStateByKeyWithInitial() { JavaPairDStream updated = pairStream.updateStateByKey( new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - int out = 0; - if (state.isPresent()) { - out = out + state.get(); - } - for (Integer v : values) { - out = out + v; - } - return Optional.of(out); + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v: values) { + out = out + v; } + return Optional.of(out); + } }, new HashPartitioner(1), initialRDD); JavaTestUtils.attachTestOutputStream(updated); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1331,7 +1328,7 @@ public void testReduceByKeyAndWindowWithInverse() { JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), - new Duration(2000), new Duration(1000)); + new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1710,74 +1707,6 @@ public Integer call(String s) throws Exception { Utils.deleteRecursively(tempDir); } - @SuppressWarnings("unchecked") - @Test - public void testContextGetOrCreate() throws InterruptedException { - - final SparkConf conf = new SparkConf() - .setMaster("local[2]") - .setAppName("test") - .set("newContext", "true"); - - File emptyDir = Files.createTempDir(); - emptyDir.deleteOnExit(); - StreamingContextSuite contextSuite = new StreamingContextSuite(); - String corruptedCheckpointDir = contextSuite.createCorruptedCheckpoint(); - String checkpointDir = contextSuite.createValidCheckpoint(); - - // Function to create JavaStreamingContext without any output operations - // (used to detect the new context) - final MutableBoolean newContextCreated = new MutableBoolean(false); - Function0 creatingFunc = new Function0() { - public JavaStreamingContext call() { - newContextCreated.setValue(true); - return new JavaStreamingContext(conf, Seconds.apply(1)); - } - }; - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc); - Assert.assertTrue("new context not created", newContextCreated.isTrue()); - ssc.stop(); - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration(), true); - Assert.assertTrue("new context not created", newContextCreated.isTrue()); - ssc.stop(); - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, - new org.apache.hadoop.conf.Configuration()); - Assert.assertTrue("old context not recovered", newContextCreated.isFalse()); - ssc.stop(); - - // Function to create JavaStreamingContext using existing JavaSparkContext - // without any output operations (used to detect the new context) - Function creatingFunc2 = - new Function() { - public JavaStreamingContext call(JavaSparkContext context) { - newContextCreated.setValue(true); - return new JavaStreamingContext(context, Seconds.apply(1)); - } - }; - - JavaSparkContext sc = new JavaSparkContext(conf); - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc); - Assert.assertTrue("new context not created", newContextCreated.isTrue()); - ssc.stop(false); - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true); - Assert.assertTrue("new context not created", newContextCreated.isTrue()); - ssc.stop(false); - - newContextCreated.setValue(false); - ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc); - Assert.assertTrue("old context not recovered", newContextCreated.isFalse()); - ssc.stop(); - } /* TEST DISABLED: Pending a discussion about checkpoint() semantics with TD @SuppressWarnings("unchecked") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 6b0a3f91d4d06..54c30440a6e8d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -430,8 +430,9 @@ class CheckpointSuite extends TestSuiteBase { assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3) } // Wait for a checkpoint to be written + val fs = new Path(checkpointDir).getFileSystem(ssc.sc.hadoopConfiguration) eventually(eventuallyTimeout) { - assert(Checkpoint.getCheckpointFiles(checkpointDir).size === 6) + assert(Checkpoint.getCheckpointFiles(checkpointDir, fs).size === 6) } ssc.stop() // Check that we shut down while the third batch was being processed diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 4f193322ad33e..58353a5f97c8a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,10 +17,8 @@ package org.apache.spark.streaming -import java.io.File import java.util.concurrent.atomic.AtomicInteger -import org.apache.commons.io.FileUtils import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ @@ -332,139 +330,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } - test("getOrCreate") { - val conf = new SparkConf().setMaster(master).setAppName(appName) - - // Function to create StreamingContext that has a config to identify it to be new context - var newContextCreated = false - def creatingFunction(): StreamingContext = { - newContextCreated = true - new StreamingContext(conf, batchDuration) - } - - // Call ssc.stop after a body of code - def testGetOrCreate(body: => Unit): Unit = { - newContextCreated = false - try { - body - } finally { - if (ssc != null) { - ssc.stop() - } - ssc = null - } - } - - val emptyPath = Utils.createTempDir().getAbsolutePath() - - // getOrCreate should create new context with empty path - testGetOrCreate { - ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - } - - val corrutedCheckpointPath = createCorruptedCheckpoint() - - // getOrCreate should throw exception with fake checkpoint file and createOnError = false - intercept[Exception] { - ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _) - } - - // getOrCreate should throw exception with fake checkpoint file - intercept[Exception] { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, createOnError = false) - } - - // getOrCreate should create new context with fake checkpoint file and createOnError = true - testGetOrCreate { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - } - - val checkpointPath = createValidCheckpoint() - - // getOrCreate should recover context with checkpoint path, and recover old configuration - testGetOrCreate { - ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) - assert(ssc != null, "no context created") - assert(!newContextCreated, "old context not recovered") - assert(ssc.conf.get("someKey") === "someValue") - } - } - - test("getOrCreate with existing SparkContext") { - val conf = new SparkConf().setMaster(master).setAppName(appName) - sc = new SparkContext(conf) - - // Function to create StreamingContext that has a config to identify it to be new context - var newContextCreated = false - def creatingFunction(sparkContext: SparkContext): StreamingContext = { - newContextCreated = true - new StreamingContext(sparkContext, batchDuration) - } - - // Call ssc.stop(stopSparkContext = false) after a body of cody - def testGetOrCreate(body: => Unit): Unit = { - newContextCreated = false - try { - body - } finally { - if (ssc != null) { - ssc.stop(stopSparkContext = false) - } - ssc = null - } - } - - val emptyPath = Utils.createTempDir().getAbsolutePath() - - // getOrCreate should create new context with empty path - testGetOrCreate { - ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - } - - val corrutedCheckpointPath = createCorruptedCheckpoint() - - // getOrCreate should throw exception with fake checkpoint file and createOnError = false - intercept[Exception] { - ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc) - } - - // getOrCreate should throw exception with fake checkpoint file - intercept[Exception] { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, sc, createOnError = false) - } - - // getOrCreate should create new context with fake checkpoint file and createOnError = true - testGetOrCreate { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, sc, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - } - - val checkpointPath = createValidCheckpoint() - - // StreamingContext.getOrCreate should recover context with checkpoint path - testGetOrCreate { - ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc) - assert(ssc != null, "no context created") - assert(!newContextCreated, "old context not recovered") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - assert(!ssc.conf.contains("someKey"), - "recovered StreamingContext unexpectedly has old config") - } - } - test("DStream and generated RDD creation sites") { testPackage.test() } @@ -474,30 +339,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val inputStream = new TestInputStream(s, input, 1) inputStream } - - def createValidCheckpoint(): String = { - val testDirectory = Utils.createTempDir().getAbsolutePath() - val checkpointDirectory = Utils.createTempDir().getAbsolutePath() - val conf = new SparkConf().setMaster(master).setAppName(appName) - conf.set("someKey", "someValue") - ssc = new StreamingContext(conf, batchDuration) - ssc.checkpoint(checkpointDirectory) - ssc.textFileStream(testDirectory).foreachRDD { rdd => rdd.count() } - ssc.start() - eventually(timeout(10000 millis)) { - assert(Checkpoint.getCheckpointFiles(checkpointDirectory).size > 1) - } - ssc.stop() - checkpointDirectory - } - - def createCorruptedCheckpoint(): String = { - val checkpointDirectory = Utils.createTempDir().getAbsolutePath() - val fakeCheckpointFile = Checkpoint.checkpointFile(checkpointDirectory, Time(1000)) - FileUtils.write(new File(fakeCheckpointFile.toString()), "blablabla") - assert(Checkpoint.getCheckpointFiles(checkpointDirectory).nonEmpty) - checkpointDirectory - } } class TestException(msg: String) extends Exception(msg) From a7160c4e3aae22600d05e257d0b4d2428754b8ea Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 25 Apr 2015 12:27:19 -0700 Subject: [PATCH 07/39] [SPARK-6113] [ML] Tree ensembles for Pipelines API This is a continuation of [https://github.com/apache/spark/pull/5530] (which was for Decision Trees), but for ensembles: Random Forests and Gradient-Boosted Trees. Please refer to the JIRA [https://issues.apache.org/jira/browse/SPARK-6113], the design doc linked from the JIRA, and the previous PR linked above for design discussions. This PR follows the example set by the previous PR for Decision Trees. It includes a few cleanups to Decision Trees. Note: There is one issue which will be addressed in a separate PR: Ensembles' component Models have no parent or fittingParamMap. I plan to submit a separate PR which makes those values in Model be Options. It does not matter much which PR gets merged first. CC: mengxr manishamde codedeft chouqin Author: Joseph K. Bradley Closes #5626 from jkbradley/dt-api-ensembles and squashes the following commits: 729167a [Joseph K. Bradley] small cleanups based on code review bbae2a2 [Joseph K. Bradley] Updated per all comments in code review 855aa9a [Joseph K. Bradley] scala style fix ea3d901 [Joseph K. Bradley] Added GBT to spark.ml, with tests and examples c0f30c1 [Joseph K. Bradley] Added random forests and test suites to spark.ml. Not tested yet. Need to add example as well d045ebd [Joseph K. Bradley] some more updates, but far from done ee1a10b [Joseph K. Bradley] Added files from old PR and did some initial updates. --- .../examples/ml/DecisionTreeExample.scala | 139 ++++++---- .../apache/spark/examples/ml/GBTExample.scala | 238 +++++++++++++++++ .../examples/ml/RandomForestExample.scala | 248 +++++++++++++++++ .../mllib/GradientBoostedTreesRunner.scala | 1 + .../scala/org/apache/spark/ml/Model.scala | 2 + .../DecisionTreeClassifier.scala | 24 +- .../ml/classification/GBTClassifier.scala | 228 ++++++++++++++++ .../RandomForestClassifier.scala | 185 +++++++++++++ .../spark/ml/impl/tree/treeParams.scala | 249 +++++++++++++++--- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +- .../spark/ml/param/shared/sharedParams.scala | 20 ++ .../ml/regression/DecisionTreeRegressor.scala | 14 +- .../spark/ml/regression/GBTRegressor.scala | 218 +++++++++++++++ .../ml/regression/RandomForestRegressor.scala | 167 ++++++++++++ .../scala/org/apache/spark/ml/tree/Node.scala | 6 +- .../org/apache/spark/ml/tree/Split.scala | 22 +- .../org/apache/spark/ml/tree/treeModels.scala | 46 +++- .../JavaDecisionTreeClassifierSuite.java | 10 +- .../JavaGBTClassifierSuite.java | 100 +++++++ .../JavaRandomForestClassifierSuite.java | 103 ++++++++ .../JavaDecisionTreeRegressorSuite.java | 26 +- .../ml/regression/JavaGBTRegressorSuite.java | 99 +++++++ .../JavaRandomForestRegressorSuite.java | 102 +++++++ .../DecisionTreeClassifierSuite.scala | 2 +- .../classification/GBTClassifierSuite.scala | 136 ++++++++++ .../RandomForestClassifierSuite.scala | 166 ++++++++++++ .../org/apache/spark/ml/impl/TreeTests.scala | 10 +- .../DecisionTreeRegressorSuite.scala | 2 +- .../ml/regression/GBTRegressorSuite.scala | 137 ++++++++++ .../RandomForestRegressorSuite.scala | 122 +++++++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 6 +- 31 files changed, 2658 insertions(+), 174 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 2cd515c89d3d2..9002e99d82ad3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -22,10 +22,9 @@ import scala.language.reflectiveCalls import scopt.OptionParser -import org.apache.spark.ml.tree.DecisionTreeModel import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams -import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer} import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer} import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} @@ -64,8 +63,6 @@ object DecisionTreeExample { maxBins: Int = 32, minInstancesPerNode: Int = 1, minInfoGain: Double = 0.0, - numTrees: Int = 1, - featureSubsetStrategy: String = "auto", fracTest: Double = 0.2, cacheNodeIds: Boolean = false, checkpointDir: Option[String] = None, @@ -123,8 +120,8 @@ object DecisionTreeExample { .required() .action((x, c) => c.copy(input = x)) checkConfig { params => - if (params.fracTest < 0 || params.fracTest > 1) { - failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") } else { success } @@ -200,9 +197,18 @@ object DecisionTreeExample { throw new IllegalArgumentException("Algo ${params.algo} not supported.") } } - val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache()) + val dataframes = splits.map(_.toDF()).map(labelsToStrings) + val training = dataframes(0).cache() + val test = dataframes(1).cache() - (dataframes(0), dataframes(1)) + val numTraining = training.count() + val numTest = test.count() + val numFeatures = training.select("features").first().getAs[Vector](0).size + println("Loaded data:") + println(s" numTraining = $numTraining, numTest = $numTest") + println(s" numFeatures = $numFeatures") + + (training, test) } def run(params: Params) { @@ -217,13 +223,6 @@ object DecisionTreeExample { val (training: DataFrame, test: DataFrame) = loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest) - val numTraining = training.count() - val numTest = test.count() - val numFeatures = training.select("features").first().getAs[Vector](0).size - println("Loaded data:") - println(s" numTraining = $numTraining, numTest = $numTest") - println(s" numFeatures = $numFeatures") - // Set up Pipeline val stages = new mutable.ArrayBuffer[PipelineStage]() // (1) For classification, re-index classes. @@ -241,7 +240,7 @@ object DecisionTreeExample { .setOutputCol("indexedFeatures") .setMaxCategories(10) stages += featuresIndexer - // (3) Learn DecisionTree + // (3) Learn Decision Tree val dt = algo match { case "classification" => new DecisionTreeClassifier() @@ -275,62 +274,86 @@ object DecisionTreeExample { println(s"Training time: $elapsedTime seconds") // Get the trained Decision Tree from the fitted PipelineModel - val treeModel: DecisionTreeModel = algo match { + algo match { case "classification" => - pipelineModel.getModel[DecisionTreeClassificationModel]( + val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel]( dt.asInstanceOf[DecisionTreeClassifier]) + if (treeModel.numNodes < 20) { + println(treeModel.toDebugString) // Print full model. + } else { + println(treeModel) // Print model summary. + } case "regression" => - pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor]) - case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") - } - if (treeModel.numNodes < 20) { - println(treeModel.toDebugString) // Print full model. - } else { - println(treeModel) // Print model summary. - } - - // Predict on training - val trainingFullPredictions = pipelineModel.transform(training).cache() - val trainingPredictions = trainingFullPredictions.select("prediction") - .map(_.getDouble(0)) - val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0)) - // Predict on test data - val testFullPredictions = pipelineModel.transform(test).cache() - val testPredictions = testFullPredictions.select("prediction") - .map(_.getDouble(0)) - val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0)) - - // For classification, print number of classes for reference. - if (algo == "classification") { - val numClasses = - MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match { - case Some(n) => n - case None => throw new RuntimeException( - "DecisionTreeExample had unknown failure when indexing labels for classification.") + val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel]( + dt.asInstanceOf[DecisionTreeRegressor]) + if (treeModel.numNodes < 20) { + println(treeModel.toDebugString) // Print full model. + } else { + println(treeModel) // Print model summary. } - println(s"numClasses = $numClasses.") + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") } // Evaluate model on training, test data algo match { case "classification" => - val trainingAccuracy = - new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision - println(s"Train accuracy = $trainingAccuracy") - val testAccuracy = - new MulticlassMetrics(testPredictions.zip(testLabels)).precision - println(s"Test accuracy = $testAccuracy") + println("Training data results:") + evaluateClassificationModel(pipelineModel, training, labelColName) + println("Test data results:") + evaluateClassificationModel(pipelineModel, test, labelColName) case "regression" => - val trainingRMSE = - new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError - println(s"Training root mean squared error (RMSE) = $trainingRMSE") - val testRMSE = - new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError - println(s"Test root mean squared error (RMSE) = $testRMSE") + println("Training data results:") + evaluateRegressionModel(pipelineModel, training, labelColName) + println("Test data results:") + evaluateRegressionModel(pipelineModel, test, labelColName) case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") } sc.stop() } + + /** + * Evaluate the given ClassificationModel on data. Print the results. + * @param model Must fit ClassificationModel abstraction + * @param data DataFrame with "prediction" and labelColName columns + * @param labelColName Name of the labelCol parameter for the model + * + * TODO: Change model type to ClassificationModel once that API is public. SPARK-5995 + */ + private[ml] def evaluateClassificationModel( + model: Transformer, + data: DataFrame, + labelColName: String): Unit = { + val fullPredictions = model.transform(data).cache() + val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + // Print number of classes for reference + val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match { + case Some(n) => n + case None => throw new RuntimeException( + "Unknown failure when indexing labels for classification.") + } + val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision + println(s" Accuracy ($numClasses classes): $accuracy") + } + + /** + * Evaluate the given RegressionModel on data. Print the results. + * @param model Must fit RegressionModel abstraction + * @param data DataFrame with "prediction" and labelColName columns + * @param labelColName Name of the labelCol parameter for the model + * + * TODO: Change model type to RegressionModel once that API is public. SPARK-5995 + */ + private[ml] def evaluateRegressionModel( + model: Transformer, + data: DataFrame, + labelColName: String): Unit = { + val fullPredictions = model.transform(data).cache() + val predictions = fullPredictions.select("prediction").map(_.getDouble(0)) + val labels = fullPredictions.select(labelColName).map(_.getDouble(0)) + val RMSE = new RegressionMetrics(predictions.zip(labels)).rootMeanSquaredError + println(s" Root mean squared error (RMSE): $RMSE") + } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala new file mode 100644 index 0000000000000..5fccb142d4c3d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} +import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} +import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} +import org.apache.spark.sql.DataFrame + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.GBTExample [options] + * }}} + * Decision Trees and ensembles can take a large amount of memory. If the run-example command + * above fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.GBTExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object GBTExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + maxIter: Int = 10, + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("GBTExample") { + head("GBTExample: an example Gradient-Boosted Trees app.") + opt[String]("algo") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Int]("maxIter") + .text(s"number of trees in ensemble, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${ + defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + } + }") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"GBTExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"GBTExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, algo, params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) + stages += featuresIndexer + // (3) Learn GBT + val dt = algo match { + case "classification" => + new GBTClassifier() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setMaxIter(params.maxIter) + case "regression" => + new GBTRegressor() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setMaxIter(params.maxIter) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained GBT from the fitted PipelineModel + algo match { + case "classification" => + val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier]) + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case "regression" => + val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor]) + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName) + case "regression" => + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala new file mode 100644 index 0000000000000..9b909324ec82a --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} +import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} +import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} +import org.apache.spark.sql.DataFrame + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.RandomForestExample [options] + * }}} + * Decision Trees and ensembles can take a large amount of memory. If the run-example command + * above fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.RandomForestExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object RandomForestExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + numTrees: Int = 10, + featureSubsetStrategy: String = "auto", + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("RandomForestExample") { + head("RandomForestExample: an example random forest app.") + opt[String]("algo") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Int]("numTrees") + .text(s"number of trees in ensemble, default: ${defaultParams.numTrees}") + .action((x, c) => c.copy(numTrees = x)) + opt[String]("featureSubsetStrategy") + .text(s"number of features to use per node (supported:" + + s" ${RandomForestClassifier.supportedFeatureSubsetStrategies.mkString(",")})," + + s" default: ${defaultParams.numTrees}") + .action((x, c) => c.copy(featureSubsetStrategy = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${ + defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + } + }") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"RandomForestExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"RandomForestExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, algo, params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) + stages += featuresIndexer + // (3) Learn Random Forest + val dt = algo match { + case "classification" => + new RandomForestClassifier() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setFeatureSubsetStrategy(params.featureSubsetStrategy) + .setNumTrees(params.numTrees) + case "regression" => + new RandomForestRegressor() + .setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + .setFeatureSubsetStrategy(params.featureSubsetStrategy) + .setNumTrees(params.numTrees) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained Random Forest from the fitted PipelineModel + algo match { + case "classification" => + val rfModel = pipelineModel.getModel[RandomForestClassificationModel]( + dt.asInstanceOf[RandomForestClassifier]) + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case "regression" => + val rfModel = pipelineModel.getModel[RandomForestRegressionModel]( + dt.asInstanceOf[RandomForestRegressor]) + if (rfModel.totalNumNodes < 30) { + println(rfModel.toDebugString) // Print full model. + } else { + println(rfModel) // Print model summary. + } + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, labelColName) + case "regression" => + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, training, labelColName) + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(pipelineModel, test, labelColName) + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index 431ead8c0c165..0763a7736305a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -25,6 +25,7 @@ import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} import org.apache.spark.util.Utils + /** * An example runner for Gradient Boosting using decision trees as weak learners. Run with * {{{ diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index cae5082b51196..a491bc7ee8295 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -30,11 +30,13 @@ import org.apache.spark.ml.param.ParamMap abstract class Model[M <: Model[M]] extends Transformer { /** * The parent estimator that produced this model. + * Note: For ensembles' component Models, this value can be null. */ val parent: Estimator[M] /** * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model. + * Note: For ensembles' component Models, this value can be null. */ val fittingParamMap: ParamMap } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 3855e396b5534..ee2a8dc6db171 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -43,8 +43,7 @@ import org.apache.spark.sql.DataFrame @AlphaComponent final class DecisionTreeClassifier extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] - with DecisionTreeParams - with TreeClassifierParams { + with DecisionTreeParams with TreeClassifierParams { // Override parameter setters from parent trait for Java API compatibility. @@ -59,11 +58,9 @@ final class DecisionTreeClassifier override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) - override def setCacheNodeIds(value: Boolean): this.type = - super.setCacheNodeIds(value) + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) - override def setCheckpointInterval(value: Int): this.type = - super.setCheckpointInterval(value) + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) override def setImpurity(value: String): this.type = super.setImpurity(value) @@ -75,8 +72,9 @@ final class DecisionTreeClassifier val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { case Some(n: Int) => n case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" + - s" with invalid label column, without the number of classes specified.") - // TODO: Automatically index labels. + s" with invalid label column ${paramMap(labelCol)}, without the number of classes" + + " specified. See StringIndexer.") + // TODO: Automatically index labels: SPARK-7126 } val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) val strategy = getOldStrategy(categoricalFeatures, numClasses) @@ -85,18 +83,16 @@ final class DecisionTreeClassifier } /** (private[ml]) Create a Strategy instance to use with the old API. */ - override private[ml] def getOldStrategy( + private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], numClasses: Int): OldStrategy = { - val strategy = super.getOldStrategy(categoricalFeatures, numClasses) - strategy.algo = OldAlgo.Classification - strategy.setImpurity(getOldImpurity) - strategy + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, + subsamplingRate = 1.0) } } object DecisionTreeClassifier { - /** Accessor for supported impurities */ + /** Accessor for supported impurities: entropy, gini */ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala new file mode 100644 index 0000000000000..d2e052fbbbf22 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Param, Params, ParamMap} +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.{Loss => OldLoss, LogLoss => OldLogLoss} +import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * learning algorithm for classification. + * It supports binary labels, as well as both continuous and categorical features. + * Note: Multiclass labels are not currently supported. + */ +@AlphaComponent +final class GBTClassifier + extends Predictor[Vector, GBTClassifier, GBTClassificationModel] + with GBTParams with TreeClassifierParams with Logging { + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeClassifierParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + /** + * The impurity setting is ignored for GBT models. + * Individual trees are built using impurity "Variance." + */ + override def setImpurity(value: String): this.type = { + logWarning("GBTClassifier.setImpurity should NOT be used") + this + } + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = { + logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") + super.setSeed(value) + } + + // Parameters from GBTParams: + + override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + + override def setStepSize(value: Double): this.type = super.setStepSize(value) + + // Parameters for GBTClassifier: + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "logistic" + * (default = logistic) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTClassifier.supportedLossTypes.mkString(", ")}") + + setDefault(lossType -> "logistic") + + /** @group setParam */ + def setLossType(value: String): this.type = { + val lossStr = value.toLowerCase + require(GBTClassifier.supportedLossTypes.contains(lossStr), "GBTClassifier was given bad loss" + + s" type: $value. Supported options: ${GBTClassifier.supportedLossTypes.mkString(", ")}") + set(lossType, lossStr) + this + } + + /** @group getParam */ + def getLossType: String = getOrDefault(lossType) + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "logistic" => OldLogLoss + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") + } + } + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): GBTClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("GBTClassifier was given input" + + s" with invalid label column ${paramMap(labelCol)}, without the number of classes" + + " specified. See StringIndexer.") + // TODO: Automatically index labels: SPARK-7126 + } + require(numClasses == 2, + s"GBTClassifier only supports binary classification but was given numClasses = $numClasses") + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + val oldGBT = new OldGBT(boostingStrategy) + val oldModel = oldGBT.run(oldDataset) + GBTClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } +} + +object GBTClassifier { + // The losses below should be lowercase. + /** Accessor for supported loss settings: logistic */ + final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * model for classification. + * It supports binary labels, as well as both continuous and categorical features. + * Note: Multiclass labels are not currently supported. + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ +@AlphaComponent +final class GBTClassificationModel( + override val parent: GBTClassifier, + override val fittingParamMap: ParamMap, + private val _trees: Array[DecisionTreeRegressionModel], + private val _treeWeights: Array[Double]) + extends PredictionModel[Vector, GBTClassificationModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.") + require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + override def treeWeights: Array[Double] = _treeWeights + + override protected def predict(features: Vector): Double = { + // TODO: Override transform() to broadcast model: SPARK-7127 + // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 + // Classifies by thresholding sum of weighted tree predictions + val treePredictions = _trees.map(_.rootNode.predict(features)) + val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + if (prediction > 0.0) 1.0 else 0.0 + } + + override protected def copy(): GBTClassificationModel = { + val m = new GBTClassificationModel(parent, fittingParamMap, _trees, _treeWeights) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"GBTClassificationModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldGBTModel = { + new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) + } +} + +private[ml] object GBTClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldGBTModel, + parent: GBTClassifier, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): GBTClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent, fittingParamMap for each tree is null since there are no good ways to set these. + DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures) + } + new GBTClassificationModel(parent, fittingParamMap, newTrees, oldModel.treeWeights) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala new file mode 100644 index 0000000000000..cfd6508fce890 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import scala.collection.mutable + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for + * classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@AlphaComponent +final class RandomForestClassifier + extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel] + with RandomForestParams with TreeClassifierParams { + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeClassifierParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = super.setSeed(value) + + // Parameters from RandomForestParams: + + override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + + override def setFeatureSubsetStrategy(value: String): this.type = + super.setFeatureSubsetStrategy(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): RandomForestClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("RandomForestClassifier was given input" + + s" with invalid label column ${paramMap(labelCol)}, without the number of classes" + + " specified. See StringIndexer.") + // TODO: Automatically index labels: SPARK-7126 + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) + val oldModel = OldRandomForest.trainClassifier( + oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) + RandomForestClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } +} + +object RandomForestClassifier { + /** Accessor for supported impurity settings: entropy, gini */ + final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities + + /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + final val supportedFeatureSubsetStrategies: Array[String] = + RandomForestParams.supportedFeatureSubsetStrategies +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + * @param _trees Decision trees in the ensemble. + * Warning: These have null parents. + */ +@AlphaComponent +final class RandomForestClassificationModel private[ml] ( + override val parent: RandomForestClassifier, + override val fittingParamMap: ParamMap, + private val _trees: Array[DecisionTreeClassificationModel]) + extends PredictionModel[Vector, RandomForestClassificationModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + // Note: We may add support for weights (based on tree performance) later on. + private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + + override def treeWeights: Array[Double] = _treeWeights + + override protected def predict(features: Vector): Double = { + // TODO: Override transform() to broadcast model. SPARK-7127 + // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 + // Classifies using majority votes. + // Ignore the weights since all are 1.0 for now. + val votes = mutable.Map.empty[Int, Double] + _trees.view.foreach { tree => + val prediction = tree.rootNode.predict(features).toInt + votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight + } + votes.maxBy(_._2)._1 + } + + override protected def copy(): RandomForestClassificationModel = { + val m = new RandomForestClassificationModel(parent, fittingParamMap, _trees) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"RandomForestClassificationModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldRandomForestModel = { + new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld)) + } +} + +private[ml] object RandomForestClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldRandomForestModel, + parent: RandomForestClassifier, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent, fittingParamMap for each tree is null since there are no good ways to set these. + DecisionTreeClassificationModel.fromOld(tree, null, null, categoricalFeatures) + } + new RandomForestClassificationModel(parent, fittingParamMap, newTrees) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala index eb2609faef05a..ab6281b9b2e34 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -20,9 +20,12 @@ package org.apache.spark.ml.impl.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.impl.estimator.PredictorParams import org.apache.spark.ml.param._ -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.ml.param.shared.{HasSeed, HasMaxIter} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, + BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy, Impurity => OldImpurity, Variance => OldVariance} +import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} /** @@ -117,79 +120,68 @@ private[ml] trait DecisionTreeParams extends PredictorParams { def setMaxDepth(value: Int): this.type = { require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") set(maxDepth, value) - this } /** @group getParam */ - def getMaxDepth: Int = getOrDefault(maxDepth) + final def getMaxDepth: Int = getOrDefault(maxDepth) /** @group setParam */ def setMaxBins(value: Int): this.type = { require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value") set(maxBins, value) - this } /** @group getParam */ - def getMaxBins: Int = getOrDefault(maxBins) + final def getMaxBins: Int = getOrDefault(maxBins) /** @group setParam */ def setMinInstancesPerNode(value: Int): this.type = { require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value") set(minInstancesPerNode, value) - this } /** @group getParam */ - def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) + final def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) /** @group setParam */ - def setMinInfoGain(value: Double): this.type = { - set(minInfoGain, value) - this - } + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) /** @group getParam */ - def getMinInfoGain: Double = getOrDefault(minInfoGain) + final def getMinInfoGain: Double = getOrDefault(minInfoGain) /** @group expertSetParam */ def setMaxMemoryInMB(value: Int): this.type = { require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value") set(maxMemoryInMB, value) - this } /** @group expertGetParam */ - def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) + final def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) /** @group expertSetParam */ - def setCacheNodeIds(value: Boolean): this.type = { - set(cacheNodeIds, value) - this - } + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) /** @group expertGetParam */ - def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) + final def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) /** @group expertSetParam */ def setCheckpointInterval(value: Int): this.type = { require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value") set(checkpointInterval, value) - this } /** @group expertGetParam */ - def getCheckpointInterval: Int = getOrDefault(checkpointInterval) + final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) - /** - * Create a Strategy instance to use with the old API. - * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0, - * the default for single trees). - */ + /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], - numClasses: Int): OldStrategy = { - val strategy = OldStrategy.defaultStategy(OldAlgo.Classification) + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity, + subsamplingRate: Double): OldStrategy = { + val strategy = OldStrategy.defaultStategy(oldAlgo) + strategy.impurity = oldImpurity strategy.checkpointInterval = getCheckpointInterval strategy.maxBins = getMaxBins strategy.maxDepth = getMaxDepth @@ -199,13 +191,13 @@ private[ml] trait DecisionTreeParams extends PredictorParams { strategy.useNodeIdCache = getCacheNodeIds strategy.numClasses = numClasses strategy.categoricalFeaturesInfo = categoricalFeatures - strategy.subsamplingRate = 1.0 // default for individual trees + strategy.subsamplingRate = subsamplingRate strategy } } /** - * (private trait) Parameters for Decision Tree-based classification algorithms. + * Parameters for Decision Tree-based classification algorithms. */ private[ml] trait TreeClassifierParams extends Params { @@ -215,7 +207,7 @@ private[ml] trait TreeClassifierParams extends Params { * (default = gini) * @group param */ - val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}") @@ -228,11 +220,10 @@ private[ml] trait TreeClassifierParams extends Params { s"Tree-based classifier was given unrecognized impurity: $value." + s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}") set(impurity, impurityStr) - this } /** @group getParam */ - def getImpurity: String = getOrDefault(impurity) + final def getImpurity: String = getOrDefault(impurity) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -249,11 +240,11 @@ private[ml] trait TreeClassifierParams extends Params { private[ml] object TreeClassifierParams { // These options should be lowercase. - val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) + final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) } /** - * (private trait) Parameters for Decision Tree-based regression algorithms. + * Parameters for Decision Tree-based regression algorithms. */ private[ml] trait TreeRegressorParams extends Params { @@ -263,7 +254,7 @@ private[ml] trait TreeRegressorParams extends Params { * (default = variance) * @group param */ - val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + " information gain calculation (case-insensitive). Supported options:" + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}") @@ -276,11 +267,10 @@ private[ml] trait TreeRegressorParams extends Params { s"Tree-based regressor was given unrecognized impurity: $value." + s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}") set(impurity, impurityStr) - this } /** @group getParam */ - def getImpurity: String = getOrDefault(impurity) + final def getImpurity: String = getOrDefault(impurity) /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { @@ -296,5 +286,186 @@ private[ml] trait TreeRegressorParams extends Params { private[ml] object TreeRegressorParams { // These options should be lowercase. - val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) + final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) +} + +/** + * :: DeveloperApi :: + * Parameters for Decision Tree-based ensemble algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { + + /** + * Fraction of the training data used for learning each decision tree. + * (default = 1.0) + * @group param + */ + final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate", + "Fraction of the training data used for learning each decision tree.") + + setDefault(subsamplingRate -> 1.0) + + /** @group setParam */ + def setSubsamplingRate(value: Double): this.type = { + require(value > 0.0 && value <= 1.0, + s"Subsampling rate must be in range (0,1]. Bad rate: $value") + set(subsamplingRate, value) + } + + /** @group getParam */ + final def getSubsamplingRate: Double = getOrDefault(subsamplingRate) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + /** + * Create a Strategy instance to use with the old API. + * NOTE: The caller should set impurity and seed. + */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) + } +} + +/** + * :: DeveloperApi :: + * Parameters for Random Forest algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait RandomForestParams extends TreeEnsembleParams { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)") + + /** + * The number of features to consider for splits at each tree node. + * Supported options: + * - "auto": Choose automatically for task: + * If numTrees == 1, set to "all." + * If numTrees > 1 (forest), set to "sqrt" for classification and + * to "onethird" for regression. + * - "all": use all features + * - "onethird": use 1/3 of the features + * - "sqrt": use sqrt(number of features) + * - "log2": use log2(number of features) + * (default = "auto") + * + * These various settings are based on the following references: + * - log2: tested in Breiman (2001) + * - sqrt: recommended by Breiman manual for random forests + * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest + * package. + * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]] + * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for + * random forests]] + * + * @group param + */ + final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", + "The number of features to consider for splits at each tree node." + + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}") + + setDefault(numTrees -> 20, featureSubsetStrategy -> "auto") + + /** @group setParam */ + def setNumTrees(value: Int): this.type = { + require(value >= 1, s"Random Forest numTrees parameter cannot be $value; it must be >= 1.") + set(numTrees, value) + } + + /** @group getParam */ + final def getNumTrees: Int = getOrDefault(numTrees) + + /** @group setParam */ + def setFeatureSubsetStrategy(value: String): this.type = { + val strategyStr = value.toLowerCase + require(RandomForestParams.supportedFeatureSubsetStrategies.contains(strategyStr), + s"RandomForestParams was given unrecognized featureSubsetStrategy: $value. Supported" + + s" options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}") + set(featureSubsetStrategy, strategyStr) + } + + /** @group getParam */ + final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy) +} + +private[ml] object RandomForestParams { + // These options should be lowercase. + final val supportedFeatureSubsetStrategies: Array[String] = + Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) +} + +/** + * :: DeveloperApi :: + * Parameters for Gradient-Boosted Tree algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { + + /** + * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each + * estimator. + * (default = 0.1) + * @group param + */ + final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + + " learning rate) in interval (0, 1] for shrinking the contribution of each estimator") + + /* TODO: Add this doc when we add this param. SPARK-7132 + * Threshold for stopping early when runWithValidation is used. + * If the error rate on the validation input changes by less than the validationTol, + * then learning will stop early (before [[numIterations]]). + * This parameter is ignored when run is used. + * (default = 1e-5) + * @group param + */ + // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") + // validationTol -> 1e-5 + + setDefault(maxIter -> 20, stepSize -> 0.1) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = { + require(value >= 1, s"Gradient Boosting maxIter parameter cannot be $value; it must be >= 1.") + set(maxIter, value) + } + + /** @group setParam */ + def setStepSize(value: Double): this.type = { + require(value > 0.0 && value <= 1.0, + s"GBT given invalid step size ($value). Value should be in (0,1].") + set(stepSize, value) + } + + /** @group getParam */ + final def getStepSize: Double = getOrDefault(stepSize) + + /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ + private[ml] def getOldBoostingStrategy( + categoricalFeatures: Map[Int, Int], + oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) + // NOTE: The old API does not support "seed" so we ignore it. + new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) + } + + /** Get old Gradient Boosting Loss type */ + private[ml] def getOldLossType: OldLoss } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 95d7e64790c79..e88c48741e99f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -45,7 +45,8 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name"), ParamDesc[Int]("checkpointInterval", "checkpoint interval"), - ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true"))) + ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()"))) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" @@ -154,6 +155,7 @@ private[shared] object SharedParamsCodeGen { | |import org.apache.spark.annotation.DeveloperApi |import org.apache.spark.ml.param._ + |import org.apache.spark.util.Utils | |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. | diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 72b08bf276483..a860b8834cff9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.param.shared import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ +import org.apache.spark.util.Utils // DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. @@ -256,4 +257,23 @@ trait HasFitIntercept extends Params { /** @group getParam */ final def getFitIntercept: Boolean = getOrDefault(fitIntercept) } + +/** + * :: DeveloperApi :: + * Trait for shared param seed (default: Utils.random.nextLong()). + */ +@DeveloperApi +trait HasSeed extends Params { + + /** + * Param for random seed. + * @group param + */ + final val seed: LongParam = new LongParam(this, "seed", "random seed") + + setDefault(seed, Utils.random.nextLong()) + + /** @group getParam */ + final def getSeed: Long = getOrDefault(seed) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 49a8b77acf960..756725a64b0f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -42,8 +42,7 @@ import org.apache.spark.sql.DataFrame @AlphaComponent final class DecisionTreeRegressor extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] - with DecisionTreeParams - with TreeRegressorParams { + with DecisionTreeParams with TreeRegressorParams { // Override parameter setters from parent trait for Java API compatibility. @@ -60,8 +59,7 @@ final class DecisionTreeRegressor override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) - override def setCheckpointInterval(value: Int): this.type = - super.setCheckpointInterval(value) + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) override def setImpurity(value: String): this.type = super.setImpurity(value) @@ -78,15 +76,13 @@ final class DecisionTreeRegressor /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { - val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0) - strategy.algo = OldAlgo.Regression - strategy.setImpurity(getOldImpurity) - strategy + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, + subsamplingRate = 1.0) } } object DecisionTreeRegressor { - /** Accessor for supported impurities */ + /** Accessor for supported impurities: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala new file mode 100644 index 0000000000000..c784cf39ed31a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap, Param} +import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, + SquaredError => OldSquaredError} +import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * learning algorithm for regression. + * It supports both continuous and categorical features. + */ +@AlphaComponent +final class GBTRegressor + extends Predictor[Vector, GBTRegressor, GBTRegressionModel] + with GBTParams with TreeRegressorParams with Logging { + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeRegressorParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + /** + * The impurity setting is ignored for GBT models. + * Individual trees are built using impurity "Variance." + */ + override def setImpurity(value: String): this.type = { + logWarning("GBTRegressor.setImpurity should NOT be used") + this + } + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = { + logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") + super.setSeed(value) + } + + // Parameters from GBTParams: + + override def setMaxIter(value: Int): this.type = super.setMaxIter(value) + + override def setStepSize(value: Double): this.type = super.setStepSize(value) + + // Parameters for GBTRegressor: + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "squared" (L2) and "absolute" (L1) + * (default = squared) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTRegressor.supportedLossTypes.mkString(", ")}") + + setDefault(lossType -> "squared") + + /** @group setParam */ + def setLossType(value: String): this.type = { + val lossStr = value.toLowerCase + require(GBTRegressor.supportedLossTypes.contains(lossStr), "GBTRegressor was given bad loss" + + s" type: $value. Supported options: ${GBTRegressor.supportedLossTypes.mkString(", ")}") + set(lossType, lossStr) + this + } + + /** @group getParam */ + def getLossType: String = getOrDefault(lossType) + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "squared" => OldSquaredError + case "absolute" => OldAbsoluteError + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") + } + } + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): GBTRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + val oldGBT = new OldGBT(boostingStrategy) + val oldModel = oldGBT.run(oldDataset) + GBTRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } +} + +object GBTRegressor { + // The losses below should be lowercase. + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] + * model for regression. + * It supports both continuous and categorical features. + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + */ +@AlphaComponent +final class GBTRegressionModel( + override val parent: GBTRegressor, + override val fittingParamMap: ParamMap, + private val _trees: Array[DecisionTreeRegressionModel], + private val _treeWeights: Array[Double]) + extends PredictionModel[Vector, GBTRegressionModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.") + require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + + s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + override def treeWeights: Array[Double] = _treeWeights + + override protected def predict(features: Vector): Double = { + // TODO: Override transform() to broadcast model. SPARK-7127 + // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 + // Classifies by thresholding sum of weighted tree predictions + val treePredictions = _trees.map(_.rootNode.predict(features)) + val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + if (prediction > 0.0) 1.0 else 0.0 + } + + override protected def copy(): GBTRegressionModel = { + val m = new GBTRegressionModel(parent, fittingParamMap, _trees, _treeWeights) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"GBTRegressionModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldGBTModel = { + new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) + } +} + +private[ml] object GBTRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldGBTModel, + parent: GBTRegressor, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): GBTRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent, fittingParamMap for each tree is null since there are no good ways to set these. + DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures) + } + new GBTRegressionModel(parent, fittingParamMap, newTrees, oldModel.treeWeights) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala new file mode 100644 index 0000000000000..2171ef3d32c26 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams} +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression. + * It supports both continuous and categorical features. + */ +@AlphaComponent +final class RandomForestRegressor + extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] + with RandomForestParams with TreeRegressorParams { + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeRegressorParams: + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + // Parameters from TreeEnsembleParams: + + override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) + + override def setSeed(value: Long): this.type = super.setSeed(value) + + // Parameters from RandomForestParams: + + override def setNumTrees(value: Int): this.type = super.setNumTrees(value) + + override def setFeatureSubsetStrategy(value: String): this.type = + super.setFeatureSubsetStrategy(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): RandomForestRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) + val oldModel = OldRandomForest.trainRegressor( + oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt) + RandomForestRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } +} + +object RandomForestRegressor { + /** Accessor for supported impurity settings: variance */ + final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + + /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + final val supportedFeatureSubsetStrategies: Array[String] = + RandomForestParams.supportedFeatureSubsetStrategies +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. + * It supports both continuous and categorical features. + * @param _trees Decision trees in the ensemble. + */ +@AlphaComponent +final class RandomForestRegressionModel private[ml] ( + override val parent: RandomForestRegressor, + override val fittingParamMap: ParamMap, + private val _trees: Array[DecisionTreeRegressionModel]) + extends PredictionModel[Vector, RandomForestRegressionModel] + with TreeEnsembleModel with Serializable { + + require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.") + + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + + // Note: We may add support for weights (based on tree performance) later on. + private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + + override def treeWeights: Array[Double] = _treeWeights + + override protected def predict(features: Vector): Double = { + // TODO: Override transform() to broadcast model. SPARK-7127 + // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 + // Predict average of tree predictions. + // Ignore the weights since all are 1.0 for now. + _trees.map(_.rootNode.predict(features)).sum / numTrees + } + + override protected def copy(): RandomForestRegressionModel = { + val m = new RandomForestRegressionModel(parent, fittingParamMap, _trees) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"RandomForestRegressionModel with $numTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldRandomForestModel = { + new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) + } +} + +private[ml] object RandomForestRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldRandomForestModel, + parent: RandomForestRegressor, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent, fittingParamMap for each tree is null since there are no good ways to set these. + DecisionTreeRegressionModel.fromOld(tree, null, null, categoricalFeatures) + } + new RandomForestRegressionModel(parent, fittingParamMap, newTrees) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d6e2203d9f937..d2dec0c76cb12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -28,9 +28,9 @@ import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformation sealed abstract class Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree - // code into the new API and deprecate the old API. + // code into the new API and deprecate the old API. SPARK-3727 - /** Prediction this node makes (or would make, if it is an internal node) */ + /** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */ def prediction: Double /** Impurity measure at this node (for training data) */ @@ -194,7 +194,7 @@ private object InternalNode { s"$featureStr > ${contSplit.threshold}" } case catSplit: CategoricalSplit => - val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}") + val categoriesStr = catSplit.leftCategories.mkString("{", ",", "}") if (left) { s"$featureStr in $categoriesStr" } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 708c769087dd0..90f1d052764d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -44,7 +44,7 @@ private[tree] object Split { oldSplit.featureType match { case OldFeatureType.Categorical => new CategoricalSplit(featureIndex = oldSplit.feature, - leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature)) + _leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature)) case OldFeatureType.Continuous => new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold) } @@ -54,30 +54,30 @@ private[tree] object Split { /** * Split which tests a categorical feature. * @param featureIndex Index of the feature to test - * @param leftCategories If the feature value is in this set of categories, then the split goes - * left. Otherwise, it goes right. + * @param _leftCategories If the feature value is in this set of categories, then the split goes + * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ final class CategoricalSplit private[ml] ( override val featureIndex: Int, - leftCategories: Array[Double], + _leftCategories: Array[Double], private val numCategories: Int) extends Split { - require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + - s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}") + require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + + s" (should be in range [0, $numCategories)): ${_leftCategories.mkString(",")}") /** * If true, then "categories" is the set of categories for splitting to the left, and vice versa. */ - private val isLeft: Boolean = leftCategories.length <= numCategories / 2 + private val isLeft: Boolean = _leftCategories.length <= numCategories / 2 /** Set of categories determining the splitting rule, along with [[isLeft]]. */ private val categories: Set[Double] = { if (isLeft) { - leftCategories.toSet + _leftCategories.toSet } else { - setComplement(leftCategories.toSet) + setComplement(_leftCategories.toSet) } } @@ -107,13 +107,13 @@ final class CategoricalSplit private[ml] ( } /** Get sorted categories which split to the left */ - def getLeftCategories: Array[Double] = { + def leftCategories: Array[Double] = { val cats = if (isLeft) categories else setComplement(categories) cats.toArray.sorted } /** Get sorted categories which split to the right */ - def getRightCategories: Array[Double] = { + def rightCategories: Array[Double] = { val cats = if (isLeft) setComplement(categories) else categories cats.toArray.sorted } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 8e3bc3849dcf0..1929f9d02156e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -17,18 +17,13 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.AlphaComponent - /** - * :: AlphaComponent :: - * * Abstraction for Decision Tree models. * - * TODO: Add support for predicting probabilities and raw predictions + * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 */ -@AlphaComponent -trait DecisionTreeModel { +private[ml] trait DecisionTreeModel { /** Root of the decision tree */ def rootNode: Node @@ -58,3 +53,40 @@ trait DecisionTreeModel { header + rootNode.subtreeToString(2) } } + +/** + * Abstraction for models which are ensembles of decision trees + * + * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 + */ +private[ml] trait TreeEnsembleModel { + + // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of + // DecisionTreeModel. + + /** Trees in this ensemble. Warning: These have null parent Estimators. */ + def trees: Array[DecisionTreeModel] + + /** Weights for each tree, zippable with [[trees]] */ + def treeWeights: Array[Double] + + /** Summary of the model */ + override def toString: String = { + // Implementing classes should generally override this method to be more descriptive. + s"TreeEnsembleModel with $numTrees trees" + } + + /** Full description of model */ + def toDebugString: String = { + val header = toString + "\n" + header + trees.zip(treeWeights).zipWithIndex.map { case ((tree, weight), treeIndex) => + s" Tree $treeIndex (weight $weight):\n" + tree.rootNode.subtreeToString(4) + }.fold("")(_ + _) + } + + /** Number of trees in ensemble */ + val numTrees: Int = trees.length + + /** Total number of nodes, summed over all trees in the ensemble. */ + lazy val totalNumNodes: Int = trees.map(_.numNodes).sum +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java index 43b8787f9dd7e..60f25e5cce437 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.ml.classification; -import java.io.File; import java.io.Serializable; import java.util.HashMap; import java.util.Map; @@ -32,7 +31,6 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; -import org.apache.spark.util.Utils; public class JavaDecisionTreeClassifierSuite implements Serializable { @@ -57,7 +55,7 @@ public void runDT() { double B = -1.5; JavaRDD data = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap(); DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); @@ -71,8 +69,8 @@ public void runDT() { .setCacheNodeIds(false) .setCheckpointInterval(10) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) { - dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]); + for (String impurity: DecisionTreeClassifier.supportedImpurities()) { + dt.setImpurity(impurity); } DecisionTreeClassificationModel model = dt.fit(dataFrame); @@ -82,7 +80,7 @@ public void runDT() { model.toDebugString(); /* - // TODO: Add test once save/load are implemented. + // TODO: Add test once save/load are implemented. SPARK-6725 File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); String path = tempDir.toURI().toString(); try { diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java new file mode 100644 index 0000000000000..3c69467fa119e --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaGBTClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGBTClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + GBTClassifier rf = new GBTClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setMaxIter(3) + .setStepSize(0.1) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String lossType: GBTClassifier.supportedLossTypes()) { + rf.setLossType(lossType); + } + GBTClassificationModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + GBTClassificationModel sameModel = GBTClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java new file mode 100644 index 0000000000000..32d0b3856b7e2 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaRandomForestClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + RandomForestClassifier rf = new RandomForestClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setNumTrees(3) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: RandomForestClassifier.supportedImpurities()) { + rf.setImpurity(impurity); + } + for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) { + rf.setFeatureSubsetStrategy(featureSubsetStrategy); + } + RandomForestClassificationModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + RandomForestClassificationModel sameModel = + RandomForestClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index a3a339004f31c..71b041818d7ee 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.ml.regression; -import java.io.File; import java.io.Serializable; import java.util.HashMap; import java.util.Map; @@ -32,7 +31,6 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.DataFrame; -import org.apache.spark.util.Utils; public class JavaDecisionTreeRegressorSuite implements Serializable { @@ -57,22 +55,22 @@ public void runDT() { double B = -1.5; JavaRDD data = sc.parallelize( - LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map categoricalFeatures = new HashMap(); DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. DecisionTreeRegressor dt = new DecisionTreeRegressor() - .setMaxDepth(2) - .setMaxBins(10) - .setMinInstancesPerNode(5) - .setMinInfoGain(0.0) - .setMaxMemoryInMB(256) - .setCacheNodeIds(false) - .setCheckpointInterval(10) - .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) { - dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]); + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: DecisionTreeRegressor.supportedImpurities()) { + dt.setImpurity(impurity); } DecisionTreeRegressionModel model = dt.fit(dataFrame); @@ -82,7 +80,7 @@ public void runDT() { model.toDebugString(); /* - // TODO: Add test once save/load are implemented. + // TODO: Add test once save/load are implemented. SPARK-6725 File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); String path = tempDir.toURI().toString(); try { diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java new file mode 100644 index 0000000000000..fc8c13db07e6f --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaGBTRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaGBTRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + + GBTRegressor rf = new GBTRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setMaxIter(3) + .setStepSize(0.1) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String lossType: GBTRegressor.supportedLossTypes()) { + rf.setLossType(lossType); + } + GBTRegressionModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + GBTRegressionModel sameModel = GBTRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java new file mode 100644 index 0000000000000..e306ebadfe7cf --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; + + +public class JavaRandomForestRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRandomForestRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + + // This tests setters. Training with various options is tested in Scala. + RandomForestRegressor rf = new RandomForestRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setSubsamplingRate(1.0) + .setSeed(1234) + .setNumTrees(3) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (String impurity: RandomForestRegressor.supportedImpurities()) { + rf.setImpurity(impurity); + } + for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) { + rf.setFeatureSubsetStrategy(featureSubsetStrategy); + } + RandomForestRegressionModel model = rf.fit(dataFrame); + + model.transform(dataFrame); + model.totalNumNodes(); + model.toDebugString(); + model.trees(); + model.treeWeights(); + + /* + // TODO: Add test once save/load are implemented. SPARK-6725 + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + RandomForestRegressionModel sameModel = RandomForestRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index af88595df5245..9b31adecdcb1c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -230,7 +230,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented + // TODO: Reinstate test once save/load are implemented SPARK-6725 /* test("model save/load") { val tempDir = Utils.createTempDir() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala new file mode 100644 index 0000000000000..e6ccc2c93cba8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * Test suite for [[GBTClassifier]]. + */ +class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext { + + import GBTClassifierSuite.compareAPIs + + // Combinations for estimators, learning rates and subsamplingRate + private val testCombinations = + Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) + + private var data: RDD[LabeledPoint] = _ + private var trainData: RDD[LabeledPoint] = _ + private var validationData: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) + trainData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) + validationData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) + } + + test("Binary classification with continuous features: Log Loss") { + val categoricalFeatures = Map.empty[Int, Int] + testCombinations.foreach { + case (maxIter, learningRate, subsamplingRate) => + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setSubsamplingRate(subsamplingRate) + .setLossType("logistic") + .setMaxIter(maxIter) + .setStepSize(learningRate) + compareAPIs(data, None, gbt, categoricalFeatures) + } + } + + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 + /* + test("runWithValidation stops early and performs better on a validation dataset") { + val categoricalFeatures = Map.empty[Int, Int] + // Set maxIter large enough so that it stops early. + val maxIter = 20 + GBTClassifier.supportedLossTypes.foreach { loss => + val gbt = new GBTClassifier() + .setMaxIter(maxIter) + .setMaxDepth(2) + .setLossType(loss) + .setValidationTol(0.0) + compareAPIs(trainData, None, gbt, categoricalFeatures) + compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures) + } + } + */ + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray + val treeWeights = Array(0.1, 0.3, 1.1) + val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights) + val newModel = GBTClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = GBTClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object GBTClassifierSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + validationData: Option[RDD[LabeledPoint]], + gbt: GBTClassifier, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldBoostingStrategy = + gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + val oldGBT = new OldGBT(oldBoostingStrategy) + val oldModel = oldGBT.run(data) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val newModel = gbt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldModelAsNew = GBTClassificationModel.fromOld(oldModel, newModel.parent, + newModel.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala new file mode 100644 index 0000000000000..ed41a9664f94f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * Test suite for [[RandomForestClassifier]]. + */ +class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext { + + import RandomForestClassifierSuite.compareAPIs + + private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ + private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + orderedLabeledPoints50_1000 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) + orderedLabeledPoints5_20 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20)) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier) { + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + val newRF = rf + .setImpurity("Gini") + .setMaxDepth(2) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses) + } + + test("Binary classification with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestClassifier() + binaryClassificationTestWithContinuousFeatures(rf) + } + + test("Binary classification with continuous features and node Id cache:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestClassifier() + .setCacheNodeIds(true) + binaryClassificationTestWithContinuousFeatures(rf) + } + + test("alternating categorical and continuous features with multiclass labels to test indexing") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)), + LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + ) + val rdd = sc.parallelize(arr) + val categoricalFeatures = Map(0 -> 3, 2 -> 2, 4 -> 4) + val numClasses = 3 + + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(5) + .setNumTrees(2) + .setFeatureSubsetStrategy("sqrt") + .setSeed(12345) + compareAPIs(rdd, rf, categoricalFeatures, numClasses) + } + + test("subsampling rate in RandomForest"){ + val rdd = orderedLabeledPoints5_20 + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val rf1 = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setCacheNodeIds(true) + .setNumTrees(3) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(rdd, rf1, categoricalFeatures, numClasses) + + val rf2 = rf1.setSubsamplingRate(0.5) + compareAPIs(rdd, rf2, categoricalFeatures, numClasses) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = + Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray + val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees) + val newModel = RandomForestClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = RandomForestClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object RandomForestClassifierSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + rf: RandomForestClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int): Unit = { + val oldStrategy = + rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity) + val oldModel = OldRandomForest.trainClassifier( + data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + val newModel = rf.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldModelAsNew = RandomForestClassificationModel.fromOld(oldModel, newModel.parent, + newModel.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 2e57d4ce37f1d..1505ad872536b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -23,8 +23,7 @@ import org.scalatest.FunSuite import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} -import org.apache.spark.ml.impl.tree._ -import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node} +import org.apache.spark.ml.tree._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, DataFrame} @@ -111,22 +110,19 @@ private[ml] object TreeTests extends FunSuite { } } - // TODO: Reinstate after adding ensembles /** * Check if the two models are exactly the same. * If the models are not equal, this throws an exception. */ - /* def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = { try { - a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) => + a.trees.zip(b.trees).foreach { case (treeA, treeB) => TreeTests.checkEqual(treeA, treeB) } - assert(a.getTreeWeights === b.getTreeWeights) + assert(a.treeWeights === b.treeWeights) } catch { case ex: Exception => throw new AssertionError( "checkEqual failed since the two tree ensembles were not identical") } } - */ } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 0b40fe33fae9d..c87a171b4b229 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -66,7 +66,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: test("model save/load") + // TODO: test("model save/load") SPARK-6725 } private[ml] object DecisionTreeRegressorSuite extends FunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala new file mode 100644 index 0000000000000..4aec36948ac92 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * Test suite for [[GBTRegressor]]. + */ +class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext { + + import GBTRegressorSuite.compareAPIs + + // Combinations for estimators, learning rates and subsamplingRate + private val testCombinations = + Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) + + private var data: RDD[LabeledPoint] = _ + private var trainData: RDD[LabeledPoint] = _ + private var validationData: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + data = sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100), 2) + trainData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120), 2) + validationData = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) + } + + test("Regression with continuous features: SquaredError") { + val categoricalFeatures = Map.empty[Int, Int] + GBTRegressor.supportedLossTypes.foreach { loss => + testCombinations.foreach { + case (maxIter, learningRate, subsamplingRate) => + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setSubsamplingRate(subsamplingRate) + .setLossType(loss) + .setMaxIter(maxIter) + .setStepSize(learningRate) + compareAPIs(data, None, gbt, categoricalFeatures) + } + } + } + + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 + /* + test("runWithValidation stops early and performs better on a validation dataset") { + val categoricalFeatures = Map.empty[Int, Int] + // Set maxIter large enough so that it stops early. + val maxIter = 20 + GBTRegressor.supportedLossTypes.foreach { loss => + val gbt = new GBTRegressor() + .setMaxIter(maxIter) + .setMaxDepth(2) + .setLossType(loss) + .setValidationTol(0.0) + compareAPIs(trainData, None, gbt, categoricalFeatures) + compareAPIs(trainData, Some(validationData), gbt, categoricalFeatures) + } + } + */ + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray + val treeWeights = Array(0.1, 0.3, 1.1) + val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights) + val newModel = GBTRegressionModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = GBTRegressionModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object GBTRegressorSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + validationData: Option[RDD[LabeledPoint]], + gbt: GBTRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) + val oldGBT = new OldGBT(oldBoostingStrategy) + val oldModel = oldGBT.run(data) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newModel = gbt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, + newModel.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala new file mode 100644 index 0000000000000..c6dc1cc29b6ff --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * Test suite for [[RandomForestRegressor]]. + */ +class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { + + import RandomForestRegressorSuite.compareAPIs + + private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + orderedLabeledPoints50_1000 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + def regressionTestWithContinuousFeatures(rf: RandomForestRegressor) { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val newRF = rf + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeaturesInfo) + } + + test("Regression with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestRegressor() + regressionTestWithContinuousFeatures(rf) + } + + test("Regression with continuous features and node Id cache :" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestRegressor() + .setCacheNodeIds(true) + regressionTestWithContinuousFeatures(rf) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented SPARK-6725 + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray + val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees) + val newModel = RandomForestRegressionModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = RandomForestRegressionModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private object RandomForestRegressorSuite extends FunSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + rf: RandomForestRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldStrategy = + rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity) + val oldModel = OldRandomForest.trainRegressor( + data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newModel = rf.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldModelAsNew = RandomForestRegressionModel.fromOld(oldModel, newModel.parent, + newModel.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldModelAsNew, newModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 249b8eae19b17..ce983eb27fa35 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -998,7 +998,7 @@ object DecisionTreeSuite extends FunSuite { node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical, categories = List(0.0, 1.0))) } - // TODO: The information gain stats should be consistent with the same info stored in children. + // TODO: The information gain stats should be consistent with info in children: SPARK-7131 node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2, leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6))) node @@ -1006,9 +1006,9 @@ object DecisionTreeSuite extends FunSuite { /** * Create a tree model. This is deterministic and contains a variety of node and feature types. - * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.) + * TODO: Update to be a correct tree (with matching probabilities, impurities, etc.): SPARK-7131 */ - private[mllib] def createModel(algo: Algo): DecisionTreeModel = { + private[spark] def createModel(algo: Algo): DecisionTreeModel = { val topNode = createInternalNode(id = 1, Continuous) val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical)) val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7)) From aa6966ff34dacc83c3ca675b5109b05e35015469 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 25 Apr 2015 13:43:39 -0700 Subject: [PATCH 08/39] [SQL] Update SQL readme to include instructions on generating golden answer files based on Hive 0.13.1. Author: Yin Huai Closes #5702 from yhuai/howToGenerateGoldenFiles and squashes the following commits: 9c4a7f8 [Yin Huai] Update readme to include instructions on generating golden answer files based on Hive 0.13.1. --- sql/README.md | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/README.md b/sql/README.md index 237620e3fa808..46aec7cef7984 100644 --- a/sql/README.md +++ b/sql/README.md @@ -12,7 +12,10 @@ Spark SQL is broken up into four subprojects: Other dependencies for developers --------------------------------- -In order to create new hive test cases , you will need to set several environmental variables. +In order to create new hive test cases (i.e. a test suite based on `HiveComparisonTest`), +you will need to setup your development environment based on the following instructions. + +If you are working with Hive 0.12.0, you will need to set several environmental variables as follows. ``` export HIVE_HOME="/hive/build/dist" @@ -20,6 +23,24 @@ export HIVE_DEV_HOME="/hive/" export HADOOP_HOME="/hadoop-1.0.4" ``` +If you are working with Hive 0.13.1, the following steps are needed: + +1. Download Hive's [0.13.1](https://hive.apache.org/downloads.html) and set `HIVE_HOME` with `export HIVE_HOME=""`. Please do not set `HIVE_DEV_HOME` (See [SPARK-4119](https://issues.apache.org/jira/browse/SPARK-4119)). +2. Set `HADOOP_HOME` with `export HADOOP_HOME=""` +3. Download all Hive 0.13.1a jars (Hive jars actually used by Spark) from [here](http://mvnrepository.com/artifact/org.spark-project.hive) and replace corresponding original 0.13.1 jars in `$HIVE_HOME/lib`. +4. Download [Kryo 2.21 jar](http://mvnrepository.com/artifact/com.esotericsoftware.kryo/kryo/2.21) (Note: 2.22 jar does not work) and [Javolution 5.5.1 jar](http://mvnrepository.com/artifact/javolution/javolution/5.5.1) to `$HIVE_HOME/lib`. +5. This step is optional. But, when generating golden answer files, if a Hive query fails and you find that Hive tries to talk to HDFS or you find weird runtime NPEs, set the following in your test suite... + +``` +val testTempDir = Utils.createTempDir() +// We have to use kryo to let Hive correctly serialize some plans. +sql("set hive.plan.serialization.format=kryo") +// Explicitly set fs to local fs. +sql(s"set fs.default.name=file://$testTempDir/") +// Ask Hive to run jobs in-process as a single map and reduce task. +sql("set mapred.job.tracker=local") +``` + Using the console ================= An interactive scala console can be invoked by running `build/sbt hive/console`. From a11c8683c76c67f45749a1b50a0912a731fd2487 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Sat, 25 Apr 2015 18:07:34 -0400 Subject: [PATCH 09/39] [SPARK-7092] Update spark scala version to 2.11.6 Author: Prashant Sharma Closes #5662 from ScrapCodes/SPARK-7092/scala-update-2.11.6 and squashes the following commits: 58cf4f9 [Prashant Sharma] [SPARK-7092] Update spark scala version to 2.11.6 --- pom.xml | 4 ++-- .../src/main/scala/org/apache/spark/repl/SparkIMain.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index 4b0b0c85eff21..9fbce1d639d8b 100644 --- a/pom.xml +++ b/pom.xml @@ -1745,9 +1745,9 @@ scala-2.11 - 2.11.2 + 2.11.6 2.11 - 2.12 + 2.12.1 jline diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 1bb62c84abddc..1cb910f376060 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -1129,7 +1129,7 @@ class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings def apply(line: String): Result = debugging(s"""parse("$line")""") { var isIncomplete = false - currentRun.reporting.withIncompleteHandler((_, _) => isIncomplete = true) { + currentRun.parsing.withIncompleteHandler((_, _) => isIncomplete = true) { reporter.reset() val trees = newUnitParser(line).parseStats() if (reporter.hasErrors) Error From f5473c2bbf66cc1144a90b4c29f3ce54ad7cc419 Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Sat, 25 Apr 2015 20:02:23 -0400 Subject: [PATCH 10/39] [SPARK-6014] [CORE] [HOTFIX] Add try-catch block around ShutDownHook Add a try/catch block around removeShutDownHook else IllegalStateException thrown in YARN cluster mode (see https://github.com/apache/spark/pull/4690) cc andrewor14, srowen Author: Nishkam Ravi Author: nishkamravi2 Author: nravi Closes #5672 from nishkamravi2/master_nravi and squashes the following commits: 0f1abd0 [nishkamravi2] Update Utils.scala 474e3bf [nishkamravi2] Update DiskBlockManager.scala 97c383e [nishkamravi2] Update Utils.scala 8691e0c [Nishkam Ravi] Add a try/catch block around Utils.removeShutdownHook 2be1e76 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 1c13b79 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi bad4349 [nishkamravi2] Update Main.java 36a6f87 [Nishkam Ravi] Minor changes and bug fixes b7f4ae7 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 4a45d6a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 458af39 [Nishkam Ravi] Locate the jar using getLocation, obviates the need to pass assembly path as an argument d9658d6 [Nishkam Ravi] Changes for SPARK-6406 ccdc334 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 3faa7a4 [Nishkam Ravi] Launcher library changes (SPARK-6406) 345206a [Nishkam Ravi] spark-class merge Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi ac58975 [Nishkam Ravi] spark-class changes 06bfeb0 [nishkamravi2] Update spark-class 35af990 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 32c3ab3 [nishkamravi2] Update AbstractCommandBuilder.java 4bd4489 [nishkamravi2] Update AbstractCommandBuilder.java 746f35b [Nishkam Ravi] "hadoop" string in the assembly name should not be mandatory (everywhere else in spark we mandate spark-assembly*hadoop*.jar) bfe96e0 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi ee902fa [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi d453197 [nishkamravi2] Update NewHadoopRDD.scala 6f41a1d [nishkamravi2] Update NewHadoopRDD.scala 0ce2c32 [nishkamravi2] Update HadoopRDD.scala f7e33c2 [Nishkam Ravi] Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi ba1eb8b [Nishkam Ravi] Try-catch block around the two occurrences of removeShutDownHook. Deletion of semi-redundant occurrences of expensive operation inShutDown. 71d0e17 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 494d8c0 [nishkamravi2] Update DiskBlockManager.scala 3c5ddba [nishkamravi2] Update DiskBlockManager.scala f0d12de [Nishkam Ravi] Workaround for IllegalStateException caused by recent changes to BlockManager.stop 79ea8b4 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi b446edc [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 5c9a4cb [nishkamravi2] Update TaskSetManagerSuite.scala 535295a [nishkamravi2] Update TaskSetManager.scala 3e1b616 [Nishkam Ravi] Modify test for maxResultSize 9f6583e [Nishkam Ravi] Changes to maxResultSize code (improve error message and add condition to check if maxResultSize > 0) 5f8f9ed [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi 636a9ff [nishkamravi2] Update YarnAllocator.scala 8f76c8b [Nishkam Ravi] Doc change for yarn memory overhead 35daa64 [Nishkam Ravi] Slight change in the doc for yarn memory overhead 5ac2ec1 [Nishkam Ravi] Remove out dac1047 [Nishkam Ravi] Additional documentation for yarn memory overhead issue 42c2c3d [Nishkam Ravi] Additional changes for yarn memory overhead issue 362da5e [Nishkam Ravi] Additional changes for yarn memory overhead c726bd9 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi f00fa31 [Nishkam Ravi] Improving logging for AM memoryOverhead 1cf2d1e [nishkamravi2] Update YarnAllocator.scala ebcde10 [Nishkam Ravi] Modify default YARN memory_overhead-- from an additive constant to a multiplier (redone to resolve merge conflicts) 2e69f11 [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark into master_nravi efd688a [Nishkam Ravi] Merge branch 'master' of https://github.com/apache/spark 2b630f9 [nravi] Accept memory input as "30g", "512M" instead of an int value, to be consistent with rest of Spark 3bf8fad [nravi] Merge branch 'master' of https://github.com/apache/spark 5423a03 [nravi] Merge branch 'master' of https://github.com/apache/spark eb663ca [nravi] Merge branch 'master' of https://github.com/apache/spark df2aeb1 [nravi] Improved fix for ConcurrentModificationIssue (Spark-1097, Hadoop-10456) 6b840f0 [nravi] Undo the fix for SPARK-1758 (the problem is fixed) 5108700 [nravi] Fix in Spark for the Concurrent thread modification issue (SPARK-1097, HADOOP-10456) 681b36f [nravi] Fix for SPARK-1758: failing test org.apache.spark.JavaAPISuite.wholeTextFiles --- .../scala/org/apache/spark/storage/DiskBlockManager.scala | 7 ++++++- core/src/main/scala/org/apache/spark/util/Utils.scala | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 7ea5e54f9e1fe..5764c16902c66 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -148,7 +148,12 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { // Remove the shutdown hook. It causes memory leaks if we leave it around. - Utils.removeShutdownHook(shutdownHook) + try { + Utils.removeShutdownHook(shutdownHook) + } catch { + case e: Exception => + logError(s"Exception while removing shutdown hook.", e) + } doStop() } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 667aa168e7ef3..c6c6df7cfa56e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2172,7 +2172,7 @@ private [util] class SparkShutdownHookManager { def runAll(): Unit = synchronized { shuttingDown = true while (!hooks.isEmpty()) { - Utils.logUncaughtExceptions(hooks.poll().run()) + Try(Utils.logUncaughtExceptions(hooks.poll().run())) } } @@ -2184,7 +2184,6 @@ private [util] class SparkShutdownHookManager { } def remove(ref: AnyRef): Boolean = synchronized { - checkState() hooks.remove(ref) } From 9a5bbe05fc1b1141e139d32661821fef47d7a13c Mon Sep 17 00:00:00 2001 From: Alain Date: Sun, 26 Apr 2015 07:14:24 -0400 Subject: [PATCH 11/39] [MINOR] [MLLIB] Refactor toString method in MLLIB MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. predict(predict.toString) has already output prefix “predict” thus it’s duplicated to print ", predict = " again 2. there are some extra spaces Author: Alain Closes #5687 from AiHe/tree-node-issue-2 and squashes the following commits: 9862b9a [Alain] Pass scala coding style checking 44ba947 [Alain] Minor][MLLIB] Format toString method in MLLIB bdc402f [Alain] [Minor][MLLIB] Fix a formatting bug in toString method in Node 426eee7 [Alain] [Minor][MLLIB] Fix a formatting bug in toString method in Node.scala --- .../main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 2 +- .../org/apache/spark/mllib/regression/LabeledPoint.scala | 2 +- .../apache/spark/mllib/tree/model/InformationGainStats.scala | 4 ++-- .../main/scala/org/apache/spark/mllib/tree/model/Node.scala | 4 ++-- .../scala/org/apache/spark/mllib/tree/model/Predict.scala | 4 +--- .../main/scala/org/apache/spark/mllib/tree/model/Split.scala | 5 ++--- 6 files changed, 9 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 4ef171f4f0419..166c00cff634d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -526,7 +526,7 @@ class SparseVector( s" ${values.size} values.") override def toString: String = - "(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]")) + s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" override def toArray: Array[Double] = { val data = new Array[Double](size) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 2067b36f246b3..d5fea822ad77b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -32,7 +32,7 @@ import org.apache.spark.SparkException @BeanInfo case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { - "(%s,%s)".format(label, features) + s"($label,$features)" } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index f209fdafd3653..2d087c967f679 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -39,8 +39,8 @@ class InformationGainStats( val rightPredict: Predict) extends Serializable { override def toString: String = { - "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" - .format(gain, impurity, leftImpurity, rightImpurity) + s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " + + s"right impurity = $rightImpurity" } override def equals(o: Any): Boolean = o match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 86390a20cb5cc..431a839817eac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -51,8 +51,8 @@ class Node ( var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString: String = { - "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "impurity = " + impurity + ", split = " + split + ", stats = " + stats + s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " + + s"split = $split, stats = $stats" } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index 25990af7c6cf7..5cbe7c280dbee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -29,9 +29,7 @@ class Predict( val predict: Double, val prob: Double = 0.0) extends Serializable { - override def toString: String = { - "predict = %f, prob = %f".format(predict, prob) - } + override def toString: String = s"$predict (prob = $prob)" override def equals(other: Any): Boolean = { other match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index fb35e70a8d077..be6c9b3de5479 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -39,8 +39,8 @@ case class Split( categories: List[Double]) { override def toString: String = { - "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + - ", categories = " + categories + s"Feature = $feature, threshold = $threshold, featureType = $featureType, " + + s"categories = $categories" } } @@ -68,4 +68,3 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType) */ private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType) extends Split(feature, Double.MaxValue, featureType, List()) - From ca55dc95b777d96b27d4e4c0457dd25145dcd6e9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 26 Apr 2015 11:46:58 -0700 Subject: [PATCH 12/39] [SPARK-7152][SQL] Add a Column expression for partition ID. Author: Reynold Xin Closes #5705 from rxin/df-pid and squashes the following commits: 401018f [Reynold Xin] [SPARK-7152][SQL] Add a Column expression for partition ID. --- python/pyspark/sql/functions.py | 30 +++++++++----- .../expressions/SparkPartitionID.scala | 39 +++++++++++++++++++ .../sql/execution/expressions/package.scala | 23 +++++++++++ .../org/apache/spark/sql/functions.scala | 29 +++++++++----- .../spark/sql/ColumnExpressionSuite.scala | 8 ++++ 5 files changed, 110 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bb47923f24b82..f48b7b5d10af7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -75,6 +75,20 @@ def _(col): __all__.sort() +def approxCountDistinct(col, rsd=None): + """Returns a new :class:`Column` for approximate distinct count of ``col``. + + >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) + + def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. @@ -89,18 +103,16 @@ def countDistinct(col, *cols): return Column(jc) -def approxCountDistinct(col, rsd=None): - """Returns a new :class:`Column` for approximate distinct count of ``col``. +def sparkPartitionId(): + """Returns a column for partition ID of the Spark task. - >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() - [Row(c=2)] + Note that this is indeterministic because it depends on data partitioning and task scheduling. + + >>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect() + [Row(pid=0), Row(pid=0)] """ sc = SparkContext._active_spark_context - if rsd is None: - jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) - else: - jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) - return Column(jc) + return Column(sc._jvm.functions.sparkPartitionId()) class UserDefinedFunction(object): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala new file mode 100644 index 0000000000000..fe7607c6ac340 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.expressions + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.expressions.{Row, Expression} +import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types.{IntegerType, DataType} + + +/** + * Expression that returns the current partition id of the Spark task. + */ +case object SparkPartitionID extends Expression with trees.LeafNode[Expression] { + self: Product => + + override type EvaluatedType = Int + + override def nullable: Boolean = false + + override def dataType: DataType = IntegerType + + override def eval(input: Row): Int = TaskContext.get().partitionId() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala new file mode 100644 index 0000000000000..568b7ac2c5987 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +/** + * Package containing expressions that are specific to Spark runtime. + */ +package object expressions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ff91e1d74bc2c..9738fd4f93bad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -276,6 +276,13 @@ object functions { // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Computes the absolute value. + * + * @group normal_funcs + */ + def abs(e: Column): Column = Abs(e.expr) + /** * Returns the first column that is not null. * {{{ @@ -287,6 +294,13 @@ object functions { @scala.annotation.varargs def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + /** + * Converts a string exprsesion to lower case. + * + * @group normal_funcs + */ + def lower(e: Column): Column = Lower(e.expr) + /** * Unary minus, i.e. negate the expression. * {{{ @@ -317,18 +331,13 @@ object functions { def not(e: Column): Column = !e /** - * Converts a string expression to upper case. + * Partition ID of the Spark task. * - * @group normal_funcs - */ - def upper(e: Column): Column = Upper(e.expr) - - /** - * Converts a string exprsesion to lower case. + * Note that this is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs */ - def lower(e: Column): Column = Lower(e.expr) + def sparkPartitionId(): Column = execution.expressions.SparkPartitionID /** * Computes the square root of the specified float value. @@ -338,11 +347,11 @@ object functions { def sqrt(e: Column): Column = Sqrt(e.expr) /** - * Computes the absolutle value. + * Converts a string expression to upper case. * * @group normal_funcs */ - def abs(e: Column): Column = Abs(e.expr) + def upper(e: Column): Column = Upper(e.expr) ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index bc8fae100db6a..904073b8cb2aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -310,6 +310,14 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("sparkPartitionId") { + val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + checkAnswer( + df.select(sparkPartitionId()), + Row(0) + ) + } + test("lift alias out of cast") { compareExpressions( col("1234").as("name").cast("int").expr, From d188b8bad82836bf654e57f9dd4e1ddde1d530f4 Mon Sep 17 00:00:00 2001 From: wangfei Date: Sun, 26 Apr 2015 21:08:47 -0700 Subject: [PATCH 13/39] [SQL][Minor] rename DataTypeParser.apply to DataTypeParser.parse rename DataTypeParser.apply to DataTypeParser.parse to make it more clear and readable. /cc rxin Author: wangfei Closes #5710 from scwf/apply and squashes the following commits: c319977 [wangfei] rename apply to parse --- .../org/apache/spark/sql/catalyst/planning/patterns.scala | 2 +- .../scala/org/apache/spark/sql/types/DataTypeParser.scala | 2 +- .../org/apache/spark/sql/types/DataTypeParserSuite.scala | 4 ++-- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 2 +- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9c8c643f7d17a..4574934d910db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -92,7 +92,7 @@ object PhysicalOperation extends PredicateHelper { } def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { - case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child + case a @ Alias(child, _) => a.toAttribute -> child }.toMap def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 5163f05879e42..04f3379afb38d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -108,7 +108,7 @@ private[sql] object DataTypeParser { override val lexical = new SqlLexical } - def apply(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) + def parse(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) } /** The exception thrown from the [[DataTypeParser]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index 169125264a803..3e7cf7cbb5e63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -23,13 +23,13 @@ class DataTypeParserSuite extends FunSuite { def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { test(s"parse ${dataTypeString.replace("\n", "")}") { - assert(DataTypeParser(dataTypeString) === expectedDataType) + assert(DataTypeParser.parse(dataTypeString) === expectedDataType) } } def unsupported(dataTypeString: String): Unit = { test(s"$dataTypeString is not supported") { - intercept[DataTypeException](DataTypeParser(dataTypeString)) + intercept[DataTypeException](DataTypeParser.parse(dataTypeString)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index edb229c059e6b..33f9d0b37d006 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -647,7 +647,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops */ - def cast(to: String): Column = cast(DataTypeParser(to)) + def cast(to: String): Column = cast(DataTypeParser.parse(to)) /** * Returns an ordering used in sorting. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f1c0bd92aa23d..4d222cf88e5e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -871,7 +871,7 @@ private[hive] case class MetastoreRelation private[hive] object HiveMetastoreTypes { - def toDataType(metastoreType: String): DataType = DataTypeParser(metastoreType) + def toDataType(metastoreType: String): DataType = DataTypeParser.parse(metastoreType) def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" From 82bb7fd41a2c7992e0aea69623c504bd439744f7 Mon Sep 17 00:00:00 2001 From: baishuo Date: Mon, 27 Apr 2015 14:08:05 +0800 Subject: [PATCH 14/39] [SPARK-6505] [SQL] Remove the reflection call in HiveFunctionWrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit according liancheng‘s comment in https://issues.apache.org/jira/browse/SPARK-6505, this patch remove the reflection call in HiveFunctionWrapper, and implement the functions named "deserializeObjectByKryo" and "serializeObjectByKryo" according the functions with the save name in org.apache.hadoop.hive.ql.exec.Utilities.java Author: baishuo Closes #5660 from baishuo/SPARK-6505-20150423 and squashes the following commits: ae61ec4 [baishuo] modify code style 78d9fa3 [baishuo] modify code style 0b522a7 [baishuo] modify code style a5ff9c7 [baishuo] Remove the reflection call in HiveFunctionWrapper --- .../org/apache/spark/sql/hive/Shim13.scala | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index d331c210e8939..dbc5e029e2047 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -19,11 +19,15 @@ package org.apache.spark.sql.hive import java.rmi.server.UID import java.util.{Properties, ArrayList => JArrayList} +import java.io.{OutputStream, InputStream} import scala.collection.JavaConversions._ import scala.language.implicitConversions +import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst @@ -46,6 +50,7 @@ import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} +import org.apache.spark.util.Utils._ /** * This class provides the UDF creation and also the UDF instance serialization and @@ -61,39 +66,34 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) // for Serialization def this() = this(null) - import org.apache.spark.util.Utils._ - @transient - private val methodDeSerialize = { - val method = classOf[Utilities].getDeclaredMethod( - "deserializeObjectByKryo", - classOf[Kryo], - classOf[java.io.InputStream], - classOf[Class[_]]) - method.setAccessible(true) - - method + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp,clazz).asInstanceOf[T] + inp.close() + t } @transient - private val methodSerialize = { - val method = classOf[Utilities].getDeclaredMethod( - "serializeObjectByKryo", - classOf[Kryo], - classOf[Object], - classOf[java.io.OutputStream]) - method.setAccessible(true) - - method + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream ) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() } def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { - methodDeSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), is, clazz) + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) .asInstanceOf[UDFType] } def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { - methodSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), function, out) + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) } private var instance: AnyRef = null From 998aac21f0a0588a70f8cf123ae4080163c612fb Mon Sep 17 00:00:00 2001 From: Misha Chernetsov Date: Mon, 27 Apr 2015 11:27:56 -0700 Subject: [PATCH 15/39] [SPARK-4925] Publish Spark SQL hive-thriftserver maven artifact turned on hive-thriftserver profile in release script Author: Misha Chernetsov Closes #5429 from chernetsov/master and squashes the following commits: 9cc36af [Misha Chernetsov] [SPARK-4925] Publish Spark SQL hive-thriftserver maven artifact turned on hive-thriftserver profile in release script for scala 2.10 --- dev/create-release/create-release.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index b5a67dd783b93..3dbb35f7054a2 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -119,7 +119,7 @@ if [[ ! "$@" =~ --skip-publish ]]; then rm -rf $SPARK_REPO build/mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-version-to-2.11.sh From 7078f6028bf012235c664b02ec3541cbb0a248a7 Mon Sep 17 00:00:00 2001 From: Jeff Harrison Date: Mon, 27 Apr 2015 13:38:25 -0700 Subject: [PATCH 16/39] [SPARK-6856] [R] Make RDD information more useful in SparkR Author: Jeff Harrison Closes #5667 from His-name-is-Joof/joofspark and squashes the following commits: f8814a6 [Jeff Harrison] newline added after RDD show() output 4d9d972 [Jeff Harrison] Merge branch 'master' into joofspark 9d2295e [Jeff Harrison] parallelize with 1:10 878b830 [Jeff Harrison] Merge branch 'master' into joofspark c8c0b80 [Jeff Harrison] add test for RDD function show() 123be65 [Jeff Harrison] SPARK-6856 --- R/pkg/R/RDD.R | 5 +++++ R/pkg/inst/tests/test_rdd.R | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 1662d6bb3b1ac..f90c26b253455 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -66,6 +66,11 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, .Object }) +setMethod("show", "RDD", + function(.Object) { + cat(paste(callJMethod(.Object@jrdd, "toString"), "\n", sep="")) + }) + setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { .Object@env <- new.env() .Object@env$isCached <- FALSE diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index d55af93e3e50a..03207353c31c6 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -759,6 +759,11 @@ test_that("collectAsMap() on a pairwise RDD", { expect_equal(vals, list(`1` = "a", `2` = "b")) }) +test_that("show()", { + rdd <- parallelize(sc, list(1:10)) + expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") +}) + test_that("sampleByKey() on pairwise RDDs", { rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) From ef82bddc11d1aea42e22d2f85613a869cbe9a990 Mon Sep 17 00:00:00 2001 From: tedyu Date: Mon, 27 Apr 2015 14:42:40 -0700 Subject: [PATCH 17/39] SPARK-7107 Add parameter for zookeeper.znode.parent to hbase_inputformat... ....py Author: tedyu Closes #5673 from tedyu/master and squashes the following commits: ab7c72b [tedyu] SPARK-7107 Adjust indentation to pass Python style tests 6e25939 [tedyu] Adjust line length to be shorter than 100 characters 18d172a [tedyu] SPARK-7107 Add parameter for zookeeper.znode.parent to hbase_inputformat.py --- examples/src/main/python/hbase_inputformat.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index e17819d5feb76..5b82a14fba413 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -54,8 +54,9 @@ Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \ - /path/to/examples/hbase_inputformat.py + /path/to/examples/hbase_inputformat.py
[] Assumes you have some data in HBase already, running on , in
+ optionally, you can specify parent znode for your hbase cluster - """, file=sys.stderr) exit(-1) @@ -64,6 +65,9 @@ sc = SparkContext(appName="HBaseInputFormat") conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table} + if len(sys.argv) > 3: + conf = {"hbase.zookeeper.quorum": host, "zookeeper.znode.parent": sys.argv[3], + "hbase.mapreduce.inputtable": table} keyConv = "org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter" valueConv = "org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter" From ca9f4ebb8e510e521bf4df0331375ddb385fb9d2 Mon Sep 17 00:00:00 2001 From: hlin09 Date: Mon, 27 Apr 2015 15:04:37 -0700 Subject: [PATCH 18/39] [SPARK-6991] [SPARKR] Adds support for zipPartitions. Author: hlin09 Closes #5568 from hlin09/zipPartitions and squashes the following commits: 12c08a5 [hlin09] Fix comments d2d32db [hlin09] Merge branch 'master' into zipPartitions ec56d2f [hlin09] Fix test. 27655d3 [hlin09] Adds support for zipPartitions. --- R/pkg/NAMESPACE | 1 + R/pkg/R/RDD.R | 46 +++++++++++++++++++++++++ R/pkg/R/generics.R | 5 +++ R/pkg/inst/tests/test_binary_function.R | 33 ++++++++++++++++++ 4 files changed, 85 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 80283643861ac..e077eace74375 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -71,6 +71,7 @@ exportMethods( "unpersist", "value", "values", + "zipPartitions", "zipRDD", "zipWithIndex", "zipWithUniqueId" diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index f90c26b253455..a3a0421a0746d 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -1595,3 +1595,49 @@ setMethod("intersection", keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction)) }) + +#' Zips an RDD's partitions with one (or more) RDD(s). +#' Same as zipPartitions in Spark. +#' +#' @param ... RDDs to be zipped. +#' @param func A function to transform zipped partitions. +#' @return A new RDD by applying a function to the zipped partitions. +#' Assumes that all the RDDs have the *same number of partitions*, but +#' does *not* require them to have the same number of elements in each partition. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 +#' rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 +#' rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 +#' collect(zipPartitions(rdd1, rdd2, rdd3, +#' func = function(x, y, z) { list(list(x, y, z))} )) +#' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) +#'} +#' @rdname zipRDD +#' @aliases zipPartitions,RDD +setMethod("zipPartitions", + "RDD", + function(..., func) { + rrdds <- list(...) + if (length(rrdds) == 1) { + return(rrdds[[1]]) + } + nPart <- sapply(rrdds, numPartitions) + if (length(unique(nPart)) != 1) { + stop("Can only zipPartitions RDDs which have the same number of partitions.") + } + + rrdds <- lapply(rrdds, function(rdd) { + mapPartitionsWithIndex(rdd, function(partIndex, part) { + print(length(part)) + list(list(partIndex, part)) + }) + }) + union.rdd <- Reduce(unionRDD, rrdds) + zipped.rdd <- values(groupByKey(union.rdd, numPartitions = nPart[1])) + res <- mapPartitions(zipped.rdd, function(plist) { + do.call(func, plist[[1]]) + }) + res + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 34dbe84051c50..e88729387ef95 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -217,6 +217,11 @@ setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) #' @export setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) +#' @rdname zipRDD +#' @export +setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, + signature = "...") + #' @rdname zipWithIndex #' @seealso zipWithUniqueId #' @export diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index c15553ba28517..6785a7bdae8cb 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -66,3 +66,36 @@ test_that("cogroup on two RDDs", { expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) }) + +test_that("zipPartitions() on RDDs", { + rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 + rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 + rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 + actual <- collect(zipPartitions(rdd1, rdd2, rdd3, + func = function(x, y, z) { list(list(x, y, z))} )) + expect_equal(actual, + list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName, 1) + actual <- collect(zipPartitions(rdd, rdd, + func = function(x, y) { list(paste(x, y, sep = "\n")) })) + expected <- list(paste(mockFile, mockFile, sep = "\n")) + expect_equal(actual, expected) + + rdd1 <- parallelize(sc, 0:1, 1) + actual <- collect(zipPartitions(rdd1, rdd, + func = function(x, y) { list(x + nchar(y)) })) + expected <- list(0:1 + nchar(mockFile)) + expect_equal(actual, expected) + + rdd <- map(rdd, function(x) { x }) + actual <- collect(zipPartitions(rdd, rdd1, + func = function(x, y) { list(y + nchar(x)) })) + expect_equal(actual, expected) + + unlink(fileName) +}) From b9de9e040aff371c6acf9b3f3d1ff8b360c0cd56 Mon Sep 17 00:00:00 2001 From: Steven She Date: Mon, 27 Apr 2015 18:55:02 -0400 Subject: [PATCH 19/39] [SPARK-7103] Fix crash with SparkContext.union when RDD has no partitioner Added a check to the SparkContext.union method to check that a partitioner is defined on all RDDs when instantiating a PartitionerAwareUnionRDD. Author: Steven She Closes #5679 from stevencanopy/SPARK-7103 and squashes the following commits: 5a3d846 [Steven She] SPARK-7103: Fix crash with SparkContext.union when at least one RDD has no partitioner --- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/rdd/PartitionerAwareUnionRDD.scala | 1 + .../scala/org/apache/spark/rdd/RDDSuite.scala | 21 +++++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 86269eac52db0..ea4ddcc2e265d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1055,7 +1055,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Build the union of a list of RDDs. */ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = { val partitioners = rdds.flatMap(_.partitioner).toSet - if (partitioners.size == 1) { + if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { new PartitionerAwareUnionRDD(this, rdds) } else { new UnionRDD(this, rdds) diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 92b0641d0fb6e..7598ff617b399 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -60,6 +60,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( var rdds: Seq[RDD[T]] ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) { require(rdds.length > 0) + require(rdds.forall(_.partitioner.isDefined)) require(rdds.flatMap(_.partitioner).toSet.size == 1, "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner)) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index df42faab64505..ef8c36a28655b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -99,6 +99,27 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) } + test("SparkContext.union creates UnionRDD if at least one RDD has no partitioner") { + val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val unionRdd = sc.union(rddWithNoPartitioner, rddWithPartitioner) + assert(unionRdd.isInstanceOf[UnionRDD[_]]) + } + + test("SparkContext.union creates PartitionAwareUnionRDD if all RDDs have partitioners") { + val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val unionRdd = sc.union(rddWithPartitioner, rddWithPartitioner) + assert(unionRdd.isInstanceOf[PartitionerAwareUnionRDD[_]]) + } + + test("PartitionAwareUnionRDD raises exception if at least one RDD has no partitioner") { + val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + intercept[IllegalArgumentException] { + new PartitionerAwareUnionRDD(sc, Seq(rddWithNoPartitioner, rddWithPartitioner)) + } + } + test("partitioner aware union") { def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) From 8e1c00dbf4b60962908626dead744e5d73c8085e Mon Sep 17 00:00:00 2001 From: Hong Shen Date: Mon, 27 Apr 2015 18:57:31 -0400 Subject: [PATCH 20/39] [SPARK-6738] [CORE] Improve estimate the size of a large array Currently, SizeEstimator.visitArray is not correct in the follow case, ``` array size > 200, elem has the share object ``` when I add a debug log in SizeTracker.scala: ``` System.err.println(s"numUpdates:$numUpdates, size:$ts, bytesPerUpdate:$bytesPerUpdate, cost time:$b") ``` I get the following log: ``` numUpdates:1, size:262448, bytesPerUpdate:0.0, cost time:35 numUpdates:2, size:420698, bytesPerUpdate:158250.0, cost time:35 numUpdates:4, size:420754, bytesPerUpdate:28.0, cost time:32 numUpdates:7, size:420754, bytesPerUpdate:0.0, cost time:27 numUpdates:12, size:420754, bytesPerUpdate:0.0, cost time:28 numUpdates:20, size:420754, bytesPerUpdate:0.0, cost time:25 numUpdates:32, size:420754, bytesPerUpdate:0.0, cost time:21 numUpdates:52, size:420754, bytesPerUpdate:0.0, cost time:20 numUpdates:84, size:420754, bytesPerUpdate:0.0, cost time:20 numUpdates:135, size:420754, bytesPerUpdate:0.0, cost time:20 numUpdates:216, size:420754, bytesPerUpdate:0.0, cost time:11 numUpdates:346, size:420754, bytesPerUpdate:0.0, cost time:6 numUpdates:554, size:488911, bytesPerUpdate:327.67788461538464, cost time:8 numUpdates:887, size:2312259426, bytesPerUpdate:6942253.798798799, cost time:198 15/04/21 14:27:26 INFO collection.ExternalAppendOnlyMap: Thread 51 spilling in-memory map of 3.0 GB to disk (1 time so far) 15/04/21 14:27:26 INFO collection.ExternalAppendOnlyMap: /data11/yarnenv/local/usercache/spark/appcache/application_1426746631567_11745/spark-local-20150421142719-c001/30/temp_local_066af981-c2fc-4b70-a00e-110e23006fbc ``` But in fact the file size is only 162K: ``` $ ll -h /data11/yarnenv/local/usercache/spark/appcache/application_1426746631567_11745/spark-local-20150421142719-c001/30/temp_local_066af981-c2fc-4b70-a00e-110e23006fbc -rw-r----- 1 spark users 162K Apr 21 14:27 /data11/yarnenv/local/usercache/spark/appcache/application_1426746631567_11745/spark-local-20150421142719-c001/30/temp_local_066af981-c2fc-4b70-a00e-110e23006fbc ``` In order to test case, I change visitArray to: ``` var size = 0l for (i <- 0 until length) { val obj = JArray.get(array, i) size += SizeEstimator.estimate(obj, state.visited).toLong } state.size += size ``` I get the following log: ``` ... 14895 277016088 566.9046118590662 time:8470 23832 281840544 552.3308270676691 time:8031 38132 289891824 539.8294729775092 time:7897 61012 302803640 563.0265734265735 time:13044 97620 322904416 564.3276223776223 time:13554 15/04/14 11:46:43 INFO collection.ExternalAppendOnlyMap: Thread 51 spilling in-memory map of 314.5 MB to disk (1 time so far) 15/04/14 11:46:43 INFO collection.ExternalAppendOnlyMap: /data1/yarnenv/local/usercache/spark/appcache/application_1426746631567_8477/spark-local-20150414114020-2fcb/14/temp_local_5b6b98d5-5bfa-47e2-8216-059482ccbda0 ``` the file size is 85M. ``` $ ll -h /data1/yarnenv/local/usercache/spark/appcache/application_1426746631567_8477/spark- local-20150414114020-2fcb/14/ total 85M -rw-r----- 1 spark users 85M Apr 14 11:46 temp_local_5b6b98d5-5bfa-47e2-8216-059482ccbda0 ``` The following log is when I use this patch, ``` .... numUpdates:32, size:365484, bytesPerUpdate:0.0, cost time:7 numUpdates:52, size:365484, bytesPerUpdate:0.0, cost time:5 numUpdates:84, size:365484, bytesPerUpdate:0.0, cost time:5 numUpdates:135, size:372208, bytesPerUpdate:131.84313725490196, cost time:86 numUpdates:216, size:379020, bytesPerUpdate:84.09876543209876, cost time:21 numUpdates:346, size:1865208, bytesPerUpdate:11432.215384615385, cost time:23 numUpdates:554, size:2052380, bytesPerUpdate:899.8653846153846, cost time:16 numUpdates:887, size:2142820, bytesPerUpdate:271.59159159159157, cost time:15 .. numUpdates:14895, size:251675500, bytesPerUpdate:438.5263157894737, cost time:13 numUpdates:23832, size:257010268, bytesPerUpdate:596.9305135951662, cost time:14 numUpdates:38132, size:263922396, bytesPerUpdate:483.3655944055944, cost time:15 numUpdates:61012, size:268962596, bytesPerUpdate:220.28846153846155, cost time:24 numUpdates:97620, size:286980644, bytesPerUpdate:492.1888111888112, cost time:22 15/04/21 14:45:12 INFO collection.ExternalAppendOnlyMap: Thread 53 spilling in-memory map of 328.7 MB to disk (1 time so far) 15/04/21 14:45:12 INFO collection.ExternalAppendOnlyMap: /data4/yarnenv/local/usercache/spark/appcache/application_1426746631567_11758/spark-local-20150421144456-a2a5/2a/temp_local_9c109510-af16-4468-8f23-48cad04da88f ``` the file size is 88M. ``` $ ll -h /data4/yarnenv/local/usercache/spark/appcache/application_1426746631567_11758/spark-local-20150421144456-a2a5/2a/ total 88M -rw-r----- 1 spark users 88M Apr 21 14:45 temp_local_9c109510-af16-4468-8f23-48cad04da88f ``` Author: Hong Shen Closes #5608 from shenh062326/my_change5 and squashes the following commits: 5506bae [Hong Shen] Fix compile error c275dd3 [Hong Shen] Alter code style fe202a2 [Hong Shen] Change the code style and add documentation. a9fca84 [Hong Shen] Add test case for SizeEstimator 4877eee [Hong Shen] Improve estimate the size of a large array a2ea7ac [Hong Shen] Alter code style 4c28e36 [Hong Shen] Improve estimate the size of a large array --- .../org/apache/spark/util/SizeEstimator.scala | 45 ++++++++++++------- .../spark/util/SizeEstimatorSuite.scala | 18 ++++++++ 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 26ffbf9350388..4dd7ab9e0767b 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -179,7 +179,7 @@ private[spark] object SizeEstimator extends Logging { } // Estimate the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling. - private val ARRAY_SIZE_FOR_SAMPLING = 200 + private val ARRAY_SIZE_FOR_SAMPLING = 400 private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState) { @@ -204,25 +204,40 @@ private[spark] object SizeEstimator extends Logging { } } else { // Estimate the size of a large array by sampling elements without replacement. - var size = 0.0 + // To exclude the shared objects that the array elements may link, sample twice + // and use the min one to caculate array size. val rand = new Random(42) - val drawn = new OpenHashSet[Int](ARRAY_SAMPLE_SIZE) - var numElementsDrawn = 0 - while (numElementsDrawn < ARRAY_SAMPLE_SIZE) { - var index = 0 - do { - index = rand.nextInt(length) - } while (drawn.contains(index)) - drawn.add(index) - val elem = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef] - size += SizeEstimator.estimate(elem, state.visited) - numElementsDrawn += 1 - } - state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong + val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE) + val s1 = sampleArray(array, state, rand, drawn, length) + val s2 = sampleArray(array, state, rand, drawn, length) + val size = math.min(s1, s2) + state.size += math.max(s1, s2) + + (size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong } } } + private def sampleArray( + array: AnyRef, + state: SearchState, + rand: Random, + drawn: OpenHashSet[Int], + length: Int): Long = { + var size = 0L + for (i <- 0 until ARRAY_SAMPLE_SIZE) { + var index = 0 + do { + index = rand.nextInt(length) + } while (drawn.contains(index)) + drawn.add(index) + val obj = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef] + if (obj != null) { + size += SizeEstimator.estimate(obj, state.visited).toLong + } + } + size + } + private def primitiveSize(cls: Class[_]): Long = { if (cls == classOf[Byte]) { BYTE_SIZE diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 67a9f75ff2187..28915bd53354e 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import scala.collection.mutable.ArrayBuffer + import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite, PrivateMethodTester} class DummyClass1 {} @@ -96,6 +98,22 @@ class SizeEstimatorSuite // Past size 100, our samples 100 elements, but we should still get the right size. assertResult(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3))) + + val arr = new Array[Char](100000) + assertResult(200016)(SizeEstimator.estimate(arr)) + assertResult(480032)(SizeEstimator.estimate(Array.fill(10000)(new DummyString(arr)))) + + val buf = new ArrayBuffer[DummyString]() + for (i <- 0 until 5000) { + buf.append(new DummyString(new Array[Char](10))) + } + assertResult(340016)(SizeEstimator.estimate(buf.toArray)) + + for (i <- 0 until 5000) { + buf.append(new DummyString(arr)) + } + assertResult(683912)(SizeEstimator.estimate(buf.toArray)) + // If an array contains the *same* element many times, we should only count it once. val d1 = new DummyClass1 // 10 pointers plus 8-byte object From 5d45e1f60059e2f2fc8ad64778b9ddcc8887c570 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 27 Apr 2015 19:46:17 -0400 Subject: [PATCH 21/39] [SPARK-3090] [CORE] Stop SparkContext if user forgets to. Set up a shutdown hook to try to stop the Spark context in case the user forgets to do it. The main effect is that any open logs files are flushed and closed, which is particularly interesting for event logs. Author: Marcelo Vanzin Closes #5696 from vanzin/SPARK-3090 and squashes the following commits: 3b554b5 [Marcelo Vanzin] [SPARK-3090] [core] Stop SparkContext if user forgets to. --- .../scala/org/apache/spark/SparkContext.scala | 38 ++++++++++++------- .../scala/org/apache/spark/util/Utils.scala | 10 ++++- .../spark/deploy/yarn/ApplicationMaster.scala | 10 +---- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ea4ddcc2e265d..65b903a55d5bd 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -223,6 +223,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private var _listenerBusStarted: Boolean = false private var _jars: Seq[String] = _ private var _files: Seq[String] = _ + private var _shutdownHookRef: AnyRef = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -517,6 +518,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _taskScheduler.postStartHook() _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) + + // Make sure the context is stopped if the user forgets about it. This avoids leaving + // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM + // is killed, though. + _shutdownHookRef = Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => + logInfo("Invoking stop() from shutdown hook") + stop() + } } catch { case NonFatal(e) => logError("Error initializing SparkContext.", e) @@ -1481,6 +1490,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli logInfo("SparkContext already stopped.") return } + if (_shutdownHookRef != null) { + Utils.removeShutdownHook(_shutdownHookRef) + } postApplicationEnd() _ui.foreach(_.stop()) @@ -1891,7 +1903,7 @@ object SparkContext extends Logging { * * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK. */ - private val activeContext: AtomicReference[SparkContext] = + private val activeContext: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null) /** @@ -1944,11 +1956,11 @@ object SparkContext extends Logging { } /** - * This function may be used to get or instantiate a SparkContext and register it as a - * singleton object. Because we can only have one active SparkContext per JVM, - * this is useful when applications may wish to share a SparkContext. + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, + * this is useful when applications may wish to share a SparkContext. * - * Note: This function cannot be used to create multiple SparkContext instances + * Note: This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(config: SparkConf): SparkContext = { @@ -1961,17 +1973,17 @@ object SparkContext extends Logging { activeContext.get() } } - + /** - * This function may be used to get or instantiate a SparkContext and register it as a - * singleton object. Because we can only have one active SparkContext per JVM, + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. - * + * * This method allows not passing a SparkConf (useful if just retrieving). - * - * Note: This function cannot be used to create multiple SparkContext instances - * even if multiple contexts are allowed. - */ + * + * Note: This function cannot be used to create multiple SparkContext instances + * even if multiple contexts are allowed. + */ def getOrCreate(): SparkContext = { getOrCreate(new SparkConf()) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c6c6df7cfa56e..342bc9a06db47 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -67,6 +67,12 @@ private[spark] object Utils extends Logging { val DEFAULT_SHUTDOWN_PRIORITY = 100 + /** + * The shutdown priority of the SparkContext instance. This is lower than the default + * priority, so that by default hooks are run before the context is shut down. + */ + val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null @@ -2116,7 +2122,7 @@ private[spark] object Utils extends Logging { * @return A handle that can be used to unregister the shutdown hook. */ def addShutdownHook(hook: () => Unit): AnyRef = { - addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY, hook) + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) } /** @@ -2126,7 +2132,7 @@ private[spark] object Utils extends Logging { * @param hook The code to run during shutdown. * @return A handle that can be used to unregister the shutdown hook. */ - def addShutdownHook(priority: Int, hook: () => Unit): AnyRef = { + def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { shutdownHooks.add(priority, hook) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 93ae45133ce24..70cb57ffd8c69 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -95,14 +95,8 @@ private[spark] class ApplicationMaster( val fs = FileSystem.get(yarnConf) - Utils.addShutdownHook { () => - // If the SparkContext is still registered, shut it down as a best case effort in case - // users do not call sc.stop or do System.exit(). - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - } + // This shutdown hook should run *after* the SparkContext is shut down. + Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1) { () => val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts From ab5adb7a973eec9d95c7575c864cba9f8d83a0fd Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 27 Apr 2015 19:50:55 -0400 Subject: [PATCH 22/39] [SPARK-7145] [CORE] commons-lang (2.x) classes used instead of commons-lang3 (3.x); commons-io used without dependency Remove use of commons-lang in favor of commons-lang3 classes; remove commons-io use in favor of Guava Author: Sean Owen Closes #5703 from srowen/SPARK-7145 and squashes the following commits: 21fbe03 [Sean Owen] Remove use of commons-lang in favor of commons-lang3 classes; remove commons-io use in favor of Guava --- .../test/scala/org/apache/spark/FileServerSuite.scala | 7 +++---- .../apache/spark/metrics/InputOutputMetricsSuite.scala | 4 ++-- .../netty/NettyBlockTransferSecuritySuite.scala | 10 +++++++--- external/flume-sink/pom.xml | 4 ++++ .../flume/sink/SparkAvroCallbackHandler.scala | 4 ++-- .../main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala | 6 +++++- .../sql/hive/thriftserver/AbstractSparkSQLDriver.scala | 4 ++-- .../sql/hive/thriftserver/SparkSQLCLIDriver.scala | 8 +++----- .../apache/spark/sql/hive/execution/UDFListString.java | 6 +++--- .../spark/sql/hive/MetastoreDataSourcesSuite.scala | 9 ++++----- 10 files changed, 35 insertions(+), 27 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index a69e9b761f9a7..c0439f934813e 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -22,8 +22,7 @@ import java.net.URI import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl.SSLException -import com.google.common.io.ByteStreams -import org.apache.commons.io.{FileUtils, IOUtils} +import com.google.common.io.{ByteStreams, Files} import org.apache.commons.lang3.RandomUtils import org.scalatest.FunSuite @@ -239,7 +238,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { def fileTransferTest(server: HttpFileServer, sm: SecurityManager = null): Unit = { val randomContent = RandomUtils.nextBytes(100) val file = File.createTempFile("FileServerSuite", "sslTests", tmpDir) - FileUtils.writeByteArrayToFile(file, randomContent) + Files.write(randomContent, file) server.addFile(file) val uri = new URI(server.serverUri + "/files/" + file.getName) @@ -254,7 +253,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { Utils.setupSecureURLConnection(connection, sm) } - val buf = IOUtils.toByteArray(connection.getInputStream) + val buf = ByteStreams.toByteArray(connection.getInputStream) assert(buf === randomContent) } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 190b08d950a02..ef3e213f1fcce 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -21,7 +21,7 @@ import java.io.{File, FileWriter, PrintWriter} import scala.collection.mutable.ArrayBuffer -import org.apache.commons.lang.math.RandomUtils +import org.apache.commons.lang3.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} @@ -60,7 +60,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(tmpFile)) for (x <- 1 to numRecords) { - pw.println(RandomUtils.nextInt(numBuckets)) + pw.println(RandomUtils.nextInt(0, numBuckets)) } pw.close() diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 94bfa67451892..46d2e5173acae 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.network.netty +import java.io.InputStreamReader import java.nio._ +import java.nio.charset.Charset import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} import scala.util.{Failure, Success, Try} -import org.apache.commons.io.IOUtils +import com.google.common.io.CharStreams import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.{BlockDataManager, BlockTransferService} @@ -32,7 +34,7 @@ import org.apache.spark.storage.{BlockId, ShuffleBlockId} import org.apache.spark.{SecurityManager, SparkConf} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers} +import org.scalatest.{FunSuite, ShouldMatchers} class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { test("security default off") { @@ -113,7 +115,9 @@ class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with Sh val result = fetchBlock(exec0, exec1, "1", blockId) match { case Success(buf) => - IOUtils.toString(buf.createInputStream()) should equal(blockString) + val actualString = CharStreams.toString( + new InputStreamReader(buf.createInputStream(), Charset.forName("UTF-8"))) + actualString should equal(blockString) buf.release() Success() case Failure(t) => diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 67907bbfb6d1b..1f3e619d97a24 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -35,6 +35,10 @@ http://spark.apache.org/ + + org.apache.commons + commons-lang3 + org.apache.flume flume-ng-sdk diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index 4373be443e67d..fd01807fc3ac4 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -21,9 +21,9 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.flume.Channel -import org.apache.commons.lang.RandomStringUtils import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.flume.Channel +import org.apache.commons.lang3.RandomStringUtils /** * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index f326510042122..f3b5455574d1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import java.util.Properties -import org.apache.commons.lang.StringEscapeUtils.escapeSql +import org.apache.commons.lang3.StringUtils + import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} @@ -239,6 +240,9 @@ private[sql] class JDBCRDD( case _ => value } + private def escapeSql(value: String): String = + if (value == null) null else StringUtils.replace(value, "'", "''") + /** * Turns a single Filter into a String representing a SQL expression. * Returns null for an unhandled filter. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala index 59f3a75768082..48ac9062af96a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import scala.collection.JavaConversions._ -import org.apache.commons.lang.exception.ExceptionUtils +import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse @@ -61,7 +61,7 @@ private[hive] abstract class AbstractSparkSQLDriver( } catch { case cause: Throwable => logError(s"Failed in [$command]", cause) - new CommandProcessorResponse(1, ExceptionUtils.getFullStackTrace(cause), null) + new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7e307bb4ad1e8..b7b6925aa87f7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -24,18 +24,16 @@ import java.util.{ArrayList => JArrayList} import jline.{ConsoleReader, History} -import org.apache.commons.lang.StringUtils +import org.apache.commons.lang3.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} -import org.apache.hadoop.hive.common.LogUtils.LogInitializationException -import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} +import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor} import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket import org.apache.spark.Logging diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java index efd34df293c88..f33210ebdae1b 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java @@ -17,10 +17,10 @@ package org.apache.spark.sql.hive.execution; -import org.apache.hadoop.hive.ql.exec.UDF; - import java.util.List; -import org.apache.commons.lang.StringUtils; + +import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.hive.ql.exec.UDF; public class UDFListString extends UDF { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e09c702c8969e..0538aa203c5a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.BeforeAndAfterEach -import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.metastore.TableType import org.apache.hadoop.hive.ql.metadata.Table @@ -174,7 +173,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT * FROM jsonTable"), Row("a", "b")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF() .toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -190,7 +189,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), Row("a1", "b1", "c1")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) } test("drop, change, recreate") { @@ -212,7 +211,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT * FROM jsonTable"), Row("a", "b")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) sparkContext.parallelize(("a", "b", "c") :: Nil).toDF() .toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -231,7 +230,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), Row("a", "b", "c")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) } test("invalidate cache and reload") { From 62888a4ded91b3c2cbb05936c374c7ebfc10799e Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Mon, 27 Apr 2015 19:52:41 -0400 Subject: [PATCH 23/39] [SPARK-7162] [YARN] Launcher error in yarn-client jira: https://issues.apache.org/jira/browse/SPARK-7162 Author: GuoQiang Li Closes #5716 from witgo/SPARK-7162 and squashes the following commits: b64564c [GuoQiang Li] Launcher error in yarn-client --- yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 019afbd1a1743..741239c953794 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -354,7 +354,7 @@ private[spark] class Client( val dir = new File(path) if (dir.isDirectory()) { dir.listFiles().foreach { file => - if (!hadoopConfFiles.contains(file.getName())) { + if (file.isFile && !hadoopConfFiles.contains(file.getName())) { hadoopConfFiles(file.getName()) = file } } From 4d9e560b5470029143926827b1cb9d72a0bfbeff Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 27 Apr 2015 19:02:51 -0700 Subject: [PATCH 24/39] [SPARK-7090] [MLLIB] Introduce LDAOptimizer to LDA to further improve extensibility jira: https://issues.apache.org/jira/browse/SPARK-7090 LDA was implemented with extensibility in mind. And with the development of OnlineLDA and Gibbs Sampling, we are collecting more detailed requirements from different algorithms. As Joseph Bradley jkbradley proposed in https://github.com/apache/spark/pull/4807 and with some further discussion, we'd like to adjust the code structure a little to present the common interface and extension point clearly. Basically class LDA would be a common entrance for LDA computing. And each LDA object will refer to a LDAOptimizer for the concrete algorithm implementation. Users can customize LDAOptimizer with specific parameters and assign it to LDA. Concrete changes: 1. Add a trait `LDAOptimizer`, which defines the common iterface for concrete implementations. Each subClass is a wrapper for a specific LDA algorithm. 2. Move EMOptimizer to file LDAOptimizer and inherits from LDAOptimizer, rename to EMLDAOptimizer. (in case a more generic EMOptimizer comes in the future) -adjust the constructor of EMOptimizer, since all the parameters should be passed in through initialState method. This can avoid unwanted confusion or overwrite. -move the code from LDA.initalState to initalState of EMLDAOptimizer 3. Add property ldaOptimizer to LDA and its getter/setter, and EMLDAOptimizer is the default Optimizer. 4. Change the return type of LDA.run from DistributedLDAModel to LDAModel. Further work: add OnlineLDAOptimizer and other possible Optimizers once ready. Author: Yuhao Yang Closes #5661 from hhbyyh/ldaRefactor and squashes the following commits: 0e2e006 [Yuhao Yang] respond to review comments 08a45da [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor e756ce4 [Yuhao Yang] solve mima exception d74fd8f [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor 0bb8400 [Yuhao Yang] refactor LDA with Optimizer ec2f857 [Yuhao Yang] protoptype for discussion --- .../spark/examples/mllib/JavaLDAExample.java | 2 +- .../spark/examples/mllib/LDAExample.scala | 4 +- .../apache/spark/mllib/clustering/LDA.scala | 181 +++------------ .../spark/mllib/clustering/LDAModel.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 210 ++++++++++++++++++ .../spark/mllib/clustering/JavaLDASuite.java | 2 +- .../spark/mllib/clustering/LDASuite.scala | 2 +- project/MimaExcludes.scala | 4 + 8 files changed, 256 insertions(+), 151 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java index 36207ae38d9a9..fd53c81cc4974 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -58,7 +58,7 @@ public Tuple2 call(Tuple2 doc_id) { corpus.cache(); // Cluster the documents into three topics using LDA - DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus); // Output topics. Each is a distribution over words (matching word count vectors) System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 08a93595a2e17..a1850390c0a86 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -26,7 +26,7 @@ import scopt.OptionParser import org.apache.log4j.{Level, Logger} import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -137,7 +137,7 @@ object LDAExample { sc.setCheckpointDir(params.checkpointDir.get) } val startTime = System.nanoTime() - val ldaModel = lda.run(corpus) + val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] val elapsed = (System.nanoTime() - startTime) / 1e9 println(s"Finished training LDA model. Summary:") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index d006b39acb213..37bf88b73b911 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -17,16 +17,11 @@ package org.apache.spark.mllib.clustering -import java.util.Random - -import breeze.linalg.{DenseVector => BDV, normalize} - +import breeze.linalg.{DenseVector => BDV} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.GraphImpl -import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -42,16 +37,9 @@ import org.apache.spark.util.Utils * - "token": instance of a term appearing in a document * - "topic": multinomial distribution over words representing some concept * - * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented - * according to the Asuncion et al. (2009) paper referenced below. - * * References: * - Original LDA paper (journal version): * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. - * - This class implements their "smoothed" LDA model. - * - Paper which clearly explains several algorithms, including EM: - * Asuncion, Welling, Smyth, and Teh. - * "On Smoothing and Inference for Topic Models." UAI, 2009. * * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation * (Wikipedia)]] @@ -63,10 +51,11 @@ class LDA private ( private var docConcentration: Double, private var topicConcentration: Double, private var seed: Long, - private var checkpointInterval: Int) extends Logging { + private var checkpointInterval: Int, + private var ldaOptimizer: LDAOptimizer) extends Logging { def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, - seed = Utils.random.nextLong(), checkpointInterval = 10) + seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer) /** * Number of topics to infer. I.e., the number of soft cluster centers. @@ -220,6 +209,32 @@ class LDA private ( this } + + /** LDAOptimizer used to perform the actual calculation */ + def getOptimizer: LDAOptimizer = ldaOptimizer + + /** + * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer) + */ + def setOptimizer(optimizer: LDAOptimizer): this.type = { + this.ldaOptimizer = optimizer + this + } + + /** + * Set the LDAOptimizer used to perform the actual calculation by algorithm name. + * Currently "em" is supported. + */ + def setOptimizer(optimizerName: String): this.type = { + this.ldaOptimizer = + optimizerName.toLowerCase match { + case "em" => new EMLDAOptimizer + case other => + throw new IllegalArgumentException(s"Only em is supported but got $other.") + } + this + } + /** * Learn an LDA model using the given dataset. * @@ -229,9 +244,9 @@ class LDA private ( * Document IDs must be unique and >= 0. * @return Inferred LDA model */ - def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { - val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, - checkpointInterval) + def run(documents: RDD[(Long, Vector)]): LDAModel = { + val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration, + seed, checkpointInterval) var iter = 0 val iterationTimes = Array.fill[Double](maxIterations)(0) while (iter < maxIterations) { @@ -241,12 +256,11 @@ class LDA private ( iterationTimes(iter) = elapsedSeconds iter += 1 } - state.graphCheckpointer.deleteAllCheckpoints() - new DistributedLDAModel(state, iterationTimes) + state.getLDAModel(iterationTimes) } /** Java-friendly version of [[run()]] */ - def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = { + def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } } @@ -320,88 +334,10 @@ private[clustering] object LDA { private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 - /** - * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. - * - * @param graph EM graph, storing current parameter estimates in vertex descriptors and - * data (token counts) in edge descriptors. - * @param k Number of topics - * @param vocabSize Number of unique terms - * @param docConcentration "alpha" - * @param topicConcentration "beta" or "eta" - */ - private[clustering] class EMOptimizer( - var graph: Graph[TopicCounts, TokenCount], - val k: Int, - val vocabSize: Int, - val docConcentration: Double, - val topicConcentration: Double, - checkpointInterval: Int) { - - private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( - graph, checkpointInterval) - - def next(): EMOptimizer = { - val eta = topicConcentration - val W = vocabSize - val alpha = docConcentration - - val N_k = globalTopicTotals - val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = - (edgeContext) => { - // Compute N_{wj} gamma_{wjk} - val N_wj = edgeContext.attr - // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count - // N_{wj}. - val scaledTopicDistribution: TopicCounts = - computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj - edgeContext.sendToDst((false, scaledTopicDistribution)) - edgeContext.sendToSrc((false, scaledTopicDistribution)) - } - // This is a hack to detect whether we could modify the values in-place. - // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) - val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = - (m0, m1) => { - val sum = - if (m0._1) { - m0._2 += m1._2 - } else if (m1._1) { - m1._2 += m0._2 - } else { - m0._2 + m1._2 - } - (true, sum) - } - // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. - val docTopicDistributions: VertexRDD[TopicCounts] = - graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) - .mapValues(_._2) - // Update the vertex descriptors with the new counts. - val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) - graph = newGraph - graphCheckpointer.updateGraph(newGraph) - globalTopicTotals = computeGlobalTopicTotals() - this - } - - /** - * Aggregate distributions over topics from all term vertices. - * - * Note: This executes an action on the graph RDDs. - */ - var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() - - private def computeGlobalTopicTotals(): TopicCounts = { - val numTopics = k - graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) - } - - } - /** * Compute gamma_{wjk}, a distribution over topics k. */ - private def computePTopic( + private[clustering] def computePTopic( docTopicCounts: TopicCounts, termTopicCounts: TopicCounts, totalTopicCounts: TopicCounts, @@ -427,49 +363,4 @@ private[clustering] object LDA { // normalize BDV(gamma_wj) /= sum } - - /** - * Compute bipartite term/doc graph. - */ - private def initialState( - docs: RDD[(Long, Vector)], - k: Int, - docConcentration: Double, - topicConcentration: Double, - randomSeed: Long, - checkpointInterval: Int): EMOptimizer = { - // For each document, create an edge (Document -> Term) for each unique term in the document. - val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => - // Add edges for terms with non-zero counts. - termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => - Edge(docID, term2index(term), cnt) - } - } - - val vocabSize = docs.take(1).head._2.size - - // Create vertices. - // Initially, we use random soft assignments of tokens to topics (random gamma). - def createVertices(): RDD[(VertexId, TopicCounts)] = { - val verticesTMP: RDD[(VertexId, TopicCounts)] = - edges.mapPartitionsWithIndex { case (partIndex, partEdges) => - val random = new Random(partIndex + randomSeed) - partEdges.flatMap { edge => - val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) - val sum = gamma * edge.attr - Seq((edge.srcId, sum), (edge.dstId, sum)) - } - } - verticesTMP.reduceByKey(_ + _) - } - - val docTermVertices = createVertices() - - // Partition such that edges are grouped by document - val graph = Graph(docTermVertices, edges) - .partitionBy(PartitionStrategy.EdgePartition1D) - - new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval) - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 0a3f21ecee0dc..6cf26445f20a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -203,7 +203,7 @@ class DistributedLDAModel private ( import LDA._ - private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = { + private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = { this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, state.topicConcentration, iterationTimes) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala new file mode 100644 index 0000000000000..ffd72a294c6c6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV, normalize} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * + * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can + * hold optimizer-specific parameters for users to set. + */ +@Experimental +trait LDAOptimizer{ + + /* + DEVELOPERS NOTE: + + An LDAOptimizer contains an algorithm for LDA and performs the actual computation, which + stores internal data structure (Graph or Matrix) and other parameters for the algorithm. + The interface is isolated to improve the extensibility of LDA. + */ + + /** + * Initializer for the optimizer. LDA passes the common parameters to the optimizer and + * the internal structure can be initialized properly. + */ + private[clustering] def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): LDAOptimizer + + private[clustering] def next(): LDAOptimizer + + private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel +} + +/** + * :: Experimental :: + * + * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. + * + * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented + * according to the Asuncion et al. (2009) paper referenced below. + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * - This class implements their "smoothed" LDA model. + * - Paper which clearly explains several algorithms, including EM: + * Asuncion, Welling, Smyth, and Teh. + * "On Smoothing and Inference for Topic Models." UAI, 2009. + * + */ +@Experimental +class EMLDAOptimizer extends LDAOptimizer{ + + import LDA._ + + /** + * Following fields will only be initialized through initialState method + */ + private[clustering] var graph: Graph[TopicCounts, TokenCount] = null + private[clustering] var k: Int = 0 + private[clustering] var vocabSize: Int = 0 + private[clustering] var docConcentration: Double = 0 + private[clustering] var topicConcentration: Double = 0 + private[clustering] var checkpointInterval: Int = 10 + private var graphCheckpointer: PeriodicGraphCheckpointer[TopicCounts, TokenCount] = null + + /** + * Compute bipartite term/doc graph. + */ + private[clustering] override def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): LDAOptimizer = { + // For each document, create an edge (Document -> Term) for each unique term in the document. + val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => + // Add edges for terms with non-zero counts. + termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => + Edge(docID, term2index(term), cnt) + } + } + + val vocabSize = docs.take(1).head._2.size + + // Create vertices. + // Initially, we use random soft assignments of tokens to topics (random gamma). + def createVertices(): RDD[(VertexId, TopicCounts)] = { + val verticesTMP: RDD[(VertexId, TopicCounts)] = + edges.mapPartitionsWithIndex { case (partIndex, partEdges) => + val random = new Random(partIndex + randomSeed) + partEdges.flatMap { edge => + val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) + val sum = gamma * edge.attr + Seq((edge.srcId, sum), (edge.dstId, sum)) + } + } + verticesTMP.reduceByKey(_ + _) + } + + val docTermVertices = createVertices() + + // Partition such that edges are grouped by document + this.graph = Graph(docTermVertices, edges).partitionBy(PartitionStrategy.EdgePartition1D) + this.k = k + this.vocabSize = vocabSize + this.docConcentration = docConcentration + this.topicConcentration = topicConcentration + this.checkpointInterval = checkpointInterval + this.graphCheckpointer = new + PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval) + this.globalTopicTotals = computeGlobalTopicTotals() + this + } + + private[clustering] override def next(): EMLDAOptimizer = { + require(graph != null, "graph is null, EMLDAOptimizer not initialized.") + + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration + + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = + (edgeContext) => { + // Compute N_{wj} gamma_{wjk} + val N_wj = edgeContext.attr + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count + // N_{wj}. + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj + edgeContext.sendToDst((false, scaledTopicDistribution)) + edgeContext.sendToSrc((false, scaledTopicDistribution)) + } + // This is a hack to detect whether we could modify the values in-place. + // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) + val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = + (m0, m1) => { + val sum = + if (m0._1) { + m0._2 += m1._2 + } else if (m1._1) { + m1._2 += m0._2 + } else { + m0._2 + m1._2 + } + (true, sum) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val docTopicDistributions: VertexRDD[TopicCounts] = + graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) + .mapValues(_._2) + // Update the vertex descriptors with the new counts. + val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + graph = newGraph + graphCheckpointer.updateGraph(newGraph) + globalTopicTotals = computeGlobalTopicTotals() + this + } + + /** + * Aggregate distributions over topics from all term vertices. + * + * Note: This executes an action on the graph RDDs. + */ + private[clustering] var globalTopicTotals: TopicCounts = null + + private def computeGlobalTopicTotals(): TopicCounts = { + val numTopics = k + graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) + } + + private[clustering] override def getLDAModel(iterationTimes: Array[Double]): LDAModel = { + require(graph != null, "graph is null, EMLDAOptimizer not initialized.") + this.graphCheckpointer.deleteAllCheckpoints() + new DistributedLDAModel(this, iterationTimes) + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index dc10aa67c7c1f..fbe171b4b1ab1 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -88,7 +88,7 @@ public void distributedLDAModel() { .setMaxIterations(5) .setSeed(12345); - DistributedLDAModel model = lda.run(corpus); + DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus); // Check: basic parameters LocalLDAModel localModel = model.toLocal(); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index cc747dabb9968..41ec794146c69 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { .setSeed(12345) val corpus = sc.parallelize(tinyCorpus, 2) - val model: DistributedLDAModel = lda.run(corpus) + val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] // Check: basic parameters val localModel = model.toLocal diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7ef363a2f07ad..967961c2bf5c3 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -72,6 +72,10 @@ object MimaExcludes { // SPARK-6703 Add getOrCreate method to SparkContext ProblemFilters.exclude[IncompatibleResultTypeProblem] ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") + )++ Seq( + // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.mllib.clustering.LDA$EMOptimizer") ) case v if v.startsWith("1.3") => From 874a2ca93d095a0dfa1acfdacf0e9d80388c4422 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 27 Apr 2015 21:45:40 -0700 Subject: [PATCH 25/39] [SPARK-7174][Core] Move calling `TaskScheduler.executorHeartbeatReceived` to another thread `HeartbeatReceiver` will call `TaskScheduler.executorHeartbeatReceived`, which is a blocking operation because `TaskScheduler.executorHeartbeatReceived` will call ```Scala blockManagerMaster.driverEndpoint.askWithReply[Boolean]( BlockManagerHeartbeat(blockManagerId), 600 seconds) ``` finally. Even if it asks from a local Actor, it may block the current Akka thread. E.g., the reply may be dispatched to the same thread of the ask operation. So the reply cannot be processed. An extreme case is setting the thread number of Akka dispatch thread pool to 1. jstack log: ``` "sparkDriver-akka.actor.default-dispatcher-14" daemon prio=10 tid=0x00007f2a8c02d000 nid=0x725 waiting on condition [0x00007f2b1d6d0000] java.lang.Thread.State: TIMED_WAITING (parking) at sun.misc.Unsafe.park(Native Method) - parking to wait for <0x00000006197a0868> (a scala.concurrent.impl.Promise$CompletionLatch) at java.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:226) at java.util.concurrent.locks.AbstractQueuedSynchronizer.doAcquireSharedNanos(AbstractQueuedSynchronizer.java:1033) at java.util.concurrent.locks.AbstractQueuedSynchronizer.tryAcquireSharedNanos(AbstractQueuedSynchronizer.java:1326) at scala.concurrent.impl.Promise$DefaultPromise.tryAwait(Promise.scala:208) at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:218) at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223) at scala.concurrent.Await$$anonfun$result$1.apply(package.scala:107) at akka.dispatch.MonitorableThreadFactory$AkkaForkJoinWorkerThread$$anon$3.block(ThreadPoolBuilder.scala:169) at scala.concurrent.forkjoin.ForkJoinPool.managedBlock(ForkJoinPool.java:3640) at akka.dispatch.MonitorableThreadFactory$AkkaForkJoinWorkerThread.blockOn(ThreadPoolBuilder.scala:167) at scala.concurrent.Await$.result(package.scala:107) at org.apache.spark.rpc.RpcEndpointRef.askWithReply(RpcEnv.scala:355) at org.apache.spark.scheduler.DAGScheduler.executorHeartbeatReceived(DAGScheduler.scala:169) at org.apache.spark.scheduler.TaskSchedulerImpl.executorHeartbeatReceived(TaskSchedulerImpl.scala:367) at org.apache.spark.HeartbeatReceiver$$anonfun$receiveAndReply$1.applyOrElse(HeartbeatReceiver.scala:103) at org.apache.spark.rpc.akka.AkkaRpcEnv.org$apache$spark$rpc$akka$AkkaRpcEnv$$processMessage(AkkaRpcEnv.scala:182) at org.apache.spark.rpc.akka.AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1$$anonfun$receiveWithLogging$1$$anonfun$applyOrElse$4.apply$mcV$sp(AkkaRpcEnv.scala:128) at org.apache.spark.rpc.akka.AkkaRpcEnv.org$apache$spark$rpc$akka$AkkaRpcEnv$$safelyCall(AkkaRpcEnv.scala:203) at org.apache.spark.rpc.akka.AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1$$anonfun$receiveWithLogging$1.applyOrElse(AkkaRpcEnv.scala:127) at scala.runtime.AbstractPartialFunction$mcVL$sp.apply$mcVL$sp(AbstractPartialFunction.scala:33) at scala.runtime.AbstractPartialFunction$mcVL$sp.apply(AbstractPartialFunction.scala:33) at scala.runtime.AbstractPartialFunction$mcVL$sp.apply(AbstractPartialFunction.scala:25) at org.apache.spark.util.ActorLogReceive$$anon$1.apply(ActorLogReceive.scala:59) at org.apache.spark.util.ActorLogReceive$$anon$1.apply(ActorLogReceive.scala:42) at scala.PartialFunction$class.applyOrElse(PartialFunction.scala:118) at org.apache.spark.util.ActorLogReceive$$anon$1.applyOrElse(ActorLogReceive.scala:42) at akka.actor.Actor$class.aroundReceive(Actor.scala:465) at org.apache.spark.rpc.akka.AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1.aroundReceive(AkkaRpcEnv.scala:94) at akka.actor.ActorCell.receiveMessage(ActorCell.scala:516) at akka.actor.ActorCell.invoke(ActorCell.scala:487) at akka.dispatch.Mailbox.processMailbox(Mailbox.scala:238) at akka.dispatch.Mailbox.run(Mailbox.scala:220) at akka.dispatch.ForkJoinExecutorConfigurator$AkkaForkJoinTask.exec(AbstractDispatcher.scala:393) at scala.concurrent.forkjoin.ForkJoinTask.doExec(ForkJoinTask.java:260) at scala.concurrent.forkjoin.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:1339) at scala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1979) at scala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:107) ``` This PR moved this blocking operation to a separated thread. Author: zsxwing Closes #5723 from zsxwing/SPARK-7174 and squashes the following commits: 98bfe48 [zsxwing] Use a single thread for checking timeout and reporting executorHeartbeatReceived 5b3b545 [zsxwing] Move calling `TaskScheduler.executorHeartbeatReceived` to another thread to avoid blocking the Akka thread pool --- .../org/apache/spark/HeartbeatReceiver.scala | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 68d05d5b02537..f2b024ff6cb67 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -76,13 +76,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) private var timeoutCheckingTask: ScheduledFuture[_] = null - private val timeoutCheckingThread = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-timeout-checking-thread") + // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not + // block the thread for a long time. + private val eventLoopThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-receiver-event-loop-thread") private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread") override def onStart(): Unit = { - timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { + timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ExpireDeadHosts)) } @@ -99,11 +101,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) executorLastSeen(executorId) = System.currentTimeMillis() - context.reply(response) + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to @@ -125,7 +131,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) if (sc.supportDynamicAllocation) { // Asynchronously kill the executor to avoid blocking the current thread killExecutorThread.submit(new Runnable { - override def run(): Unit = sc.killExecutor(executorId) + override def run(): Unit = Utils.tryLogNonFatalError { + sc.killExecutor(executorId) + } }) } executorLastSeen.remove(executorId) @@ -137,7 +145,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel(true) } - timeoutCheckingThread.shutdownNow() + eventLoopThread.shutdownNow() killExecutorThread.shutdownNow() } } From 29576e786072bd4218e10036ddfc8d367b1c1446 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 27 Apr 2015 23:10:14 -0700 Subject: [PATCH 26/39] [SPARK-6829] Added math functions for DataFrames Implemented almost all math functions found in scala.math (max, min and abs were already present). cc mengxr marmbrus Author: Burak Yavuz Closes #5616 from brkyvz/math-udfs and squashes the following commits: fb27153 [Burak Yavuz] reverted exception message 836a098 [Burak Yavuz] fixed test and addressed small comment e5f0d13 [Burak Yavuz] addressed code review v2.2 b26c5fb [Burak Yavuz] addressed review v2.1 2761f08 [Burak Yavuz] addressed review v2 6588a5b [Burak Yavuz] fixed merge conflicts b084e10 [Burak Yavuz] Addressed code review 029e739 [Burak Yavuz] fixed atan2 test 534cc11 [Burak Yavuz] added more tests, addressed comments fa68dbe [Burak Yavuz] added double specific test data 937d5a5 [Burak Yavuz] use doubles instead of ints 8e28fff [Burak Yavuz] Added apache header 7ec8f7f [Burak Yavuz] Added math functions for DataFrames --- .../catalyst/analysis/HiveTypeCoercion.scala | 19 + .../sql/catalyst/expressions/Expression.scala | 10 + .../expressions/mathfuncs/binary.scala | 93 +++ .../expressions/mathfuncs/unary.scala | 168 ++++++ .../ExpressionEvaluationSuite.scala | 165 +++++ .../org/apache/spark/sql/mathfunctions.scala | 562 ++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 9 + .../spark/sql/ColumnExpressionSuite.scala | 1 - .../spark/sql/MathExpressionsSuite.scala | 233 ++++++++ 9 files changed, 1259 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 35c7f00d4e42a..73c9a1c7afdad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -79,6 +79,7 @@ trait HiveTypeCoercion { CaseWhenCoercion :: Division :: PropagateTypes :: + ExpectedInputConversion :: Nil /** @@ -643,4 +644,22 @@ trait HiveTypeCoercion { } } + /** + * Casts types according to the expected input types for Expressions that have the trait + * `ExpectsInputTypes`. + */ + object ExpectedInputConversion extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => + val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case (child, actual, expected) => + if (actual == expected) child else Cast(child, expected) + } + e.withNewChildren(newC) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4e3bbc06a5b4c..1d71c1b4b0c7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -109,3 +109,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException } + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait ExpectsInputTypes { + + def expectedChildTypes: Seq[DataType] + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala new file mode 100644 index 0000000000000..5b4d912a64f71 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} +import org.apache.spark.sql.types._ + +/** + * A binary expression specifically for math functions that take two `Double`s as input and returns + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) + extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + type EvaluatedType = Any + override def symbol: String = null + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + + override def nullable: Boolean = left.nullable || right.nullable + override def toString: String = s"$name($left, $right)" + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType && + !DecimalType.isFixed(left.dataType) + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } + } +} + +case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") + +case class Hypot( + left: Expression, + right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") + +case class Atan2( + left: Expression, + right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala new file mode 100644 index 0000000000000..96cb77d487529 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} +import org.apache.spark.sql.types._ + +/** + * A unary expression specifically for math functions. Math Functions expect a specific type of + * input format, therefore these functions extend `ExpectsInputTypes`. + * @param name The short name of the function + */ +abstract class MathematicalExpression(name: String) + extends UnaryExpression with Serializable with ExpectsInputTypes { + self: Product => + type EvaluatedType = Any + + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"$name($child)" +} + +/** + * A unary expression specifically for math functions that take a `Double` as input and return + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForDouble(f: Double => Double, name: String) + extends MathematicalExpression(name) { self: Product => + + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val result = f(evalE.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } +} + +/** + * A unary expression specifically for math functions that take an `Int` as input and return + * an `Int`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForInt(f: Int => Int, name: String) + extends MathematicalExpression(name) { self: Product => + + override def dataType: DataType = IntegerType + override def expectedChildTypes: Seq[DataType] = Seq(IntegerType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) null else f(evalE.asInstanceOf[Int]) + } +} + +/** + * A unary expression specifically for math functions that take a `Float` as input and return + * a `Float`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForFloat(f: Float => Float, name: String) + extends MathematicalExpression(name) { self: Product => + + override def dataType: DataType = FloatType + override def expectedChildTypes: Seq[DataType] = Seq(FloatType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val result = f(evalE.asInstanceOf[Float]) + if (result.isNaN) null else result + } + } +} + +/** + * A unary expression specifically for math functions that take a `Long` as input and return + * a `Long`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForLong(f: Long => Long, name: String) + extends MathematicalExpression(name) { self: Product => + + override def dataType: DataType = LongType + override def expectedChildTypes: Seq[DataType] = Seq(LongType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) null else f(evalE.asInstanceOf[Long]) + } +} + +case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin, "SIN") + +case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin, "ASIN") + +case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh, "SINH") + +case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos, "COS") + +case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos, "ACOS") + +case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh, "COSH") + +case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan, "TAN") + +case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan, "ATAN") + +case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh, "TANH") + +case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil, "CEIL") + +case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor, "FLOOR") + +case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint, "ROUND") + +case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt, "CBRT") + +case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum, "SIGNUM") + +case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum, "ISIGNUM") + +case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum, "FSIGNUM") + +case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum, "LSIGNUM") + +case class ToDegrees(child: Expression) + extends MathematicalExpressionForDouble(math.toDegrees, "DEGREES") + +case class ToRadians(child: Expression) + extends MathematicalExpressionForDouble(math.toRadians, "RADIANS") + +case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log, "LOG") + +case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10, "LOG10") + +case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p, "LOG1P") + +case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp, "EXP") + +case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1, "EXPM1") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 76298f03c94ae..5390ce43c6639 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.types._ @@ -1152,6 +1153,170 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @tparam T Generic type for primitives + */ + def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( + c: Expression => Expression, + f: T => T, + domain: Iterable[T] = (-20 to 20).map(_ * 0.1), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { value => + checkEvaluation(c(Literal(value)), null, EmptyRow) + } + } else { + domain.foreach { value => + checkEvaluation(c(Literal(value)), f(value), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("sin") { + unaryMathFunctionEvaluation(Sin, math.sin) + } + + test("asin") { + unaryMathFunctionEvaluation(Asin, math.asin, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Asin, math.asin, (11 to 20).map(_ * 0.1), true) + } + + test("sinh") { + unaryMathFunctionEvaluation(Sinh, math.sinh) + } + + test("cos") { + unaryMathFunctionEvaluation(Cos, math.cos) + } + + test("acos") { + unaryMathFunctionEvaluation(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Acos, math.acos, (11 to 20).map(_ * 0.1), true) + } + + test("cosh") { + unaryMathFunctionEvaluation(Cosh, math.cosh) + } + + test("tan") { + unaryMathFunctionEvaluation(Tan, math.tan) + } + + test("atan") { + unaryMathFunctionEvaluation(Atan, math.atan) + } + + test("tanh") { + unaryMathFunctionEvaluation(Tanh, math.tanh) + } + + test("toDeg") { + unaryMathFunctionEvaluation(ToDegrees, math.toDegrees) + } + + test("toRad") { + unaryMathFunctionEvaluation(ToRadians, math.toRadians) + } + + test("cbrt") { + unaryMathFunctionEvaluation(Cbrt, math.cbrt) + } + + test("ceil") { + unaryMathFunctionEvaluation(Ceil, math.ceil) + } + + test("floor") { + unaryMathFunctionEvaluation(Floor, math.floor) + } + + test("rint") { + unaryMathFunctionEvaluation(Rint, math.rint) + } + + test("exp") { + unaryMathFunctionEvaluation(Exp, math.exp) + } + + test("expm1") { + unaryMathFunctionEvaluation(Expm1, math.expm1) + } + + test("signum") { + unaryMathFunctionEvaluation[Double](Signum, math.signum) + } + + test("isignum") { + unaryMathFunctionEvaluation[Int](ISignum, math.signum, (-5 to 5)) + } + + test("fsignum") { + unaryMathFunctionEvaluation[Float](FSignum, math.signum, (-5 to 5).map(_.toFloat)) + } + + test("lsignum") { + unaryMathFunctionEvaluation[Long](LSignum, math.signum, (5 to 5).map(_.toLong)) + } + + test("log") { + unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true) + } + + test("log10") { + unaryMathFunctionEvaluation(Log10, math.log10, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log10, math.log10, (-5 to -1).map(_ * 0.1), true) + } + + test("log1p") { + unaryMathFunctionEvaluation(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), true) + } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + */ + def binaryMathFunctionEvaluation( + c: (Expression, Expression) => Expression, + f: (Double, Double) => Double, + domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), null, create_row(null)) + } + } else { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(c(v2, v1), f(v2 + 0.0, v1 + 0.0), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType), 1.0), null, create_row(null)) + checkEvaluation(c(1.0, Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("pow") { + binaryMathFunctionEvaluation(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + binaryMathFunctionEvaluation(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), true) + } + + test("hypot") { + binaryMathFunctionEvaluation(Hypot, math.hypot) + } + + test("atan2") { + binaryMathFunctionEvaluation(Atan2, math.atan2) + } } // TODO: Make the tests work with codegen. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala new file mode 100644 index 0000000000000..84f62bf47f955 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -0,0 +1,562 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ +import org.apache.spark.sql.functions.lit + +/** + * :: Experimental :: + * Mathematical Functions available for [[DataFrame]]. + * + * @groupname double_funcs Functions that require DoubleType as an input + * @groupname int_funcs Functions that require IntegerType as an input + * @groupname float_funcs Functions that require FloatType as an input + * @groupname long_funcs Functions that require LongType as an input + */ +@Experimental +// scalastyle:off +object mathfunctions { +// scalastyle:on + + private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + + /** + * Computes the sine of the given value. + * + * @group double_funcs + */ + def sin(e: Column): Column = Sin(e.expr) + + /** + * Computes the sine of the given column. + * + * @group double_funcs + */ + def sin(columnName: String): Column = sin(Column(columnName)) + + /** + * Computes the sine inverse of the given value; the returned angle is in the range + * -pi/2 through pi/2. + * + * @group double_funcs + */ + def asin(e: Column): Column = Asin(e.expr) + + /** + * Computes the sine inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2. + * + * @group double_funcs + */ + def asin(columnName: String): Column = asin(Column(columnName)) + + /** + * Computes the hyperbolic sine of the given value. + * + * @group double_funcs + */ + def sinh(e: Column): Column = Sinh(e.expr) + + /** + * Computes the hyperbolic sine of the given column. + * + * @group double_funcs + */ + def sinh(columnName: String): Column = sinh(Column(columnName)) + + /** + * Computes the cosine of the given value. + * + * @group double_funcs + */ + def cos(e: Column): Column = Cos(e.expr) + + /** + * Computes the cosine of the given column. + * + * @group double_funcs + */ + def cos(columnName: String): Column = cos(Column(columnName)) + + /** + * Computes the cosine inverse of the given value; the returned angle is in the range + * 0.0 through pi. + * + * @group double_funcs + */ + def acos(e: Column): Column = Acos(e.expr) + + /** + * Computes the cosine inverse of the given column; the returned angle is in the range + * 0.0 through pi. + * + * @group double_funcs + */ + def acos(columnName: String): Column = acos(Column(columnName)) + + /** + * Computes the hyperbolic cosine of the given value. + * + * @group double_funcs + */ + def cosh(e: Column): Column = Cosh(e.expr) + + /** + * Computes the hyperbolic cosine of the given column. + * + * @group double_funcs + */ + def cosh(columnName: String): Column = cosh(Column(columnName)) + + /** + * Computes the tangent of the given value. + * + * @group double_funcs + */ + def tan(e: Column): Column = Tan(e.expr) + + /** + * Computes the tangent of the given column. + * + * @group double_funcs + */ + def tan(columnName: String): Column = tan(Column(columnName)) + + /** + * Computes the tangent inverse of the given value. + * + * @group double_funcs + */ + def atan(e: Column): Column = Atan(e.expr) + + /** + * Computes the tangent inverse of the given column. + * + * @group double_funcs + */ + def atan(columnName: String): Column = atan(Column(columnName)) + + /** + * Computes the hyperbolic tangent of the given value. + * + * @group double_funcs + */ + def tanh(e: Column): Column = Tanh(e.expr) + + /** + * Computes the hyperbolic tangent of the given column. + * + * @group double_funcs + */ + def tanh(columnName: String): Column = tanh(Column(columnName)) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * @group double_funcs + */ + def toDeg(e: Column): Column = ToDegrees(e.expr) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * @group double_funcs + */ + def toDeg(columnName: String): Column = toDeg(Column(columnName)) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * + * @group double_funcs + */ + def toRad(e: Column): Column = ToRadians(e.expr) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * + * @group double_funcs + */ + def toRad(columnName: String): Column = toRad(Column(columnName)) + + /** + * Computes the ceiling of the given value. + * + * @group double_funcs + */ + def ceil(e: Column): Column = Ceil(e.expr) + + /** + * Computes the ceiling of the given column. + * + * @group double_funcs + */ + def ceil(columnName: String): Column = ceil(Column(columnName)) + + /** + * Computes the floor of the given value. + * + * @group double_funcs + */ + def floor(e: Column): Column = Floor(e.expr) + + /** + * Computes the floor of the given column. + * + * @group double_funcs + */ + def floor(columnName: String): Column = floor(Column(columnName)) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + * + * @group double_funcs + */ + def rint(e: Column): Column = Rint(e.expr) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + * + * @group double_funcs + */ + def rint(columnName: String): Column = rint(Column(columnName)) + + /** + * Computes the cube-root of the given value. + * + * @group double_funcs + */ + def cbrt(e: Column): Column = Cbrt(e.expr) + + /** + * Computes the cube-root of the given column. + * + * @group double_funcs + */ + def cbrt(columnName: String): Column = cbrt(Column(columnName)) + + /** + * Computes the signum of the given value. + * + * @group double_funcs + */ + def signum(e: Column): Column = Signum(e.expr) + + /** + * Computes the signum of the given column. + * + * @group double_funcs + */ + def signum(columnName: String): Column = signum(Column(columnName)) + + /** + * Computes the signum of the given value. For IntegerType. + * + * @group int_funcs + */ + def isignum(e: Column): Column = ISignum(e.expr) + + /** + * Computes the signum of the given column. For IntegerType. + * + * @group int_funcs + */ + def isignum(columnName: String): Column = isignum(Column(columnName)) + + /** + * Computes the signum of the given value. For FloatType. + * + * @group float_funcs + */ + def fsignum(e: Column): Column = FSignum(e.expr) + + /** + * Computes the signum of the given column. For FloatType. + * + * @group float_funcs + */ + def fsignum(columnName: String): Column = fsignum(Column(columnName)) + + /** + * Computes the signum of the given value. For LongType. + * + * @group long_funcs + */ + def lsignum(e: Column): Column = LSignum(e.expr) + + /** + * Computes the signum of the given column. For FloatType. + * + * @group long_funcs + */ + def lsignum(columnName: String): Column = lsignum(Column(columnName)) + + /** + * Computes the natural logarithm of the given value. + * + * @group double_funcs + */ + def log(e: Column): Column = Log(e.expr) + + /** + * Computes the natural logarithm of the given column. + * + * @group double_funcs + */ + def log(columnName: String): Column = log(Column(columnName)) + + /** + * Computes the logarithm of the given value in Base 10. + * + * @group double_funcs + */ + def log10(e: Column): Column = Log10(e.expr) + + /** + * Computes the logarithm of the given value in Base 10. + * + * @group double_funcs + */ + def log10(columnName: String): Column = log10(Column(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + * + * @group double_funcs + */ + def log1p(e: Column): Column = Log1p(e.expr) + + /** + * Computes the natural logarithm of the given column plus one. + * + * @group double_funcs + */ + def log1p(columnName: String): Column = log1p(Column(columnName)) + + /** + * Computes the exponential of the given value. + * + * @group double_funcs + */ + def exp(e: Column): Column = Exp(e.expr) + + /** + * Computes the exponential of the given column. + * + * @group double_funcs + */ + def exp(columnName: String): Column = exp(Column(columnName)) + + /** + * Computes the exponential of the given value minus one. + * + * @group double_funcs + */ + def expm1(e: Column): Column = Expm1(e.expr) + + /** + * Computes the exponential of the given column. + * + * @group double_funcs + */ + def expm1(columnName: String): Column = expm1(Column(columnName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, r: Column): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, rightName: String): Column = pow(Column(leftName), Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, r: Double): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, r: Column): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, rightName: String): Column = + hypot(Column(leftName), Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, r: Double): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, rightName: String): Column = + atan2(Column(leftName), Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index e02c84872c628..e5c9504d21042 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -41,6 +41,7 @@ import java.util.Map; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.mathfunctions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -98,6 +99,14 @@ public void testVarargMethods() { df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); + + // Varargs with mathfunctions + DataFrame df2 = context.table("testData2"); + df2.select(exp("a"), exp("b")); + df2.select(exp(log("a"))); + df2.select(pow("a", "a"), pow("b", 2.0)); + df2.select(pow(col("a"), col("b")), exp("b")); + df2.select(sin("a"), acos("b")); } @Ignore diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 904073b8cb2aa..680b5c636960d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ - class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala new file mode 100644 index 0000000000000..561553cc925cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.{Double => JavaDouble} + +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.mathfunctions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ + +private[this] object MathExpressionsTestData { + + case class DoubleData(a: JavaDouble, b: JavaDouble) + val doubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() + + val nnDoubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() + + case class NullDoubles(a: JavaDouble) + val nullDoubles = + TestSQLContext.sparkContext.parallelize( + NullDoubles(1.0) :: + NullDoubles(2.0) :: + NullDoubles(3.0) :: + NullDoubles(null) :: Nil + ).toDF() +} + +class MathExpressionsSuite extends QueryTest { + + import MathExpressionsTestData._ + + def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + c: Column => Column, + f: T => T): Unit = { + checkAnswer( + doubleData.select(c('a)), + (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c('b)), + (1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a)), + (1 to 10).map(n => Row(f(n * 0.1))) + ) + + if (f(-1) === math.log1p(-1)) { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) + ) + } else { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 10).map(n => Row(null)) + ) + } + + checkAnswer( + nnDoubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testTwoToOneMathFunction( + c: (Column, Column) => Column, + d: (Column, Double) => Column, + f: (Double, Double) => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a, 'a)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + + checkAnswer( + nnDoubleData.select(c('a, 'b)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) + ) + + checkAnswer( + nnDoubleData.select(d('a, 2.0)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), 2.0))) + ) + + checkAnswer( + nnDoubleData.select(d('a, -0.5)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), -0.5))) + ) + + val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) + + checkAnswer( + nullDoubles.select(c('a, 'a)).orderBy('a.asc), + Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + } + + test("sin") { + testOneToOneMathFunction(sin, math.sin) + } + + test("asin") { + testOneToOneMathFunction(asin, math.asin) + } + + test("sinh") { + testOneToOneMathFunction(sinh, math.sinh) + } + + test("cos") { + testOneToOneMathFunction(cos, math.cos) + } + + test("acos") { + testOneToOneMathFunction(acos, math.acos) + } + + test("cosh") { + testOneToOneMathFunction(cosh, math.cosh) + } + + test("tan") { + testOneToOneMathFunction(tan, math.tan) + } + + test("atan") { + testOneToOneMathFunction(atan, math.atan) + } + + test("tanh") { + testOneToOneMathFunction(tanh, math.tanh) + } + + test("toDeg") { + testOneToOneMathFunction(toDeg, math.toDegrees) + } + + test("toRad") { + testOneToOneMathFunction(toRad, math.toRadians) + } + + test("cbrt") { + testOneToOneMathFunction(cbrt, math.cbrt) + } + + test("ceil") { + testOneToOneMathFunction(ceil, math.ceil) + } + + test("floor") { + testOneToOneMathFunction(floor, math.floor) + } + + test("rint") { + testOneToOneMathFunction(rint, math.rint) + } + + test("exp") { + testOneToOneMathFunction(exp, math.exp) + } + + test("expm1") { + testOneToOneMathFunction(expm1, math.expm1) + } + + test("signum") { + testOneToOneMathFunction[Double](signum, math.signum) + } + + test("isignum") { + testOneToOneMathFunction[Int](isignum, math.signum) + } + + test("fsignum") { + testOneToOneMathFunction[Float](fsignum, math.signum) + } + + test("lsignum") { + testOneToOneMathFunction[Long](lsignum, math.signum) + } + + test("pow") { + testTwoToOneMathFunction(pow, pow, math.pow) + } + + test("hypot") { + testTwoToOneMathFunction(hypot, hypot, math.hypot) + } + + test("atan2") { + testTwoToOneMathFunction(atan2, atan2, math.atan2) + } + + test("log") { + testOneToOneNonNegativeMathFunction(log, math.log) + } + + test("log10") { + testOneToOneNonNegativeMathFunction(log10, math.log10) + } + + test("log1p") { + testOneToOneNonNegativeMathFunction(log1p, math.log1p) + } + +} From 9e4e82b7bca1129bcd5e0274b9ae1b1be3fb93da Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 27 Apr 2015 23:48:02 -0700 Subject: [PATCH 27/39] [SPARK-5946] [STREAMING] Add Python API for direct Kafka stream Currently only added `createDirectStream` API, I'm not sure if `createRDD` is also needed, since some Java object needs to be wrapped in Python. Please help to review, thanks a lot. Author: jerryshao Author: Saisai Shao Closes #4723 from jerryshao/direct-kafka-python-api and squashes the following commits: a1fe97c [jerryshao] Fix rebase issue eebf333 [jerryshao] Address the comments da40f4e [jerryshao] Fix Python 2.6 Syntax error issue 5c0ee85 [jerryshao] Style fix 4aeac18 [jerryshao] Fix bug in example code 7146d86 [jerryshao] Add unit test bf3bdd6 [jerryshao] Add more APIs and address the comments f5b3801 [jerryshao] Small style fix 8641835 [Saisai Shao] Rebase and update the code 589c05b [Saisai Shao] Fix the style d6fcb6a [Saisai Shao] Address the comments dfda902 [Saisai Shao] Style fix 0f7d168 [Saisai Shao] Add the doc and fix some style issues 67e6880 [Saisai Shao] Fix test bug 917b0db [Saisai Shao] Add Python createRDD API for Kakfa direct stream c3fc11d [jerryshao] Modify the docs 2c00936 [Saisai Shao] address the comments 3360f44 [jerryshao] Fix code style e0e0f0d [jerryshao] Code clean and bug fix 338c41f [Saisai Shao] Add python API and example for direct kafka stream --- .../streaming/direct_kafka_wordcount.py | 55 ++++++ .../spark/streaming/kafka/KafkaUtils.scala | 92 +++++++++- python/pyspark/streaming/kafka.py | 167 +++++++++++++++++- python/pyspark/streaming/tests.py | 84 ++++++++- 4 files changed, 383 insertions(+), 15 deletions(-) create mode 100644 examples/src/main/python/streaming/direct_kafka_wordcount.py diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py new file mode 100644 index 0000000000000..6ef188a220c51 --- /dev/null +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text directly received from Kafka in every 2 seconds. + Usage: direct_kafka_wordcount.py + + To run this on your local machine, you need to setup Kafka and create a producer first, see + http://kafka.apache.org/documentation.html#quickstart + + and then run the example + `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ + spark-streaming-kafka-assembly-*.jar \ + examples/src/main/python/streaming/direct_kafka_wordcount.py \ + localhost:9092 test` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kafka import KafkaUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: direct_kafka_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") + ssc = StreamingContext(sc, 2) + + brokers, topic = sys.argv[1:] + kvs = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 5a9bd4214cf51..0721ddaf7055a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -21,6 +21,7 @@ import java.lang.{Integer => JInt} import java.lang.{Long => JLong} import java.util.{Map => JMap} import java.util.{Set => JSet} +import java.util.{List => JList} import scala.reflect.ClassTag import scala.collection.JavaConversions._ @@ -234,7 +235,6 @@ object KafkaUtils { new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler) } - /** * Create a RDD from Kafka using offset ranges for each topic and partition. * @@ -558,4 +558,94 @@ private class KafkaUtilsPythonHelper { topics, storageLevel) } + + def createRDD( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jrdd = KafkaUtils.createRDD[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jsc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), + leaders, + messageHandler + ) + new JavaPairRDD(jrdd.rdd) + } + + def createDirectStream( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong] + ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { + + if (!fromOffsets.isEmpty) { + import scala.collection.JavaConversions._ + val topicsFromOffsets = fromOffsets.keySet().map(_.topic) + if (topicsFromOffsets != topics.toSet) { + throw new IllegalStateException(s"The specified topics: ${topics.toSet.mkString(" ")} " + + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") + } + } + + if (fromOffsets.isEmpty) { + KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + kafkaParams, + topics) + } else { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jstream = KafkaUtils.createDirectStream[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + fromOffsets, + messageHandler) + new JavaPairInputDStream(jstream.inputDStream) + } + } + + def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong + ): OffsetRange = OffsetRange.create(topic, partition, fromOffset, untilOffset) + + def createTopicAndPartition(topic: String, partition: JInt): TopicAndPartition = + TopicAndPartition(topic, partition) + + def createBroker(host: String, port: JInt): Broker = Broker(host, port) } diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 8d610d6569b4a..e278b29003f69 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -17,11 +17,12 @@ from py4j.java_gateway import Py4JJavaError +from pyspark.rdd import RDD from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream -__all__ = ['KafkaUtils', 'utf8_decoder'] +__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] def utf8_decoder(s): @@ -67,7 +68,104 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, except Py4JJavaError as e: # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): - print(""" + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create an input stream that directly pulls messages from a Kafka Broker and specific offset. + + This is not a receiver based Kafka input stream, it directly pulls the message from Kafka + in each batch duration and processed without storing. + + This does not use Zookeeper to store offsets. The consumed offsets are tracked + by the stream itself. For interoperability with Kafka monitoring tools that depend on + Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + You can access the offsets used in each batch from the generated RDDs (see + + To recover from driver failures, you have to enable checkpointing in the StreamingContext. + The information on consumed offset can be recovered from the checkpoint. + See the programming guide for details (constraints, etc.). + + :param ssc: StreamingContext object. + :param topics: list of topic_name to consume. + :param kafkaParams: Additional params for Kafka. + :param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting + point of the stream. + :param keyDecoder: A function used to decode key (default is utf8_decoder). + :param valueDecoder: A function used to decode value (default is utf8_decoder). + :return: A DStream object + """ + if not isinstance(topics, list): + raise TypeError("topics should be list") + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + + jfromOffsets = dict([(k._jTopicAndPartition(helper), + v) for (k, v) in fromOffsets.items()]) + jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createRDD(sc, kafkaParams, offsetRanges, leaders={}, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create a RDD from Kafka using offset ranges for each topic and partition. + :param sc: SparkContext object + :param kafkaParams: Additional params for Kafka + :param offsetRanges: list of offsetRange to specify topic:partition:[start, end) to consume + :param leaders: Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty + map, in which case leaders will be looked up on the driver. + :param keyDecoder: A function used to decode key (default is utf8_decoder) + :param valueDecoder: A function used to decode value (default is utf8_decoder) + :return: A RDD object + """ + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + if not isinstance(offsetRanges, list): + raise TypeError("offsetRanges should be list") + + try: + helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] + jleaders = dict([(k._jTopicAndPartition(helper), + v._jBroker(helper)) for (k, v) in leaders.items()]) + jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(sc) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + rdd = RDD(jrdd, sc, ser) + return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def _printErrorMsg(sc): + print(""" ________________________________________________________________________________________________ Spark Streaming's Kafka libraries not found in class path. Try one of the following. @@ -85,8 +183,63 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, ________________________________________________________________________________________________ -""" % (ssc.sparkContext.version, ssc.sparkContext.version)) - raise e - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) - return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) +""" % (sc.version, sc.version)) + + +class OffsetRange(object): + """ + Represents a range of offsets from a single Kafka TopicAndPartition. + """ + + def __init__(self, topic, partition, fromOffset, untilOffset): + """ + Create a OffsetRange to represent range of offsets + :param topic: Kafka topic name. + :param partition: Kafka partition id. + :param fromOffset: Inclusive starting offset. + :param untilOffset: Exclusive ending offset. + """ + self._topic = topic + self._partition = partition + self._fromOffset = fromOffset + self._untilOffset = untilOffset + + def _jOffsetRange(self, helper): + return helper.createOffsetRange(self._topic, self._partition, self._fromOffset, + self._untilOffset) + + +class TopicAndPartition(object): + """ + Represents a specific top and partition for Kafka. + """ + + def __init__(self, topic, partition): + """ + Create a Python TopicAndPartition to map to the Java related object + :param topic: Kafka topic name. + :param partition: Kafka partition id. + """ + self._topic = topic + self._partition = partition + + def _jTopicAndPartition(self, helper): + return helper.createTopicAndPartition(self._topic, self._partition) + + +class Broker(object): + """ + Represent the host and port info for a Kafka broker. + """ + + def __init__(self, host, port): + """ + Create a Python Broker to map to the Java related object. + :param host: Broker's hostname. + :param port: Broker's port. + """ + self._host = host + self._port = port + + def _jBroker(self, helper): + return helper.createBroker(self._host, self._port) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5fa1e5ef081ab..7c06c203455d9 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -21,6 +21,7 @@ import time import operator import tempfile +import random import struct from functools import reduce @@ -35,7 +36,7 @@ from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext -from pyspark.streaming.kafka import KafkaUtils +from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition class PySparkStreamingTestCase(unittest.TestCase): @@ -590,9 +591,27 @@ def tearDown(self): super(KafkaStreamTests, self).tearDown() + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _validateStreamResult(self, sendData, stream): + result = {} + for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), + sum(sendData.values()))): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + + def _validateRddResult(self, sendData, rdd): + result = {} + for i in rdd.map(lambda x: x[1]).collect(): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + def test_kafka_stream(self): """Test the Python Kafka stream API.""" - topic = "topic1" + topic = self._randomTopic() sendData = {"a": 3, "b": 5, "c": 10} self._kafkaTestUtils.createTopic(topic) @@ -601,13 +620,64 @@ def test_kafka_stream(self): stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, {"auto.offset.reset": "smallest"}) + self._validateStreamResult(sendData, stream) - result = {} - for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), - sum(sendData.values()))): - result[i] = result.get(i, 0) + 1 + def test_kafka_direct_stream(self): + """Test the Python direct Kafka stream API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} - self.assertEqual(sendData, result) + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_from_offset(self): + """Test the Python direct Kafka stream API with start offset specified.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + fromOffsets = {TopicAndPartition(topic, 0): long(0)} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd(self): + """Test the Python direct Kafka RDD API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) + self._validateRddResult(sendData, rdd) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_with_leaders(self): + """Test the Python direct Kafka RDD API with leaders.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + address = self._kafkaTestUtils.brokerAddress().split(":") + leaders = {TopicAndPartition(topic, 0): Broker(address[0], int(address[1]))} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) + self._validateRddResult(sendData, rdd) if __name__ == "__main__": unittest.main() From bf35edd9d4b8b11df9f47b6ff43831bc95f06322 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 28 Apr 2015 00:38:14 -0700 Subject: [PATCH 28/39] [SPARK-7187] SerializationDebugger should not crash user code rxin Author: Andrew Or Closes #5734 from andrewor14/ser-deb and squashes the following commits: e8aad6c [Andrew Or] NonFatal 57d0ef4 [Andrew Or] try catch improveException --- .../spark/serializer/SerializationDebugger.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index cecb992579655..5abfa467c0ec8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -23,6 +23,7 @@ import java.security.AccessController import scala.annotation.tailrec import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.Logging @@ -35,8 +36,15 @@ private[serializer] object SerializationDebugger extends Logging { */ def improveException(obj: Any, e: NotSerializableException): NotSerializableException = { if (enableDebugging && reflect != null) { - new NotSerializableException( - e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + try { + new NotSerializableException( + e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + } catch { + case NonFatal(t) => + // Fall back to old exception + logWarning("Exception in serialization debugger", t) + e + } } else { e } From d94cd1a733d5715792e6c4eac87f0d5c81aebbe2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 28 Apr 2015 00:39:08 -0700 Subject: [PATCH 29/39] [SPARK-7135][SQL] DataFrame expression for monotonically increasing IDs. Author: Reynold Xin Closes #5709 from rxin/inc-id and squashes the following commits: 7853611 [Reynold Xin] private sql. a9fda0d [Reynold Xin] Missed a few numbers. 343d896 [Reynold Xin] Self review feedback. a7136cb [Reynold Xin] [SPARK-7135][SQL] DataFrame expression for monotonically increasing IDs. --- python/pyspark/sql/functions.py | 22 +++++++- .../MonotonicallyIncreasingID.scala | 53 +++++++++++++++++++ .../expressions/SparkPartitionID.scala | 6 +-- .../org/apache/spark/sql/functions.scala | 16 ++++++ .../spark/sql/ColumnExpressionSuite.scala | 11 ++++ 5 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f48b7b5d10af7..7b86655d9c82f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -103,8 +103,28 @@ def countDistinct(col, *cols): return Column(jc) +def monotonicallyIncreasingId(): + """A column that generates monotonically increasing 64-bit integers. + + The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + The current implementation puts the partition ID in the upper 31 bits, and the record number + within each partition in the lower 33 bits. The assumption is that the data frame has + less than 1 billion partitions, and each partition has less than 8 billion records. + + As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + This expression would return the following IDs: + 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + + >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) + >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.monotonicallyIncreasingId()) + + def sparkPartitionId(): - """Returns a column for partition ID of the Spark task. + """A column for partition ID of the Spark task. Note that this is indeterministic because it depends on data partitioning and task scheduling. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala new file mode 100644 index 0000000000000..9ac732b55b188 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.expressions + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.expressions.{Row, LeafExpression} +import org.apache.spark.sql.types.{LongType, DataType} + +/** + * Returns monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the lower 33 bits + * represent the record number within each partition. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * Since this expression is stateful, it cannot be a case object. + */ +private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { + + /** + * Record ID within each partition. By being transient, count's value is reset to 0 every time + * we serialize and deserialize it. + */ + @transient private[this] var count: Long = 0L + + override type EvaluatedType = Long + + override def nullable: Boolean = false + + override def dataType: DataType = LongType + + override def eval(input: Row): Long = { + val currentCount = count + count += 1 + (TaskContext.get().partitionId().toLong << 33) + currentCount + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index fe7607c6ac340..c2c6cbd491598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -18,16 +18,14 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.expressions.{Row, Expression} -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Row} import org.apache.spark.sql.types.{IntegerType, DataType} /** * Expression that returns the current partition id of the Spark task. */ -case object SparkPartitionID extends Expression with trees.LeafNode[Expression] { - self: Product => +private[sql] case object SparkPartitionID extends LeafExpression { override type EvaluatedType = Int diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9738fd4f93bad..aa31d04a0cbe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -301,6 +301,22 @@ object functions { */ def lower(e: Column): Column = Lower(e.expr) + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + */ + def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID() + /** * Unary minus, i.e. negate the expression. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 680b5c636960d..2ba5fc21ff57c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -309,6 +309,17 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("monotonicallyIncreasingId") { + // Make sure we have 2 partitions, each with 2 records. + val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + Iterator(Tuple1(1), Tuple1(2)) + }.toDF("a") + checkAnswer( + df.select(monotonicallyIncreasingId()), + Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil + ) + } + test("sparkPartitionId") { val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") checkAnswer( From e13cd86567a43672297bb488088dd8f40ec799bf Mon Sep 17 00:00:00 2001 From: Pei-Lun Lee Date: Tue, 28 Apr 2015 16:50:18 +0800 Subject: [PATCH 30/39] [SPARK-6352] [SQL] Custom parquet output committer Add new config "spark.sql.parquet.output.committer.class" to allow custom parquet output committer and an output committer class specific to use on s3. Fix compilation error introduced by https://github.com/apache/spark/pull/5042. Respect ParquetOutputFormat.ENABLE_JOB_SUMMARY flag. Author: Pei-Lun Lee Closes #5525 from ypcat/spark-6352 and squashes the following commits: 54c6b15 [Pei-Lun Lee] error handling 472870e [Pei-Lun Lee] add back custom parquet output committer ddd0f69 [Pei-Lun Lee] Merge branch 'master' of https://github.com/apache/spark into spark-6352 9ece5c5 [Pei-Lun Lee] compatibility with hadoop 1.x 8413fcd [Pei-Lun Lee] Merge branch 'master' of https://github.com/apache/spark into spark-6352 fe65915 [Pei-Lun Lee] add support for parquet config parquet.enable.summary-metadata e17bf47 [Pei-Lun Lee] Merge branch 'master' of https://github.com/apache/spark into spark-6352 9ae7545 [Pei-Lun Lee] [SPARL-6352] [SQL] Change to allow custom parquet output committer. 0d540b9 [Pei-Lun Lee] [SPARK-6352] [SQL] add license c42468c [Pei-Lun Lee] [SPARK-6352] [SQL] add test case 0fc03ca [Pei-Lun Lee] [SPARK-6532] [SQL] hide class DirectParquetOutputCommitter 769bd67 [Pei-Lun Lee] DirectParquetOutputCommitter f75e261 [Pei-Lun Lee] DirectParquetOutputCommitter --- .../DirectParquetOutputCommitter.scala | 73 +++++++++++++++++++ .../sql/parquet/ParquetTableOperations.scala | 21 ++++++ .../spark/sql/parquet/ParquetIOSuite.scala | 22 ++++++ 3 files changed, 116 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala new file mode 100644 index 0000000000000..f5ce2718bec4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter + +import parquet.Log +import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} + +private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + val LOG = Log.getLog(classOf[ParquetOutputCommitter]) + + override def getWorkPath(): Path = outputPath + override def abortTask(taskContext: TaskAttemptContext): Unit = {} + override def commitTask(taskContext: TaskAttemptContext): Unit = {} + override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true + override def setupJob(jobContext: JobContext): Unit = {} + override def setupTask(taskContext: TaskAttemptContext): Unit = {} + + override def commitJob(jobContext: JobContext) { + val configuration = ContextUtil.getConfiguration(jobContext) + val fileSystem = outputPath.getFileSystem(configuration) + + if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { + try { + val outputStatus = fileSystem.getFileStatus(outputPath) + val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) + try { + ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) + } catch { + case e: Exception => { + LOG.warn("could not write summary file for " + outputPath, e) + val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fileSystem.exists(metadataPath)) { + fileSystem.delete(metadataPath, true) + } + } + } + } catch { + case e: Exception => LOG.warn("could not write summary file for " + outputPath, e) + } + } + + if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) { + try { + val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fileSystem.create(successPath).close() + } catch { + case e: Exception => LOG.warn("could not write success file for " + outputPath, e) + } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index a938b77578686..aded126ea0615 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -381,6 +381,7 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) extends parquet.hadoop.ParquetOutputFormat[Row] { // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} + var committer: OutputCommitter = null // override to choose output filename so not overwrite existing ones override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { @@ -403,6 +404,26 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] } + + // override to create output committer from configuration + override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + if (committer == null) { + val output = getOutputPath(context) + val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", + classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) + val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] + } + committer + } + + // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 + private def getOutputPath(context: TaskAttemptContext): Path = { + context.getConfiguration().get("mapred.output.dir") match { + case null => null + case name => new Path(name) + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 97c0f439acf13..b504842053690 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -381,6 +381,28 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } } + + test("SPARK-6352 DirectParquetOutputCommitter") { + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } + finally { + configuration.set("spark.sql.parquet.output.committer.class", + "parquet.hadoop.ParquetOutputCommitter") + } + } } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { From 7f3b3b7eb7d14767124a28ec0062c4d60d6c16fc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 28 Apr 2015 07:48:34 -0400 Subject: [PATCH 31/39] [SPARK-7168] [BUILD] Update plugin versions in Maven build and centralize versions Update Maven build plugin versions and centralize plugin version management Author: Sean Owen Closes #5720 from srowen/SPARK-7168 and squashes the following commits: 98a8947 [Sean Owen] Make install, deploy plugin versions explicit 4ecf3b2 [Sean Owen] Update Maven build plugin versions and centralize plugin version management --- assembly/pom.xml | 1 - core/pom.xml | 1 - network/common/pom.xml | 1 - pom.xml | 44 ++++++++++++++++++++++++++++++------------ sql/hive/pom.xml | 1 - 5 files changed, 32 insertions(+), 16 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 20593e710dedb..2b4d0a990bf22 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -194,7 +194,6 @@ org.apache.maven.plugins maven-assembly-plugin - 2.4 dist diff --git a/core/pom.xml b/core/pom.xml index 5e89d548cd47f..459ef66712c36 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -478,7 +478,6 @@ org.codehaus.mojo exec-maven-plugin - 1.3.2 sparkr-pkg diff --git a/network/common/pom.xml b/network/common/pom.xml index 22c738bde6d42..0c3147761cfc5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -95,7 +95,6 @@ org.apache.maven.plugins maven-jar-plugin - 2.2 test-jar-on-test-compile diff --git a/pom.xml b/pom.xml index 9fbce1d639d8b..928f5d0f5efad 100644 --- a/pom.xml +++ b/pom.xml @@ -1082,7 +1082,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.3.1 + 1.4 enforce-versions @@ -1105,7 +1105,7 @@ org.codehaus.mojo build-helper-maven-plugin - 1.8 + 1.9.1 net.alchim31.maven @@ -1176,7 +1176,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.1 + 3.3 ${java.version} ${java.version} @@ -1189,7 +1189,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.18 + 2.18.1 @@ -1260,17 +1260,17 @@ org.apache.maven.plugins maven-jar-plugin - 2.4 + 2.6 org.apache.maven.plugins maven-antrun-plugin - 1.7 + 1.8 org.apache.maven.plugins maven-source-plugin - 2.2.1 + 2.4 true @@ -1287,7 +1287,7 @@ org.apache.maven.plugins maven-clean-plugin - 2.5 + 2.6.1 @@ -1305,7 +1305,27 @@ org.apache.maven.plugins maven-javadoc-plugin - 2.10.1 + 2.10.3 + + + org.codehaus.mojo + exec-maven-plugin + 1.4.0 + + + org.apache.maven.plugins + maven-assembly-plugin + 2.5.3 + + + org.apache.maven.plugins + maven-install-plugin + 2.5.2 + + + org.apache.maven.plugins + maven-deploy-plugin + 2.8.2 @@ -1315,7 +1335,7 @@ org.apache.maven.plugins maven-dependency-plugin - 2.9 + 2.10 test-compile @@ -1334,7 +1354,7 @@ org.codehaus.gmavenplus gmavenplus-plugin - 1.2 + 1.5 process-test-classes @@ -1359,7 +1379,7 @@ org.apache.maven.plugins maven-shade-plugin - 2.2 + 2.3 false diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 21dce8d8a565a..e322340094e6f 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -183,7 +183,6 @@ org.apache.maven.plugins maven-dependency-plugin - 2.4 copy-dependencies From 75905c57cd57bc5b650ac5f486580ef8a229b260 Mon Sep 17 00:00:00 2001 From: Jim Carroll Date: Tue, 28 Apr 2015 07:51:02 -0400 Subject: [PATCH 32/39] [SPARK-7100] [MLLIB] Fix persisted RDD leak in GradientBoostTrees This fixes a leak of a persisted RDD where GradientBoostTrees can call persist but never unpersists. Jira: https://issues.apache.org/jira/browse/SPARK-7100 Discussion: http://apache-spark-developers-list.1001551.n3.nabble.com/GradientBoostTrees-leaks-a-persisted-RDD-td11750.html Author: Jim Carroll Closes #5669 from jimfcarroll/gb-unpersist-fix and squashes the following commits: 45f4b03 [Jim Carroll] [SPARK-7100][MLLib] Fix persisted RDD leak in GradientBoostTrees --- .../apache/spark/mllib/tree/GradientBoostedTrees.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 0e31c7ed58df8..deac390130128 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -177,9 +177,10 @@ object GradientBoostedTrees extends Logging { treeStrategy.assertValid() // Cache input - if (input.getStorageLevel == StorageLevel.NONE) { + val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { input.persist(StorageLevel.MEMORY_AND_DISK) - } + true + } else false timer.stop("init") @@ -265,6 +266,9 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + + if (persistedInput) input.unpersist() + if (validate) { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, From 268c419f1586110b90e68f98cd000a782d18828c Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Tue, 28 Apr 2015 07:55:21 -0400 Subject: [PATCH 33/39] [SPARK-6435] spark-shell --jars option does not add all jars to classpath Modified to accept double-quotated args properly in spark-shell.cmd. Author: Masayoshi TSUZUKI Closes #5227 from tsudukim/feature/SPARK-6435-2 and squashes the following commits: ac55787 [Masayoshi TSUZUKI] removed unnecessary argument. 60789a7 [Masayoshi TSUZUKI] Merge branch 'master' of https://github.com/apache/spark into feature/SPARK-6435-2 1fee420 [Masayoshi TSUZUKI] fixed test code for escaping '='. 0d4dc41 [Masayoshi TSUZUKI] - escaped comman and semicolon in CommandBuilderUtils.java - added random string to the temporary filename - double-quotation followed by `cmd /c` did not worked properly - no need to escape `=` by `^` - if double-quoted string ended with `\` like classpath, the last `\` is parsed as the escape charactor and the closing `"` didn't work properly 2a332e5 [Masayoshi TSUZUKI] Merge branch 'master' into feature/SPARK-6435-2 04f4291 [Masayoshi TSUZUKI] [SPARK-6435] spark-shell --jars option does not add all jars to classpath --- bin/spark-class2.cmd | 5 ++++- .../org/apache/spark/launcher/CommandBuilderUtils.java | 9 ++++----- .../src/main/java/org/apache/spark/launcher/Main.java | 6 +----- .../apache/spark/launcher/CommandBuilderUtilsSuite.java | 5 ++++- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 3d068dd3a2739..db09fa27e51a6 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -61,7 +61,10 @@ if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. -for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %*"') do ( +set LAUNCHER_OUTPUT=%temp%\spark-class-launcher-output-%RANDOM%.txt +"%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% +for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do ( set SPARK_CMD=%%i ) +del %LAUNCHER_OUTPUT% %SPARK_CMD% diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 8028e42ffb483..261402856ac5e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -244,7 +244,7 @@ static String quoteForBatchScript(String arg) { boolean needsQuotes = false; for (int i = 0; i < arg.length(); i++) { int c = arg.codePointAt(i); - if (Character.isWhitespace(c) || c == '"' || c == '=') { + if (Character.isWhitespace(c) || c == '"' || c == '=' || c == ',' || c == ';') { needsQuotes = true; break; } @@ -261,15 +261,14 @@ static String quoteForBatchScript(String arg) { quoted.append('"'); break; - case '=': - quoted.append('^'); - break; - default: break; } quoted.appendCodePoint(cp); } + if (arg.codePointAt(arg.length() - 1) == '\\') { + quoted.append("\\"); + } quoted.append("\""); return quoted.toString(); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 206acfb514d86..929b29a49ed70 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -101,12 +101,9 @@ public static void main(String[] argsArray) throws Exception { * The method quotes all arguments so that spaces are handled as expected. Quotes within arguments * are "double quoted" (which is batch for escaping a quote). This page has more details about * quoting and other batch script fun stuff: http://ss64.com/nt/syntax-esc.html - * - * The command is executed using "cmd /c" and formatted in single line, since that's the - * easiest way to consume this from a batch script (see spark-class2.cmd). */ private static String prepareWindowsCommand(List cmd, Map childEnv) { - StringBuilder cmdline = new StringBuilder("cmd /c \""); + StringBuilder cmdline = new StringBuilder(); for (Map.Entry e : childEnv.entrySet()) { cmdline.append(String.format("set %s=%s", e.getKey(), e.getValue())); cmdline.append(" && "); @@ -115,7 +112,6 @@ private static String prepareWindowsCommand(List cmd, Map Date: Tue, 28 Apr 2015 09:46:08 -0700 Subject: [PATCH 34/39] [SPARK-5253] [ML] LinearRegression with L1/L2 (ElasticNet) using OWLQN Author: DB Tsai Author: DB Tsai Closes #4259 from dbtsai/lir and squashes the following commits: a81c201 [DB Tsai] add import org.apache.spark.util.Utils back 9fc48ed [DB Tsai] rebase 2178b63 [DB Tsai] add comments 9988ca8 [DB Tsai] addressed feedback and fixed a bug. TODO: documentation and build another synthetic dataset which can catch the bug fixed in this commit. fcbaefe [DB Tsai] Refactoring 4eb078d [DB Tsai] first commit --- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +- .../spark/ml/param/shared/sharedParams.scala | 34 ++ .../ml/regression/LinearRegression.scala | 304 ++++++++++++++++-- .../apache/spark/mllib/linalg/Vectors.scala | 8 +- .../spark/mllib/optimization/Gradient.scala | 6 +- .../spark/mllib/optimization/LBFGS.scala | 15 +- .../mllib/util/LinearDataGenerator.scala | 43 ++- .../ml/regression/LinearRegressionSuite.scala | 158 +++++++-- 8 files changed, 508 insertions(+), 64 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index e88c48741e99f..3f7e8f5a0b22c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -46,7 +46,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("outputCol", "output column name"), ParamDesc[Int]("checkpointInterval", "checkpoint interval"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), - ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()"))) + ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")), + ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"), + ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms")) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a860b8834cff9..7d2c76d6c62c8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -276,4 +276,38 @@ trait HasSeed extends Params { /** @group getParam */ final def getSeed: Long = getOrDefault(seed) } + +/** + * :: DeveloperApi :: + * Trait for shared param elasticNetParam. + */ +@DeveloperApi +trait HasElasticNetParam extends Params { + + /** + * Param for the ElasticNet mixing parameter. + * @group param + */ + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter") + + /** @group getParam */ + final def getElasticNetParam: Double = getOrDefault(elasticNetParam) +} + +/** + * :: DeveloperApi :: + * Trait for shared param tol. + */ +@DeveloperApi +trait HasTol extends Params { + + /** + * Param for the convergence tolerance for iterative algorithms. + * @group param + */ + final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms") + + /** @group getParam */ + final def getTol: Double = getOrDefault(tol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 26ca7459c4fdf..f92c6816eb54c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -17,21 +17,29 @@ package org.apache.spark.ml.regression +import scala.collection.mutable + +import breeze.linalg.{norm => brzNorm, DenseVector => BDV} +import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction} + import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Params, ParamMap} -import org.apache.spark.ml.param.shared._ -import org.apache.spark.mllib.linalg.{BLAS, Vector} -import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel - +import org.apache.spark.util.StatCounter /** * Params for linear regression. */ private[regression] trait LinearRegressionParams extends RegressorParams - with HasRegParam with HasMaxIter - + with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol /** * :: AlphaComponent :: @@ -42,34 +50,119 @@ private[regression] trait LinearRegressionParams extends RegressorParams class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams { - setDefault(regParam -> 0.1, maxIter -> 100) - - /** @group setParam */ + /** + * Set the regularization parameter. + * Default is 0.0. + * @group setParam + */ def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Set the ElasticNet mixing parameter. + * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + * For 0 < alpha < 1, the penalty is a combination of L1 and L2. + * Default is 0.0 which is an L2 penalty. + * @group setParam + */ + def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) + setDefault(elasticNetParam -> 0.0) - /** @group setParam */ + /** + * Set the maximal number of iterations. + * Default is 100. + * @group setParam + */ def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = { - // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val oldDataset = extractLabeledPoints(dataset, paramMap) + // Extract columns from data. If dataset is persisted, do not persist instances. + val instances = extractLabeledPoints(dataset, paramMap).map { + case LabeledPoint(label: Double, features: Vector) => (label, features) + } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) { - oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + instances.persist(StorageLevel.MEMORY_AND_DISK) + } + + val (summarizer, statCounter) = instances.treeAggregate( + (new MultivariateOnlineSummarizer, new StatCounter))( { + case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), + (label: Double, features: Vector)) => + (summarizer.add(features), statCounter.merge(label)) + }, { + case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), + (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => + (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) + }) + + val numFeatures = summarizer.mean.size + val yMean = statCounter.mean + val yStd = math.sqrt(statCounter.variance) + + val featuresMean = summarizer.mean.toArray + val featuresStd = summarizer.variance.toArray.map(math.sqrt) + + // Since we implicitly do the feature scaling when we compute the cost function + // to improve the convergence, the effective regParam will be changed. + val effectiveRegParam = paramMap(regParam) / yStd + val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam + val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam + + val costFun = new LeastSquaresCostFun(instances, yStd, yMean, + featuresStd, featuresMean, effectiveL2RegParam) + + val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { + new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol)) + } else { + new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol)) + } + + val initialWeights = Vectors.zeros(numFeatures) + val states = + optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector) + + var state = states.next() + val lossHistory = mutable.ArrayBuilder.make[Double] + + while (states.hasNext) { + lossHistory += state.value + state = states.next() + } + lossHistory += state.value + + // TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector. + // The weights are trained in the scaled space; we're converting them back to + // the original space. + val weights = { + val rawWeights = state.x.toArray.clone() + var i = 0 + while (i < rawWeights.length) { + rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 } + i += 1 + } + Vectors.dense(rawWeights) } - // Train model - val lr = new LinearRegressionWithSGD() - lr.optimizer - .setRegParam(paramMap(regParam)) - .setNumIterations(paramMap(maxIter)) - val model = lr.run(oldDataset) - val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept) + // The intercept in R's GLMNET is computed using closed form after the coefficients are + // converged. See the following discussion for detail. + // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + val intercept = yMean - dot(weights, Vectors.dense(featuresMean)) if (handlePersistence) { - oldDataset.unpersist() + instances.unpersist() } - lrm + new LinearRegressionModel(this, paramMap, weights, intercept) } } @@ -88,7 +181,7 @@ class LinearRegressionModel private[ml] ( with LinearRegressionParams { override protected def predict(features: Vector): Double = { - BLAS.dot(features, weights) + intercept + dot(features, weights) + intercept } override protected def copy(): LinearRegressionModel = { @@ -97,3 +190,168 @@ class LinearRegressionModel private[ml] ( m } } + +/** + * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, + * as used in linear regression for samples in sparse or dense vector in a online fashion. + * + * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + + * * Compute gradient and loss for a Least-squared loss function, as used in linear regression. + * This is correct for the averaged least squares loss function (mean squared error) + * L = 1/2n ||A weights-y||^2 + * See also the documentation for the precise formulation. + * + * @param weights weights/coefficients corresponding to features + * + * @param updater Updater to be used to update weights after every iteration. + */ +private class LeastSquaresAggregator( + weights: Vector, + labelStd: Double, + labelMean: Double, + featuresStd: Array[Double], + featuresMean: Array[Double]) extends Serializable { + + private var totalCnt: Long = 0 + private var lossSum = 0.0 + private var diffSum = 0.0 + + private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { + val weightsArray = weights.toArray.clone() + var sum = 0.0 + var i = 0 + while (i < weightsArray.length) { + if (featuresStd(i) != 0.0) { + weightsArray(i) /= featuresStd(i) + sum += weightsArray(i) * featuresMean(i) + } else { + weightsArray(i) = 0.0 + } + i += 1 + } + (weightsArray, -sum + labelMean / labelStd, weightsArray.length) + } + private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + + private val gradientSumArray: Array[Double] = Array.ofDim[Double](dim) + + /** + * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * + * @param label The label for this data point. + * @param data The features for one data point in dense/sparse vector format to be added + * into this aggregator. + * @return This LeastSquaresAggregator object. + */ + def add(label: Double, data: Vector): this.type = { + require(dim == data.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${data.size}.") + + val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + data.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += diff * value / featuresStd(index) + } + } + lossSum += diff * diff / 2.0 + diffSum += diff + } + + totalCnt += 1 + this + } + + /** + * Merge another LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other LeastSquaresAggregator to be merged. + * @return This LeastSquaresAggregator object. + */ + def merge(other: LeastSquaresAggregator): this.type = { + require(dim == other.dim, s"Dimensions mismatch when merging with another " + + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") + + if (other.totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + diffSum += other.diffSum + + var i = 0 + val localThisGradientSumArray = this.gradientSumArray + val localOtherGradientSumArray = other.gradientSumArray + while (i < dim) { + localThisGradientSumArray(i) += localOtherGradientSumArray(i) + i += 1 + } + } + this + } + + def count: Long = totalCnt + + def loss: Double = lossSum / totalCnt + + def gradient: Vector = { + val result = Vectors.dense(gradientSumArray.clone()) + + val correction = { + val temp = effectiveWeightsArray.clone() + var i = 0 + while (i < temp.length) { + temp(i) *= featuresMean(i) + i += 1 + } + Vectors.dense(temp) + } + + axpy(-diffSum, correction, result) + scal(1.0 / totalCnt, result) + result + } +} + +/** + * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost. + * It returns the loss and gradient with L2 regularization at a particular point (weights). + * It's used in Breeze's convex optimization routines. + */ +private class LeastSquaresCostFun( + data: RDD[(Double, Vector)], + labelStd: Double, + labelMean: Double, + featuresStd: Array[Double], + featuresMean: Array[Double], + effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { + + override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + val w = Vectors.fromBreeze(weights) + + val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, + labelMean, featuresStd, featuresMean))( + seqOp = (c, v) => (c, v) match { + case (aggregator, (label, features)) => aggregator.add(label, features) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + // regVal is the sum of weight squares for L2 regularization + val norm = brzNorm(weights, 2.0) + val regVal = 0.5 * effectiveL2regParam * norm * norm + + val loss = leastSquaresAggregator.loss + regVal + val gradient = leastSquaresAggregator.gradient + axpy(effectiveL2regParam, w, gradient) + + (loss, gradient.toBreeze.asInstanceOf[BDV[Double]]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 166c00cff634d..af0cfe22ca10d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -85,7 +85,7 @@ sealed trait Vector extends Serializable { /** * Converts the instance to a breeze vector. */ - private[mllib] def toBreeze: BV[Double] + private[spark] def toBreeze: BV[Double] /** * Gets the value of the ith element. @@ -284,7 +284,7 @@ object Vectors { /** * Creates a vector instance from a breeze vector. */ - private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { + private[spark] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { @@ -483,7 +483,7 @@ class DenseVector(val values: Array[Double]) extends Vector { override def toArray: Array[Double] = values - private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) + private[spark] override def toBreeze: BV[Double] = new BDV[Double](values) override def apply(i: Int): Double = values(i) @@ -543,7 +543,7 @@ class SparseVector( new SparseVector(size, indices.clone(), values.clone()) } - private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + private[spark] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 8bfa0d2b64995..240baeb5a158b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -37,7 +37,11 @@ abstract class Gradient extends Serializable { * * @return (gradient: Vector, loss: Double) */ - def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) + def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } /** * Compute the gradient and loss given the features of a single data point, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index ef6eccd90711a..efedc112d380e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.optimization +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseVector => BDV} @@ -164,7 +165,7 @@ object LBFGS extends Logging { regParam: Double, initialWeights: Vector): (Vector, Array[Double]) = { - val lossHistory = new ArrayBuffer[Double](maxNumIterations) + val lossHistory = mutable.ArrayBuilder.make[Double] val numExamples = data.count() @@ -181,17 +182,19 @@ object LBFGS extends Logging { * and regVal is the regularization value computed in the previous iteration as well. */ var state = states.next() - while(states.hasNext) { - lossHistory.append(state.value) + while (states.hasNext) { + lossHistory += state.value state = states.next() } - lossHistory.append(state.value) + lossHistory += state.value val weights = Vectors.fromBreeze(state.x) + val lossHistoryArray = lossHistory.result() + logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format( - lossHistory.takeRight(10).mkString(", "))) + lossHistoryArray.takeRight(10).mkString(", "))) - (weights, lossHistory.toArray) + (weights, lossHistoryArray) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index c9d33787b0bb5..d7bb943e84f53 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -56,6 +56,10 @@ object LinearDataGenerator { } /** + * For compatibility, the generated data without specifying the mean and variance + * will have zero mean and variance of (1.0/3.0) since the original output range is + * [-1, 1] with uniform distribution, and the variance of uniform distribution + * is (b - a)^2^ / 12 which will be (1.0/3.0) * * @param intercept Data intercept * @param weights Weights to be applied. @@ -70,10 +74,47 @@ object LinearDataGenerator { nPoints: Int, seed: Int, eps: Double = 0.1): Seq[LabeledPoint] = { + generateLinearInput(intercept, weights, + Array.fill[Double](weights.size)(0.0), + Array.fill[Double](weights.size)(1.0 / 3.0), + nPoints, seed, eps)} + + /** + * + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param xMean the mean of the generated features. Lots of time, if the features are not properly + * standardized, the algorithm with poor implementation will have difficulty + * to converge. + * @param xVariance the variance of the generated features. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @param eps Epsilon scaling factor. + * @return Seq of input. + */ + def generateLinearInput( + intercept: Double, + weights: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + eps: Double): Seq[LabeledPoint] = { val rnd = new Random(seed) val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0)) + Array.fill[Double](weights.length)(rnd.nextDouble)) + + x.map(vector => { + // This doesn't work if `vector` is a sparse vector. + val vectorArray = vector.toArray + var i = 0 + while (i < vectorArray.size) { + vectorArray(i) = (vectorArray(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + }) + val y = x.map { xi => blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index bbb44c3e2dfc2..80323ef5201a6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -19,47 +19,149 @@ package org.apache.spark.ml.regression import org.scalatest.FunSuite -import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext, DataFrame} class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ + /** + * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML + * is the same as the one trained by R's glmnet package. The following instruction + * describes how to reproduce the data in R. + * + * import org.apache.spark.mllib.util.LinearDataGenerator + * val data = + * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2) + * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path") + */ override def beforeAll(): Unit = { super.beforeAll() sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( - sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2)) + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) } - test("linear regression: default params") { - val lr = new LinearRegression - assert(lr.getLabelCol == "label") - val model = lr.fit(dataset) - model.transform(dataset) - .select("label", "prediction") - .collect() - // Check defaults - assert(model.getFeaturesCol == "features") - assert(model.getPredictionCol == "prediction") + test("linear regression with intercept without regularization") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + /** + * Using the following R code to load the data and train the model using glmnet package. + * + * library("glmnet") + * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + * label <- as.numeric(data$V1) + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.300528 + * as.numeric.data.V2. 4.701024 + * as.numeric.data.V3. 7.198257 + */ + val interceptR = 6.298698 + val weightsR = Array(4.700706, 7.199082) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with L1 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.311546 + * as.numeric.data.V2. 2.123522 + * as.numeric.data.V3. 4.605651 + */ + val interceptR = 6.243000 + val weightsR = Array(4.024821, 6.679841) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } - test("linear regression with setters") { - // Set params, train, and check as many as we can. - val lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(1.0) - val model = lr.fit(dataset) - assert(model.fittingParamMap.get(lr.maxIter).get === 10) - assert(model.fittingParamMap.get(lr.regParam).get === 1.0) - - // Call fit() with new params, and check as many as we can. - val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred") - assert(model2.fittingParamMap.get(lr.maxIter).get === 5) - assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) - assert(model2.getPredictionCol == "thePred") + test("linear regression with intercept with L2 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.328062 + * as.numeric.data.V2. 3.222034 + * as.numeric.data.V3. 4.926260 + */ + val interceptR = 5.269376 + val weightsR = Array(3.736216, 5.712356) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with ElasticNet regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.324108 + * as.numeric.data.V2. 3.168435 + * as.numeric.data.V3. 5.200403 + */ + val interceptR = 5.696056 + val weightsR = Array(3.670489, 6.001122) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } } From b14cd2364932e504695bcc49486ffb4518fdf33d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 09:59:36 -0700 Subject: [PATCH 35/39] [SPARK-7140] [MLLIB] only scan the first 16 entries in Vector.hashCode The Python SerDe calls `Object.hashCode`, which is very expensive for Vectors. It is not necessary to scan the whole vector, especially for large ones. In this PR, we only scan the first 16 nonzeros. srowen Author: Xiangrui Meng Closes #5697 from mengxr/SPARK-7140 and squashes the following commits: 2abc86d [Xiangrui Meng] typo 8fb7d74 [Xiangrui Meng] update impl 1ebad60 [Xiangrui Meng] only scan the first 16 nonzeros in Vector.hashCode --- .../apache/spark/mllib/linalg/Vectors.scala | 88 ++++++++++++++----- 1 file changed, 67 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index af0cfe22ca10d..34833e90d4af0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -52,7 +52,7 @@ sealed trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { - case v2: Vector => { + case v2: Vector => if (this.size != v2.size) return false (this, v2) match { case (s1: SparseVector, s2: SparseVector) => @@ -63,20 +63,28 @@ sealed trait Vector extends Serializable { Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) } - } case _ => false } } + /** + * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros + * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. + */ override def hashCode(): Int = { - var result: Int = size + 31 - this.foreachActive { case (index, value) => - // ignore explict 0 for comparison between sparse and dense - if (value != 0) { - result = 31 * result + index - // refer to {@link java.util.Arrays.equals} for hash algorithm - val bits = java.lang.Double.doubleToLongBits(value) - result = 31 * result + (bits ^ (bits >>> 32)).toInt + // This is a reference implementation. It calls return in foreachActive, which is slow. + // Subclasses should override it with optimized implementation. + var result: Int = 31 + size + this.foreachActive { (index, value) => + if (index < 16) { + // ignore explicit 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + index + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + return result } } result @@ -317,7 +325,7 @@ object Vectors { case SparseVector(n, ids, vs) => vs case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } - val size = values.size + val size = values.length if (p == 1) { var sum = 0.0 @@ -371,8 +379,8 @@ object Vectors { val v1Indices = v1.indices val v2Values = v2.values val v2Indices = v2.indices - val nnzv1 = v1Indices.size - val nnzv2 = v2Indices.size + val nnzv1 = v1Indices.length + val nnzv2 = v2Indices.length var kv1 = 0 var kv2 = 0 @@ -401,7 +409,7 @@ object Vectors { case (DenseVector(vv1), DenseVector(vv2)) => var kv = 0 - val sz = vv1.size + val sz = vv1.length while (kv < sz) { val score = vv1(kv) - vv2(kv) squaredDistance += score * score @@ -422,7 +430,7 @@ object Vectors { var kv2 = 0 val indices = v1.indices var squaredDistance = 0.0 - val nnzv1 = indices.size + val nnzv1 = indices.length val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 @@ -451,8 +459,8 @@ object Vectors { v1Values: Array[Double], v2Indices: IndexedSeq[Int], v2Values: Array[Double]): Boolean = { - val v1Size = v1Values.size - val v2Size = v2Values.size + val v1Size = v1Values.length + val v2Size = v2Values.length var k1 = 0 var k2 = 0 var allEqual = true @@ -493,7 +501,7 @@ class DenseVector(val values: Array[Double]) extends Vector { private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localValues = values while (i < localValuesSize) { @@ -501,6 +509,22 @@ class DenseVector(val values: Array[Double]) extends Vector { i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + var i = 0 + val end = math.min(values.length, 16) + while (i < end) { + val v = values(i) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(values(i)) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + i += 1 + } + result + } } object DenseVector { @@ -522,8 +546,8 @@ class SparseVector( val values: Array[Double]) extends Vector { require(indices.length == values.length, "Sparse vectors require that the dimension of the" + - s" indices match the dimension of the values. You provided ${indices.size} indices and " + - s" ${values.size} values.") + s" indices match the dimension of the values. You provided ${indices.length} indices and " + + s" ${values.length} values.") override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" @@ -547,7 +571,7 @@ class SparseVector( private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localIndices = indices val localValues = values @@ -556,6 +580,28 @@ class SparseVector( i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + val end = values.length + var continue = true + var k = 0 + while ((k < end) & continue) { + val i = indices(k) + if (i < 16) { + val v = values(k) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(v) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + continue = false + } + k += 1 + } + result + } } object SparseVector { From 52ccf1d3739694826915cdf01642bab02958eb78 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 28 Apr 2015 10:24:00 -0700 Subject: [PATCH 36/39] [Core][test][minor] replace try finally block with tryWithSafeFinally Author: Zhang, Liye Closes #5739 from liyezhang556520/trySafeFinally and squashes the following commits: 55683e5 [Zhang, Liye] replace try finally block with tryWithSafeFinally --- .../apache/spark/deploy/history/FsHistoryProviderSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index fcae603c7d18e..9e367a0d9af0d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -224,9 +224,9 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers EventLoggingListener.initEventLog(new FileOutputStream(file)) } val writer = new OutputStreamWriter(bstream, "UTF-8") - try { + Utils.tryWithSafeFinally { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) - } finally { + } { writer.close() } } From 8aab94d8984e9d12194dbda47b2e7d9dbc036889 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Tue, 28 Apr 2015 12:08:18 -0700 Subject: [PATCH 37/39] [SPARK-4286] Add an external shuffle service that can be run as a daemon. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This allows Mesos deployments to use the shuffle service (and implicitly dynamic allocation). It does so by adding a new "main" class and two corresponding scripts in `sbin`: - `sbin/start-shuffle-service.sh` - `sbin/stop-shuffle-service.sh` Specific options can be passed in `SPARK_SHUFFLE_OPTS`. This is picking up work from #3861 /cc tnachen Author: Iulian Dragos Closes #4990 from dragos/feature/external-shuffle-service and squashes the following commits: 6c2b148 [Iulian Dragos] Import order and wrong name fixup. 07804ad [Iulian Dragos] Moved ExternalShuffleService to the `deploy` package + other minor tweaks. 4dc1f91 [Iulian Dragos] Reviewer’s comments: 8145429 [Iulian Dragos] Add an external shuffle service that can be run as a daemon. --- conf/spark-env.sh.template | 3 +- ...ice.scala => ExternalShuffleService.scala} | 59 ++++++++++++++++--- .../apache/spark/deploy/worker/Worker.scala | 13 ++-- docs/job-scheduling.md | 2 +- .../launcher/SparkClassCommandBuilder.java | 4 ++ sbin/start-shuffle-service.sh | 33 +++++++++++ sbin/stop-shuffle-service.sh | 25 ++++++++ 7 files changed, 124 insertions(+), 15 deletions(-) rename core/src/main/scala/org/apache/spark/deploy/{worker/StandaloneWorkerShuffleService.scala => ExternalShuffleService.scala} (59%) create mode 100755 sbin/start-shuffle-service.sh create mode 100755 sbin/stop-shuffle-service.sh diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 67f81d33361e1..43c4288912b18 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -3,7 +3,7 @@ # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. -# Options read when launching programs locally with +# Options read when launching programs locally with # ./bin/run-example or ./bin/spark-submit # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node @@ -39,6 +39,7 @@ # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") +# - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") # - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala similarity index 59% rename from core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala rename to core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index b9798963bab0a..cd16f992a3c0a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.spark.deploy.worker +package org.apache.spark.deploy + +import java.util.concurrent.CountDownLatch import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.network.TransportContext @@ -23,6 +25,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslRpcHandler import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.util.Utils /** * Provides a server from which Executors can read shuffle files (rather than reading directly from @@ -31,8 +34,8 @@ import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler * * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. */ -private[worker] -class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) +private[deploy] +class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) extends Logging { private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) @@ -51,16 +54,58 @@ class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: Secu /** Starts the external shuffle service if the user has configured us to. */ def startIfEnabled() { if (enabled) { - require(server == null, "Shuffle server already started") - logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - server = transportContext.createServer(port) + start() } } + /** Start the external shuffle service */ + def start() { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + server = transportContext.createServer(port) + } + def stop() { - if (enabled && server != null) { + if (server != null) { server.close() server = null } } } + +/** + * A main class for running the external shuffle service. + */ +object ExternalShuffleService extends Logging { + @volatile + private var server: ExternalShuffleService = _ + + private val barrier = new CountDownLatch(1) + + def main(args: Array[String]): Unit = { + val sparkConf = new SparkConf + Utils.loadDefaultSparkProperties(sparkConf) + val securityManager = new SecurityManager(sparkConf) + + // we override this value since this service is started from the command line + // and we assume the user really wants it to be running + sparkConf.set("spark.shuffle.service.enabled", "true") + server = new ExternalShuffleService(sparkConf, securityManager) + server.start() + + installShutdownHook() + + // keep running until the process is terminated + barrier.await() + } + + private def installShutdownHook(): Unit = { + Runtime.getRuntime.addShutdownHook(new Thread("External Shuffle Service shutdown thread") { + override def run() { + logInfo("Shutting down shuffle service.") + server.stop() + barrier.countDown() + } + }) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3ee2eb69e8a4e..8f3cc54051048 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -34,6 +34,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem @@ -61,7 +62,7 @@ private[worker] class Worker( assert (port > 0) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -85,10 +86,10 @@ private[worker] class Worker( private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders - private val CLEANUP_INTERVAL_MILLIS = + private val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - private val APP_DATA_RETENTION_SECS = + private val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) private val testing: Boolean = sys.props.contains("spark.testing") @@ -112,7 +113,7 @@ private[worker] class Worker( } else { new File(sys.env.get("SPARK_HOME").getOrElse(".")) } - + var workDir: File = null val finishedExecutors = new HashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] @@ -122,7 +123,7 @@ private[worker] class Worker( val finishedApps = new HashSet[String] // The shuffle service is not actually started unless configured. - private val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + private val shuffleService = new ExternalShuffleService(conf, securityMgr) private val publicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") @@ -134,7 +135,7 @@ private[worker] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) - + private var registrationRetryTimer: Option[Cancellable] = None var coresUsed = 0 diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 963e88a3e1d8f..8d9c2ba2041b2 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -32,7 +32,7 @@ Resource allocation can be configured as follows, based on the cluster type: * **Standalone mode:** By default, applications submitted to the standalone mode cluster will run in FIFO (first-in-first-out) order, and each application will try to use all available nodes. You can limit the number of nodes an application uses by setting the `spark.cores.max` configuration property in it, - or change the default for applications that don't set this setting through `spark.deploy.defaultCores`. + or change the default for applications that don't set this setting through `spark.deploy.defaultCores`. Finally, in addition to controlling cores, each application's `spark.executor.memory` setting controls its memory use. * **Mesos:** To use static partitioning on Mesos, set the `spark.mesos.coarse` configuration property to `true`, diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index e601a0a19f368..d80abf2a8676e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -69,6 +69,10 @@ public List buildCommand(Map env) throws IOException { } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_SHUFFLE_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; } else if (className.startsWith("org.apache.spark.tools.")) { String sparkHome = getSparkHome(); File toolsDir = new File(join(File.separator, sparkHome, "tools", "target", diff --git a/sbin/start-shuffle-service.sh b/sbin/start-shuffle-service.sh new file mode 100755 index 0000000000000..4fddcf7f95d40 --- /dev/null +++ b/sbin/start-shuffle-service.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Starts the external shuffle server on the machine this script is executed on. +# +# Usage: start-shuffle-server.sh +# +# Use the SPARK_SHUFFLE_OPTS environment variable to set shuffle server configuration. +# + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.ExternalShuffleService 1 diff --git a/sbin/stop-shuffle-service.sh b/sbin/stop-shuffle-service.sh new file mode 100755 index 0000000000000..4cb6891ae27fa --- /dev/null +++ b/sbin/stop-shuffle-service.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stops the external shuffle service on the machine this script is executed on. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.ExternalShuffleService 1 From 2d222fb39dd978e5a33cde6ceb59307cbdf7b171 Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Tue, 28 Apr 2015 12:18:55 -0700 Subject: [PATCH 38/39] [SPARK-5932] [CORE] Use consistent naming for size properties I've added an interface to JavaUtils to do byte conversion and added hooks within Utils.scala to handle conversion within Spark code (like for time strings). I've added matching tests for size conversion, and then updated all deprecated configs and documentation as per SPARK-5933. Author: Ilya Ganelin Closes #5574 from ilganeli/SPARK-5932 and squashes the following commits: 11f6999 [Ilya Ganelin] Nit fixes 49a8720 [Ilya Ganelin] Whitespace fix 2ab886b [Ilya Ganelin] Scala style fc85733 [Ilya Ganelin] Got rid of floating point math 852a407 [Ilya Ganelin] [SPARK-5932] Added much improved overflow handling. Can now handle sizes up to Long.MAX_VALUE Petabytes instead of being capped at Long.MAX_VALUE Bytes 9ee779c [Ilya Ganelin] Simplified fraction matches 22413b1 [Ilya Ganelin] Made MAX private 3dfae96 [Ilya Ganelin] Fixed some nits. Added automatic conversion of old paramter for kryoserializer.mb to new values. e428049 [Ilya Ganelin] resolving merge conflict 8b43748 [Ilya Ganelin] Fixed error in pattern matching for doubles 84a2581 [Ilya Ganelin] Added smoother handling of fractional values for size parameters. This now throws an exception and added a warning for old spark.kryoserializer.buffer d3d09b6 [Ilya Ganelin] [SPARK-5932] Fixing error in KryoSerializer fe286b4 [Ilya Ganelin] Resolved merge conflict c7803cd [Ilya Ganelin] Empty lines 54b78b4 [Ilya Ganelin] Simplified byteUnit class 69e2f20 [Ilya Ganelin] Updates to code f32bc01 [Ilya Ganelin] [SPARK-5932] Fixed error in API in SparkConf.scala where Kb conversion wasn't being done properly (was Mb). Added test cases for both timeUnit and ByteUnit conversion f15f209 [Ilya Ganelin] Fixed conversion of kryo buffer size 0f4443e [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-5932 35a7fa7 [Ilya Ganelin] Minor formatting 928469e [Ilya Ganelin] [SPARK-5932] Converted some longs to ints 5d29f90 [Ilya Ganelin] [SPARK-5932] Finished documentation updates 7a6c847 [Ilya Ganelin] [SPARK-5932] Updated spark.shuffle.file.buffer afc9a38 [Ilya Ganelin] [SPARK-5932] Updated spark.broadcast.blockSize and spark.storage.memoryMapThreshold ae7e9f6 [Ilya Ganelin] [SPARK-5932] Updated spark.io.compression.snappy.block.size 2d15681 [Ilya Ganelin] [SPARK-5932] Updated spark.executor.logs.rolling.size.maxBytes 1fbd435 [Ilya Ganelin] [SPARK-5932] Updated spark.broadcast.blockSize eba4de6 [Ilya Ganelin] [SPARK-5932] Updated spark.shuffle.file.buffer.kb b809a78 [Ilya Ganelin] [SPARK-5932] Updated spark.kryoserializer.buffer.max 0cdff35 [Ilya Ganelin] [SPARK-5932] Updated to use bibibytes in method names. Updated spark.kryoserializer.buffer.mb and spark.reducer.maxMbInFlight 475370a [Ilya Ganelin] [SPARK-5932] Simplified ByteUnit code, switched to using longs. Updated docs to clarify that we use kibi, mebi etc instead of kilo, mega 851d691 [Ilya Ganelin] [SPARK-5932] Updated memoryStringToMb to use new interfaces a9f4fcf [Ilya Ganelin] [SPARK-5932] Added unit tests for unit conversion 747393a [Ilya Ganelin] [SPARK-5932] Added unit tests for ByteString conversion 09ea450 [Ilya Ganelin] [SPARK-5932] Added byte string conversion to Jav utils 5390fd9 [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-5932 db9a963 [Ilya Ganelin] Closing second spark context 1dc0444 [Ilya Ganelin] Added ref equality check 8c884fa [Ilya Ganelin] Made getOrCreate synchronized cb0c6b7 [Ilya Ganelin] Doc updates and code cleanup 270cfe3 [Ilya Ganelin] [SPARK-6703] Documentation fixes 15e8dea [Ilya Ganelin] Updated comments and added MiMa Exclude 0e1567c [Ilya Ganelin] Got rid of unecessary option for AtomicReference dfec4da [Ilya Ganelin] Changed activeContext to AtomicReference 733ec9f [Ilya Ganelin] Fixed some bugs in test code 8be2f83 [Ilya Ganelin] Replaced match with if e92caf7 [Ilya Ganelin] [SPARK-6703] Added test to ensure that getOrCreate both allows creation, retrieval, and a second context if desired a99032f [Ilya Ganelin] Spacing fix d7a06b8 [Ilya Ganelin] Updated SparkConf class to add getOrCreate method. Started test suite implementation --- .../scala/org/apache/spark/SparkConf.scala | 90 ++++++++++++++- .../spark/broadcast/TorrentBroadcast.scala | 3 +- .../apache/spark/io/CompressionCodec.scala | 8 +- .../spark/serializer/KryoSerializer.scala | 17 +-- .../shuffle/FileShuffleBlockManager.scala | 3 +- .../hash/BlockStoreShuffleFetcher.scala | 3 +- .../org/apache/spark/storage/DiskStore.scala | 3 +- .../scala/org/apache/spark/util/Utils.scala | 53 ++++++--- .../collection/ExternalAppendOnlyMap.scala | 6 +- .../util/collection/ExternalSorter.scala | 4 +- .../util/logging/RollingFileAppender.scala | 2 +- .../org/apache/spark/DistributedSuite.scala | 2 +- .../org/apache/spark/SparkConfSuite.scala | 19 ++++ .../KryoSerializerResizableOutputSuite.scala | 8 +- .../serializer/KryoSerializerSuite.scala | 2 +- .../BlockManagerReplicationSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 6 +- .../org/apache/spark/util/UtilsSuite.scala | 100 +++++++++++++++- docs/configuration.md | 60 ++++++---- docs/tuning.md | 2 +- .../spark/examples/mllib/MovieLensALS.scala | 2 +- .../apache/spark/network/util/ByteUnit.java | 67 +++++++++++ .../apache/spark/network/util/JavaUtils.java | 107 ++++++++++++++++-- 23 files changed, 488 insertions(+), 81 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c1996e08756a6..a8fc90ad2050e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -211,7 +211,74 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.timeStringAsMs(get(key, defaultValue)) } + /** + * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then bytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsBytes(key: String): Long = { + Utils.byteStringAsBytes(get(key)) + } + + /** + * Get a size parameter as bytes, falling back to a default if not set. If no + * suffix is provided then bytes are assumed. + */ + def getSizeAsBytes(key: String, defaultValue: String): Long = { + Utils.byteStringAsBytes(get(key, defaultValue)) + } + + /** + * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Kibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsKb(key: String): Long = { + Utils.byteStringAsKb(get(key)) + } + + /** + * Get a size parameter as Kibibytes, falling back to a default if not set. If no + * suffix is provided then Kibibytes are assumed. + */ + def getSizeAsKb(key: String, defaultValue: String): Long = { + Utils.byteStringAsKb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Mebibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsMb(key: String): Long = { + Utils.byteStringAsMb(get(key)) + } + + /** + * Get a size parameter as Mebibytes, falling back to a default if not set. If no + * suffix is provided then Mebibytes are assumed. + */ + def getSizeAsMb(key: String, defaultValue: String): Long = { + Utils.byteStringAsMb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Gibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsGb(key: String): Long = { + Utils.byteStringAsGb(get(key)) + } + /** + * Get a size parameter as Gibibytes, falling back to a default if not set. If no + * suffix is provided then Gibibytes are assumed. + */ + def getSizeAsGb(key: String, defaultValue: String): Long = { + Utils.byteStringAsGb(get(key, defaultValue)) + } + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) @@ -407,7 +474,13 @@ private[spark] object SparkConf extends Logging { "The spark.cache.class property is no longer being used! Specify storage levels using " + "the RDD.persist() method instead."), DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", - "Please use spark.{driver,executor}.userClassPathFirst instead.")) + "Please use spark.{driver,executor}.userClassPathFirst instead."), + DeprecatedConfig("spark.kryoserializer.buffer.mb", "1.4", + "Please use spark.kryoserializer.buffer instead. The default value for " + + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + + "are no longer accepted. To specify the equivalent now, one may use '64k'.") + ) + Map(configs.map { cfg => (cfg.key -> cfg) }:_*) } @@ -432,6 +505,21 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", // Translate old value to a duration, with 10s wait time per try. translation = s => s"${s.toLong * 10}s")), + "spark.reducer.maxSizeInFlight" -> Seq( + AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), + "spark.kryoserializer.buffer" -> + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${s.toDouble * 1000}k")), + "spark.kryoserializer.buffer.max" -> Seq( + AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), + "spark.shuffle.file.buffer" -> Seq( + AlternateConfig("spark.shuffle.file.buffer.kb", "1.4")), + "spark.executor.logs.rolling.maxSize" -> Seq( + AlternateConfig("spark.executor.logs.rolling.size.maxBytes", "1.4")), + "spark.io.compression.snappy.blockSize" -> Seq( + AlternateConfig("spark.io.compression.snappy.block.size", "1.4")), + "spark.io.compression.lz4.blockSize" -> Seq( + AlternateConfig("spark.io.compression.lz4.block.size", "1.4")), "spark.rpc.numRetries" -> Seq( AlternateConfig("spark.akka.num.retries", "1.4")), "spark.rpc.retry.wait" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 23b02e60338fb..a0c9b5e63c744 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -74,7 +74,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } else { None } - blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + // Note: use getSizeAsKb (not bytes) to maintain compatiblity if no units are provided + blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 } setConf(SparkEnv.get.conf) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0709b6d689e86..0756cdb2ed8e6 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -97,7 +97,7 @@ private[spark] object CompressionCodec { /** * :: DeveloperApi :: * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.lz4.block.size`. + * Block size can be configured by `spark.io.compression.lz4.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -107,7 +107,7 @@ private[spark] object CompressionCodec { class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.lz4.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.lz4.blockSize", "32k").toInt new LZ4BlockOutputStream(s, blockSize) } @@ -137,7 +137,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { /** * :: DeveloperApi :: * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.snappy.block.size`. + * Block size can be configured by `spark.io.compression.snappy.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -153,7 +153,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { } override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.snappy.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt new SnappyOutputStream(s, blockSize) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 579fb6624e692..754832b8a4ca7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -49,16 +49,17 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSizeMb = conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) - if (bufferSizeMb >= 2048) { - throw new IllegalArgumentException("spark.kryoserializer.buffer.mb must be less than " + - s"2048 mb, got: + $bufferSizeMb mb.") + private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") + + if (bufferSizeKb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + + s"2048 mb, got: + $bufferSizeKb mb.") } - private val bufferSize = (bufferSizeMb * 1024 * 1024).toInt + private val bufferSize = (bufferSizeKb * 1024).toInt - val maxBufferSizeMb = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) + val maxBufferSizeMb = conf.getSizeAsMb("spark.kryoserializer.buffer.max", "64m").toInt if (maxBufferSizeMb >= 2048) { - throw new IllegalArgumentException("spark.kryoserializer.buffer.max.mb must be less than " + + throw new IllegalArgumentException("spark.kryoserializer.buffer.max must be less than " + s"2048 mb, got: + $maxBufferSizeMb mb.") } private val maxBufferSize = maxBufferSizeMb * 1024 * 1024 @@ -173,7 +174,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + - "increase spark.kryoserializer.buffer.max.mb value.") + "increase spark.kryoserializer.buffer.max value.") } ByteBuffer.wrap(output.toBytes) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 538e150ead05a..e9b4e2b955dc8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -78,7 +78,8 @@ class FileShuffleBlockManager(conf: SparkConf) private val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 7a2c5ae32d98b..80374adc44296 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -79,7 +79,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blockManager, blocksByAddress, serializer, - SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 4b232ae7d3180..1f45956282166 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -31,8 +31,7 @@ import org.apache.spark.util.Utils private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) extends BlockStore(blockManager) with Logging { - val minMemoryMapBytes = blockManager.conf.getLong( - "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L) + val minMemoryMapBytes = blockManager.conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") override def getSize(blockId: BlockId): Long = { diskManager.getFile(blockId.name).length diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 342bc9a06db47..4c028c06a5138 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1020,21 +1020,48 @@ private[spark] object Utils extends Logging { } /** - * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + def byteStringAsBytes(str: String): Long = { + JavaUtils.byteStringAsBytes(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + def byteStringAsKb(str: String): Long = { + JavaUtils.byteStringAsKb(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + def byteStringAsMb(str: String): Long = { + JavaUtils.byteStringAsMb(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m, 500g) to gibibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + def byteStringAsGb(str: String): Long = { + JavaUtils.byteStringAsGb(str) + } + + /** + * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of mebibytes. */ def memoryStringToMb(str: String): Int = { - val lower = str.toLowerCase - if (lower.endsWith("k")) { - (lower.substring(0, lower.length-1).toLong / 1024).toInt - } else if (lower.endsWith("m")) { - lower.substring(0, lower.length-1).toInt - } else if (lower.endsWith("g")) { - lower.substring(0, lower.length-1).toInt * 1024 - } else if (lower.endsWith("t")) { - lower.substring(0, lower.length-1).toInt * 1024 * 1024 - } else {// no suffix, so it's just a number in bytes - (lower.toLong / 1024 / 1024).toInt - } + // Convert to bytes, rather than directly to MB, because when no units are specified the unit + // is assumed to be bytes + (JavaUtils.byteStringAsBytes(str) / 1024 / 1024).toInt } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 30dd7f22e494f..f912049563906 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -89,8 +89,10 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L - - private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val fileBufferSize = + sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 79a695fb62086..ef3cac622505e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -108,7 +108,9 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) // Size of object batches when reading/writing from serializers. diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index e579421676343..7138b4b8e4533 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -138,7 +138,7 @@ private[spark] object RollingFileAppender { val STRATEGY_DEFAULT = "" val INTERVAL_PROPERTY = "spark.executor.logs.rolling.time.interval" val INTERVAL_DEFAULT = "daily" - val SIZE_PROPERTY = "spark.executor.logs.rolling.size.maxBytes" + val SIZE_PROPERTY = "spark.executor.logs.rolling.maxSize" val SIZE_DEFAULT = (1024 * 1024).toString val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles" val DEFAULT_BUFFER_SIZE = 8192 diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 97ea3578aa8ba..96a9c207ad022 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -77,7 +77,7 @@ class DistributedSuite extends FunSuite with Matchers with LocalSparkContext { } test("groupByKey where map output sizes exceed maxMbInFlight") { - val conf = new SparkConf().set("spark.reducer.maxMbInFlight", "1") + val conf = new SparkConf().set("spark.reducer.maxSizeInFlight", "1m") sc = new SparkContext(clusterUrl, "test", conf) // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output // file should be about 2.5 MB diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 272e6af0514e4..68d08e32f9aa4 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -24,11 +24,30 @@ import scala.language.postfixOps import scala.util.{Try, Random} import org.scalatest.FunSuite +import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} import org.apache.spark.util.{RpcUtils, ResetSystemProperties} import com.esotericsoftware.kryo.Kryo class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { + test("Test byteString conversion") { + val conf = new SparkConf() + // Simply exercise the API, we don't need a complete conversion test since that's handled in + // UtilsSuite.scala + assert(conf.getSizeAsBytes("fake","1k") === ByteUnit.KiB.toBytes(1)) + assert(conf.getSizeAsKb("fake","1k") === ByteUnit.KiB.toKiB(1)) + assert(conf.getSizeAsMb("fake","1k") === ByteUnit.KiB.toMiB(1)) + assert(conf.getSizeAsGb("fake","1k") === ByteUnit.KiB.toGiB(1)) + } + + test("Test timeString conversion") { + val conf = new SparkConf() + // Simply exercise the API, we don't need a complete conversion test since that's handled in + // UtilsSuite.scala + assert(conf.getTimeAsMs("fake","1ms") === TimeUnit.MILLISECONDS.toMillis(1)) + assert(conf.getTimeAsSeconds("fake","1000ms") === TimeUnit.MILLISECONDS.toSeconds(1000)) + } + test("loading from system properties") { System.setProperty("spark.test.testProperty", "2") val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index 967c9e9899c9d..da98d09184735 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -33,8 +33,8 @@ class KryoSerializerResizableOutputSuite extends FunSuite { test("kryo without resizable output buffer should fail on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.kryoserializer.buffer.max", "1m") val sc = new SparkContext("local", "test", conf) intercept[SparkException](sc.parallelize(x).collect()) LocalSparkContext.stop(sc) @@ -43,8 +43,8 @@ class KryoSerializerResizableOutputSuite extends FunSuite { test("kryo with resizable output buffer should succeed on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "2") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.kryoserializer.buffer.max", "2m") val sc = new SparkContext("local", "test", conf) assert(sc.parallelize(x).collect() === x) LocalSparkContext.stop(sc) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index b070a54aa989b..1b13559e77cb8 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -269,7 +269,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("serialization buffer overflow reporting") { import org.apache.spark.SparkException - val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max.mb" + val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max" val largeObject = (1 to 1000000).toArray diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index ffa5162a31841..f647200402ecb 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -50,7 +50,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 7d82a7c66ad1a..6957bc72e9903 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -55,7 +55,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val shuffleManager = new HashShuffleManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. @@ -814,14 +814,14 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // be nice to refactor classes involved in disk storage in a way that // allows for easier testing. val blockManager = mock(classOf[BlockManager]) - when(blockManager.conf).thenReturn(conf.clone.set(confKey, 0.toString)) + when(blockManager.conf).thenReturn(conf.clone.set(confKey, "0")) val diskBlockManager = new DiskBlockManager(blockManager, conf) val diskStoreMapped = new DiskStore(blockManager, diskBlockManager) diskStoreMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val mapped = diskStoreMapped.getBytes(blockId).get - when(blockManager.conf).thenReturn(conf.clone.set(confKey, (1000 * 1000).toString)) + when(blockManager.conf).thenReturn(conf.clone.set(confKey, "1m")) val diskStoreNotMapped = new DiskStore(blockManager, diskBlockManager) diskStoreNotMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val notMapped = diskStoreNotMapped.getBytes(blockId).get diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 1ba99803f5a0e..62a3cbcdf69ea 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -23,7 +23,6 @@ import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols import java.util.concurrent.TimeUnit import java.util.Locale -import java.util.PriorityQueue import scala.collection.mutable.ListBuffer import scala.util.Random @@ -35,6 +34,7 @@ import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.network.util.ByteUnit import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { @@ -65,6 +65,10 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.timeStringAsMs("1d") === TimeUnit.DAYS.toMillis(1)) // Test invalid strings + intercept[NumberFormatException] { + Utils.timeStringAsMs("600l") + } + intercept[NumberFormatException] { Utils.timeStringAsMs("This breaks 600s") } @@ -82,6 +86,100 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { } } + test("Test byteString conversion") { + // Test zero + assert(Utils.byteStringAsBytes("0") === 0) + + assert(Utils.byteStringAsGb("1") === 1) + assert(Utils.byteStringAsGb("1g") === 1) + assert(Utils.byteStringAsGb("1023m") === 0) + assert(Utils.byteStringAsGb("1024m") === 1) + assert(Utils.byteStringAsGb("1048575k") === 0) + assert(Utils.byteStringAsGb("1048576k") === 1) + assert(Utils.byteStringAsGb("1k") === 0) + assert(Utils.byteStringAsGb("1t") === ByteUnit.TiB.toGiB(1)) + assert(Utils.byteStringAsGb("1p") === ByteUnit.PiB.toGiB(1)) + + assert(Utils.byteStringAsMb("1") === 1) + assert(Utils.byteStringAsMb("1m") === 1) + assert(Utils.byteStringAsMb("1048575b") === 0) + assert(Utils.byteStringAsMb("1048576b") === 1) + assert(Utils.byteStringAsMb("1023k") === 0) + assert(Utils.byteStringAsMb("1024k") === 1) + assert(Utils.byteStringAsMb("3645k") === 3) + assert(Utils.byteStringAsMb("1024gb") === 1048576) + assert(Utils.byteStringAsMb("1g") === ByteUnit.GiB.toMiB(1)) + assert(Utils.byteStringAsMb("1t") === ByteUnit.TiB.toMiB(1)) + assert(Utils.byteStringAsMb("1p") === ByteUnit.PiB.toMiB(1)) + + assert(Utils.byteStringAsKb("1") === 1) + assert(Utils.byteStringAsKb("1k") === 1) + assert(Utils.byteStringAsKb("1m") === ByteUnit.MiB.toKiB(1)) + assert(Utils.byteStringAsKb("1g") === ByteUnit.GiB.toKiB(1)) + assert(Utils.byteStringAsKb("1t") === ByteUnit.TiB.toKiB(1)) + assert(Utils.byteStringAsKb("1p") === ByteUnit.PiB.toKiB(1)) + + assert(Utils.byteStringAsBytes("1") === 1) + assert(Utils.byteStringAsBytes("1k") === ByteUnit.KiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1m") === ByteUnit.MiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1g") === ByteUnit.GiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1t") === ByteUnit.TiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1p") === ByteUnit.PiB.toBytes(1)) + + // Overflow handling, 1073741824p exceeds Long.MAX_VALUE if converted straight to Bytes + // This demonstrates that we can have e.g 1024^3 PB without overflowing. + assert(Utils.byteStringAsGb("1073741824p") === ByteUnit.PiB.toGiB(1073741824)) + assert(Utils.byteStringAsMb("1073741824p") === ByteUnit.PiB.toMiB(1073741824)) + + // Run this to confirm it doesn't throw an exception + assert(Utils.byteStringAsBytes("9223372036854775807") === 9223372036854775807L) + assert(ByteUnit.PiB.toPiB(9223372036854775807L) === 9223372036854775807L) + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MAX when converted to bytes + Utils.byteStringAsBytes("9223372036854775808") + } + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MAX when converted to TB + ByteUnit.PiB.toTiB(9223372036854775807L) + } + + // Test fractional string + intercept[NumberFormatException] { + Utils.byteStringAsMb("0.064") + } + + // Test fractional string + intercept[NumberFormatException] { + Utils.byteStringAsMb("0.064m") + } + + // Test invalid strings + intercept[NumberFormatException] { + Utils.byteStringAsBytes("500ub") + } + + // Test invalid strings + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This breaks 600b") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This breaks 600") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("600gb This breaks") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This 123mb breaks") + } + } + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") diff --git a/docs/configuration.md b/docs/configuration.md index d587b91124cb8..72105feba4919 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -48,6 +48,17 @@ The following format is accepted: 5d (days) 1y (years) + +Properties that specify a byte size should be configured with a unit of size. +The following format is accepted: + + 1b (bytes) + 1k or 1kb (kibibytes = 1024 bytes) + 1m or 1mb (mebibytes = 1024 kibibytes) + 1g or 1gb (gibibytes = 1024 mebibytes) + 1t or 1tb (tebibytes = 1024 gibibytes) + 1p or 1pb (pebibytes = 1024 tebibytes) + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different @@ -272,12 +283,11 @@ Apart from these, the following properties are also available, and may be useful - + @@ -366,10 +376,10 @@ Apart from these, the following properties are also available, and may be useful
spark.executor.logs.rolling.size.maxBytesspark.executor.logs.rolling.maxSize (none) Set the max size of the file by which the executor logs will be rolled over. - Rolling is disabled by default. Value is set in terms of bytes. - See spark.executor.logs.rolling.maxRetainedFiles + Rolling is disabled by default. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs.
- - + + @@ -403,10 +413,10 @@ Apart from these, the following properties are also available, and may be useful - - + + @@ -582,18 +592,18 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + @@ -641,19 +651,19 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + @@ -698,9 +708,9 @@ Apart from these, the following properties are also available, and may be useful - + @@ -816,9 +826,9 @@ Apart from these, the following properties are also available, and may be useful - + diff --git a/docs/tuning.md b/docs/tuning.md index cbd227868b248..1cb223e74f382 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -60,7 +60,7 @@ val sc = new SparkContext(conf) The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes more advanced registration options, such as adding custom serialization code. -If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` +If your objects are large, you may also need to increase the `spark.kryoserializer.buffer` config property. The default is 2, but this value needs to be large enough to hold the *largest* object you will serialize. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 0bc36ea65e1ab..99588b0984ab2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -100,7 +100,7 @@ object MovieLensALS { val conf = new SparkConf().setAppName(s"MovieLensALS with $params") if (params.kryo) { conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating])) - .set("spark.kryoserializer.buffer.mb", "8") + .set("spark.kryoserializer.buffer", "8m") } val sc = new SparkContext(conf) diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java new file mode 100644 index 0000000000000..36d655017fb0d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.util; + +public enum ByteUnit { + BYTE (1), + KiB (1024L), + MiB ((long) Math.pow(1024L, 2L)), + GiB ((long) Math.pow(1024L, 3L)), + TiB ((long) Math.pow(1024L, 4L)), + PiB ((long) Math.pow(1024L, 5L)); + + private ByteUnit(long multiplier) { + this.multiplier = multiplier; + } + + // Interpret the provided number (d) with suffix (u) as this unit type. + // E.g. KiB.interpret(1, MiB) interprets 1MiB as its KiB representation = 1024k + public long convertFrom(long d, ByteUnit u) { + return u.convertTo(d, this); + } + + // Convert the provided number (d) interpreted as this unit type to unit type (u). + public long convertTo(long d, ByteUnit u) { + if (multiplier > u.multiplier) { + long ratio = multiplier / u.multiplier; + if (Long.MAX_VALUE / ratio < d) { + throw new IllegalArgumentException("Conversion of " + d + " exceeds Long.MAX_VALUE in " + + name() + ". Try a larger unit (e.g. MiB instead of KiB)"); + } + return d * ratio; + } else { + // Perform operations in this order to avoid potential overflow + // when computing d * multiplier + return d / (u.multiplier / multiplier); + } + } + + public double toBytes(long d) { + if (d < 0) { + throw new IllegalArgumentException("Negative size value. Size must be positive: " + d); + } + return d * multiplier; + } + + public long toKiB(long d) { return convertTo(d, KiB); } + public long toMiB(long d) { return convertTo(d, MiB); } + public long toGiB(long d) { return convertTo(d, GiB); } + public long toTiB(long d) { return convertTo(d, TiB); } + public long toPiB(long d) { return convertTo(d, PiB); } + + private final long multiplier; +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index b6fbace509a0e..6b514aaa1290d 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -126,7 +126,7 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static ImmutableMap timeSuffixes = + private static final ImmutableMap timeSuffixes = ImmutableMap.builder() .put("us", TimeUnit.MICROSECONDS) .put("ms", TimeUnit.MILLISECONDS) @@ -137,6 +137,21 @@ private static boolean isSymlink(File file) throws IOException { .put("d", TimeUnit.DAYS) .build(); + private static final ImmutableMap byteSuffixes = + ImmutableMap.builder() + .put("b", ByteUnit.BYTE) + .put("k", ByteUnit.KiB) + .put("kb", ByteUnit.KiB) + .put("m", ByteUnit.MiB) + .put("mb", ByteUnit.MiB) + .put("g", ByteUnit.GiB) + .put("gb", ByteUnit.GiB) + .put("t", ByteUnit.TiB) + .put("tb", ByteUnit.TiB) + .put("p", ByteUnit.PiB) + .put("pb", ByteUnit.PiB) + .build(); + /** * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for * internal use. If no suffix is provided a direct conversion is attempted. @@ -145,16 +160,14 @@ private static long parseTimeString(String str, TimeUnit unit) { String lower = str.toLowerCase().trim(); try { - String suffix; - long val; Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); - if (m.matches()) { - val = Long.parseLong(m.group(1)); - suffix = m.group(2); - } else { + if (!m.matches()) { throw new NumberFormatException("Failed to parse time string: " + str); } + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + // Check for invalid suffixes if (suffix != null && !timeSuffixes.containsKey(suffix)) { throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); @@ -164,7 +177,7 @@ private static long parseTimeString(String str, TimeUnit unit) { return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); } catch (NumberFormatException e) { String timeError = "Time must be specified as seconds (s), " + - "milliseconds (ms), microseconds (us), minutes (m or min) hour (h), or day (d). " + + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + "E.g. 50s, 100ms, or 250us."; throw new NumberFormatException(timeError + "\n" + e.getMessage()); @@ -186,5 +199,83 @@ public static long timeStringAsMs(String str) { public static long timeStringAsSec(String str) { return parseTimeString(str, TimeUnit.SECONDS); } + + /** + * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for + * internal use. If no suffix is provided a direct conversion of the provided default is + * attempted. + */ + private static long parseByteString(String str, ByteUnit unit) { + String lower = str.toLowerCase().trim(); + + try { + Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); + Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); + + if (m.matches()) { + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + + // Check for invalid suffixes + if (suffix != null && !byteSuffixes.containsKey(suffix)) { + throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); + } + + // If suffix is valid use that, otherwise none was provided and use the default passed + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + } else if (fractionMatcher.matches()) { + throw new NumberFormatException("Fractional values are not supported. Input was: " + + fractionMatcher.group(1)); + } else { + throw new NumberFormatException("Failed to parse byte string: " + str); + } + + } catch (NumberFormatException e) { + String timeError = "Size must be specified as bytes (b), " + + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + + "E.g. 50b, 100k, or 250m."; + throw new NumberFormatException(timeError + "\n" + e.getMessage()); + } + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + public static long byteStringAsBytes(String str) { + return parseByteString(str, ByteUnit.BYTE); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + public static long byteStringAsKb(String str) { + return parseByteString(str, ByteUnit.KiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + public static long byteStringAsMb(String str) { + return parseByteString(str, ByteUnit.MiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to gibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + public static long byteStringAsGb(String str) { + return parseByteString(str, ByteUnit.GiB); + } } From 80098109d908b738b43d397e024756ff617d0af4 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 28 Apr 2015 12:33:48 -0700 Subject: [PATCH 39/39] [SPARK-6314] [CORE] handle JsonParseException for history server This is handled in the same way with [SPARK-6197](https://issues.apache.org/jira/browse/SPARK-6197). The result of this PR is that exception showed in history server log will be replaced by a warning, and the application that with un-complete history log file will be listed on history server webUI Author: Zhang, Liye Closes #5736 from liyezhang556520/SPARK-6314 and squashes the following commits: b8d2d88 [Zhang, Liye] handle JsonParseException for history server --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a94ebf6e53750..fb2cbbcccc54b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -333,8 +333,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } try { val appListener = new ApplicationEventListener + val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) - bus.replay(logInput, logPath.toString) + bus.replay(logInput, logPath.toString, !appCompleted) new FsApplicationHistoryInfo( logPath.getName(), appListener.appId.getOrElse(logPath.getName()), @@ -343,7 +344,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis appListener.endTime.getOrElse(-1L), getModificationTime(eventLog).get, appListener.sparkUser.getOrElse(NOT_STARTED), - isApplicationCompleted(eventLog)) + appCompleted) } finally { logInput.close() }
Property NameDefaultMeaning
spark.reducer.maxMbInFlight48spark.reducer.maxSizeInFlight48m - Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since + Maximum size of map outputs to fetch simultaneously from each reduce task. Since each output requires us to create a buffer to receive it, this represents a fixed memory overhead per reduce task, so keep it small unless you have a large amount of memory.
spark.shuffle.file.buffer.kb32spark.shuffle.file.buffer32k - Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers + Size of the in-memory buffer for each shuffle file output stream. These buffers reduce the number of disk seeks and system calls made in creating intermediate shuffle files.
spark.io.compression.lz4.block.size32768spark.io.compression.lz4.blockSize32k - Block size (in bytes) used in LZ4 compression, in the case when LZ4 compression codec + Block size used in LZ4 compression, in the case when LZ4 compression codec is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used.
spark.io.compression.snappy.block.size32768spark.io.compression.snappy.blockSize32k - Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec + Block size used in Snappy compression, in the case when Snappy compression codec is used. Lowering this block size will also lower shuffle memory usage when Snappy is used.
spark.kryoserializer.buffer.max.mb64spark.kryoserializer.buffer.max64m - Maximum allowable size of Kryo serialization buffer, in megabytes. This must be larger than any + Maximum allowable size of Kryo serialization buffer. This must be larger than any object you attempt to serialize. Increase this if you get a "buffer limit exceeded" exception inside Kryo.
spark.kryoserializer.buffer.mb0.064spark.kryoserializer.buffer64k - Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer + Initial size of Kryo's serialization buffer. Note that there will be one buffer per core on each worker. This buffer will grow up to spark.kryoserializer.buffer.max.mb if needed.
Property NameDefaultMeaning
spark.broadcast.blockSize40964m - Size of each piece of a block in kilobytes for TorrentBroadcastFactory. + Size of each piece of a block for TorrentBroadcastFactory. Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, BlockManager might take a performance hit.
spark.storage.memoryMapThreshold20971522m - Size of a block, in bytes, above which Spark memory maps when reading a block from disk. + Size of a block above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory mapping has high overhead for blocks close to or below the page size of the operating system.