diff --git a/.rat-excludes b/.rat-excludes index c0f81b57fe09d..8f2722cbd001f 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -82,3 +82,4 @@ local-1426633911242/* local-1430917381534/* DESCRIPTION NAMESPACE +test_support/* diff --git a/LICENSE b/LICENSE index d6b9ccf07d999..d0cd0dcb4bdb7 100644 --- a/LICENSE +++ b/LICENSE @@ -853,6 +853,52 @@ and Vis.js may be distributed under either license. +======================================================================== +For dagre-d3 (core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js): +======================================================================== +Copyright (c) 2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +======================================================================== +For graphlib-dot (core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js): +======================================================================== +Copyright (c) 2012-2013 Chris Pettitt + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + ======================================================================== BSD-style licenses ======================================================================== @@ -861,7 +907,7 @@ The following components are provided under a BSD-style license. See project lin (BSD 3 Clause) core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.1.15 - https://github.com/jpmml/jpmml-model) - (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.3 - http://jblas.org/) + (BSD 3-clause style license) jblas (org.jblas:jblas:1.2.4 - http://jblas.org/) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) diff --git a/R/README.md b/R/README.md index a6970e39b55f3..d7d65b4f0eca5 100644 --- a/R/README.md +++ b/R/README.md @@ -52,7 +52,7 @@ The SparkR documentation (Rd files and HTML files) are not a part of the source SparkR comes with several sample programs in the `examples/src/main/r` directory. To run one of them, use `./bin/sparkR `. For example: - ./bin/sparkR examples/src/main/r/pi.R local[2] + ./bin/sparkR examples/src/main/r/dataframe.R You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first): @@ -63,5 +63,5 @@ You can also run the unit-tests for SparkR by running (you need to install the [ The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run ``` export YARN_CONF_DIR=/etc/hadoop/conf -./bin/spark-submit --master yarn examples/src/main/r/pi.R 4 +./bin/spark-submit --master yarn examples/src/main/r/dataframe.R ``` diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 819e9a24e5c0e..f9447f6c3288d 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -1,6 +1,9 @@ # Imports from base R importFrom(methods, setGeneric, setMethod, setOldClass) -useDynLib(SparkR, stringHashCode) + +# Disable native libraries till we figure out how to package it +# See SPARKR-7839 +#useDynLib(SparkR, stringHashCode) # S3 methods exported export("sparkR.init") @@ -16,9 +19,11 @@ exportMethods("arrange", "count", "describe", "distinct", + "dropna", "dtypes", "except", "explain", + "fillna", "filter", "first", "group_by", @@ -37,7 +42,7 @@ exportMethods("arrange", "registerTempTable", "rename", "repartition", - "sampleDF", + "sample", "sample_frac", "saveAsParquetFile", "saveAsTable", @@ -53,38 +58,62 @@ exportMethods("arrange", "unpersist", "where", "withColumn", - "withColumnRenamed") + "withColumnRenamed", + "write.df") exportClasses("Column") exportMethods("abs", + "acos", "alias", "approxCountDistinct", "asc", + "asin", + "atan", + "atan2", "avg", "cast", + "cbrt", + "ceiling", "contains", + "cos", + "cosh", "countDistinct", "desc", "endsWith", + "exp", + "expm1", + "floor", "getField", "getItem", + "hypot", "isNotNull", "isNull", "last", "like", + "log", + "log10", + "log1p", "lower", "max", "mean", "min", "n", "n_distinct", + "rint", "rlike", + "sign", + "sin", + "sinh", "sqrt", "startsWith", "substr", "sum", "sumDistinct", + "tan", + "tanh", + "toDegrees", + "toRadians", "upper") exportClasses("GroupedData") @@ -101,6 +130,7 @@ export("cacheTable", "jsonFile", "loadDF", "parquetFile", + "read.df", "sql", "table", "tableNames", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 2705817531019..0af5cb8881e35 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -65,9 +65,9 @@ dataFrame <- function(sdf, isCached = FALSE) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' printSchema(df) #'} setMethod("printSchema", @@ -88,9 +88,9 @@ setMethod("printSchema", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' dfSchema <- schema(df) #'} setMethod("schema", @@ -110,9 +110,9 @@ setMethod("schema", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' explain(df, TRUE) #'} setMethod("explain", @@ -139,9 +139,9 @@ setMethod("explain", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' isLocal(df) #'} setMethod("isLocal", @@ -162,9 +162,9 @@ setMethod("isLocal", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' showDF(df) #'} setMethod("showDF", @@ -185,9 +185,9 @@ setMethod("showDF", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' df #'} setMethod("show", "DataFrame", @@ -210,9 +210,9 @@ setMethod("show", "DataFrame", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' dtypes(df) #'} setMethod("dtypes", @@ -234,9 +234,9 @@ setMethod("dtypes", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' columns(df) #'} setMethod("columns", @@ -267,11 +267,11 @@ setMethod("names", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "json_df") -#' new_df <- sql(sqlCtx, "SELECT * FROM json_df") +#' new_df <- sql(sqlContext, "SELECT * FROM json_df") #'} setMethod("registerTempTable", signature(x = "DataFrame", tableName = "character"), @@ -293,9 +293,9 @@ setMethod("registerTempTable", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- loadDF(sqlCtx, path, "parquet") -#' df2 <- loadDF(sqlCtx, path2, "parquet") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, path, "parquet") +#' df2 <- read.df(sqlContext, path2, "parquet") #' registerTempTable(df, "table1") #' insertInto(df2, "table1", overwrite = TRUE) #'} @@ -316,9 +316,9 @@ setMethod("insertInto", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' cache(df) #'} setMethod("cache", @@ -341,9 +341,9 @@ setMethod("cache", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #'} setMethod("persist", @@ -366,9 +366,9 @@ setMethod("persist", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' persist(df, "MEMORY_AND_DISK") #' unpersist(df) #'} @@ -391,9 +391,9 @@ setMethod("unpersist", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- repartition(df, 2L) #'} setMethod("repartition", @@ -415,9 +415,9 @@ setMethod("repartition", # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # path <- "path/to/file.json" -# df <- jsonFile(sqlCtx, path) +# df <- jsonFile(sqlContext, path) # newRDD <- toJSON(df) #} setMethod("toJSON", @@ -440,9 +440,9 @@ setMethod("toJSON", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' saveAsParquetFile(df, "/tmp/sparkr-tmp/") #'} setMethod("saveAsParquetFile", @@ -461,9 +461,9 @@ setMethod("saveAsParquetFile", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' distinctDF <- distinct(df) #'} setMethod("distinct", @@ -473,26 +473,26 @@ setMethod("distinct", dataFrame(sdf) }) -#' SampleDF +#' Sample #' #' Return a sampled subset of this DataFrame using a random seed. #' #' @param x A SparkSQL DataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction -#' @rdname sampleDF +#' @rdname sample #' @aliases sample_frac #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) -#' collect(sampleDF(df, FALSE, 0.5)) -#' collect(sampleDF(df, TRUE, 0.5)) +#' df <- jsonFile(sqlContext, path) +#' collect(sample(df, FALSE, 0.5)) +#' collect(sample(df, TRUE, 0.5)) #'} -setMethod("sampleDF", +setMethod("sample", # TODO : Figure out how to send integer as java.lang.Long to JVM so # we can send seed as an argument through callJMethod signature(x = "DataFrame", withReplacement = "logical", @@ -503,13 +503,13 @@ setMethod("sampleDF", dataFrame(sdf) }) -#' @rdname sampleDF -#' @aliases sampleDF +#' @rdname sample +#' @aliases sample setMethod("sample_frac", signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), function(x, withReplacement, fraction) { - sampleDF(x, withReplacement, fraction) + sample(x, withReplacement, fraction) }) #' Count @@ -523,9 +523,9 @@ setMethod("sample_frac", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' count(df) #' } setMethod("count", @@ -545,9 +545,9 @@ setMethod("count", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' collected <- collect(df) #' firstName <- collected[[1]]$name #' } @@ -580,9 +580,9 @@ setMethod("collect", #' @examples #' \dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' limitedDF <- limit(df, 10) #' } setMethod("limit", @@ -599,9 +599,9 @@ setMethod("limit", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' take(df, 2) #' } setMethod("take", @@ -626,9 +626,9 @@ setMethod("take", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' head(df) #' } setMethod("head", @@ -647,9 +647,9 @@ setMethod("head", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' first(df) #' } setMethod("first", @@ -669,9 +669,9 @@ setMethod("first", # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # path <- "path/to/file.json" -# df <- jsonFile(sqlCtx, path) +# df <- jsonFile(sqlContext, path) # rdd <- toRDD(df) # } setMethod("toRDD", @@ -938,9 +938,9 @@ setMethod("select", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' selectExpr(df, "col1", "(col2 * 5) as newCol") #' } setMethod("selectExpr", @@ -964,9 +964,9 @@ setMethod("selectExpr", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- withColumn(df, "newCol", df$col1 * 5) #' } setMethod("withColumn", @@ -988,9 +988,9 @@ setMethod("withColumn", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2) #' names(newDF) # Will contain newCol, newCol2 #' } @@ -1024,9 +1024,9 @@ setMethod("mutate", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- withColumnRenamed(df, "col1", "newCol1") #' } setMethod("withColumnRenamed", @@ -1055,9 +1055,9 @@ setMethod("withColumnRenamed", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' newDF <- rename(df, col1 = df$newCol1) #' } setMethod("rename", @@ -1095,9 +1095,9 @@ setClassUnion("characterOrColumn", c("character", "Column")) #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' arrange(df, df$col1) #' arrange(df, "col1") #' arrange(df, asc(df$col1), desc(abs(df$col2))) @@ -1137,9 +1137,9 @@ setMethod("orderBy", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' filter(df, "col1 > 0") #' filter(df, df$col2 != "abcdefg") #' } @@ -1177,9 +1177,9 @@ setMethod("where", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' join(df1, df2) # Performs a Cartesian #' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression #' join(df1, df2, df1$col1 == df2$col2, "right_outer") @@ -1219,9 +1219,9 @@ setMethod("join", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' unioned <- unionAll(df, df2) #' } setMethod("unionAll", @@ -1244,9 +1244,9 @@ setMethod("unionAll", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' intersectDF <- intersect(df, df2) #' } setMethod("intersect", @@ -1269,9 +1269,9 @@ setMethod("intersect", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df1 <- jsonFile(sqlCtx, path) -#' df2 <- jsonFile(sqlCtx, path2) +#' sqlContext <- sparkRSQL.init(sc) +#' df1 <- jsonFile(sqlContext, path) +#' df2 <- jsonFile(sqlContext, path2) #' exceptDF <- except(df, df2) #' } #' @rdname except @@ -1303,23 +1303,22 @@ setMethod("except", #' @param source A name for external data source #' @param mode One of 'append', 'overwrite', 'error', 'ignore' #' -#' @rdname saveAsTable +#' @rdname write.df #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) -#' saveAsTable(df, "myfile") +#' df <- jsonFile(sqlContext, path) +#' write.df(df, "myfile", "parquet", "overwrite") #' } -setMethod("saveDF", - signature(df = "DataFrame", path = 'character', source = 'character', - mode = 'character'), - function(df, path = NULL, source = NULL, mode = "append", ...){ +setMethod("write.df", + signature(df = "DataFrame", path = 'character'), + function(df, path, source = NULL, mode = "append", ...){ if (is.null(source)) { - sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) - source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") @@ -1334,6 +1333,14 @@ setMethod("saveDF", callJMethod(df@sdf, "save", source, jmode, options) }) +#' @rdname write.df +#' @aliases saveDF +#' @export +setMethod("saveDF", + signature(df = "DataFrame", path = 'character'), + function(df, path, source = NULL, mode = "append", ...){ + write.df(df, path, source, mode, ...) + }) #' saveAsTable #' @@ -1362,9 +1369,9 @@ setMethod("saveDF", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' saveAsTable(df, "myfile") #' } setMethod("saveAsTable", @@ -1372,8 +1379,8 @@ setMethod("saveAsTable", mode = 'character'), function(df, tableName, source = NULL, mode="append", ...){ if (is.null(source)) { - sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv) - source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default", + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") @@ -1399,9 +1406,9 @@ setMethod("saveAsTable", #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' describe(df) #' describe(df, "col1") #' describe(df, "col1", "col2") @@ -1422,3 +1429,128 @@ setMethod("describe", sdf <- callJMethod(x@sdf, "describe", listToSeq(colList)) dataFrame(sdf) }) + +#' dropna +#' +#' Returns a new DataFrame omitting rows with null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param how "any" or "all". +#' if "any", drop a row if it contains any nulls. +#' if "all", drop a row only if all its values are null. +#' if minNonNulls is specified, how is ignored. +#' @param minNonNulls If specified, drop rows that have less than +#' minNonNulls non-null values. +#' This overwrites the how parameter. +#' @param cols Optional list of column names to consider. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' dropna(df) +#' } +setMethod("dropna", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + how <- match.arg(how) + if (is.null(cols)) { + cols <- columns(x) + } + if (is.null(minNonNulls)) { + minNonNulls <- if (how == "any") { length(cols) } else { 1 } + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- callJMethod(naFunctions, "drop", + as.integer(minNonNulls), listToSeq(as.list(cols))) + dataFrame(sdf) + }) + +#' @aliases dropna +#' @export +setMethod("na.omit", + signature(x = "DataFrame"), + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + dropna(x, how, minNonNulls, cols) + }) + +#' fillna +#' +#' Replace null values. +#' +#' @param x A SparkSQL DataFrame. +#' @param value Value to replace null values with. +#' Should be an integer, numeric, character or named list. +#' If the value is a named list, then cols is ignored and +#' value must be a mapping from column name (character) to +#' replacement value. The replacement value must be an +#' integer, numeric or character. +#' @param cols optional list of column names to consider. +#' Columns specified in cols that do not have matching data +#' type are ignored. For example, if value is a character, and +#' subset contains a non-character column, then the non-character +#' column is simply ignored. +#' @return A DataFrame +#' +#' @rdname nafunctions +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' path <- "path/to/file.json" +#' df <- jsonFile(sqlCtx, path) +#' fillna(df, 1) +#' fillna(df, list("age" = 20, "name" = "unknown")) +#' } +setMethod("fillna", + signature(x = "DataFrame"), + function(x, value, cols = NULL) { + if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { + stop("value should be an integer, numeric, charactor or named list.") + } + + if (class(value) == "list") { + # Check column names in the named list + colNames <- names(value) + if (length(colNames) == 0 || !all(colNames != "")) { + stop("value should be an a named list with each name being a column name.") + } + + # Convert to the named list to an environment to be passed to JVM + valueMap <- new.env() + for (col in colNames) { + # Check each item in the named list is of valid type + v <- value[[col]] + if (!(class(v) %in% c("integer", "numeric", "character"))) { + stop("Each item in value should be an integer, numeric or charactor.") + } + valueMap[[col]] <- v + } + + # When value is a named list, caller is expected not to pass in cols + if (!is.null(cols)) { + warning("When value is a named list, cols is ignored!") + cols <- NULL + } + + value <- valueMap + } else if (is.integer(value)) { + # Cast an integer to a numeric + value <- as.numeric(value) + } + + naFunctions <- callJMethod(x@sdf, "na") + sdf <- if (length(cols) == 0) { + callJMethod(naFunctions, "fill", value) + } else { + callJMethod(naFunctions, "fill", value, listToSeq(as.list(cols))) + } + dataFrame(sdf) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 9138629cac9c0..0513299515644 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -239,7 +239,7 @@ setMethod("cache", # @aliases persist,RDD-method setMethod("persist", signature(x = "RDD", newLevel = "character"), - function(x, newLevel) { + function(x, newLevel = "MEMORY_ONLY") { callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel)) x@env$isCached <- TRUE x @@ -927,7 +927,7 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", MAXINT))))) # TODO(zongheng): investigate if this call is an in-place shuffle? - sample(samples)[1:total] + base::sample(samples)[1:total] }) # Creates tuples of the elements in this RDD by applying a function. @@ -996,7 +996,7 @@ setMethod("coalesce", if (shuffle || numPartitions > SparkR:::numPartitions(x)) { func <- function(partIndex, part) { set.seed(partIndex) # partIndex as seed - start <- as.integer(sample(numPartitions, 1) - 1) + start <- as.integer(base::sample(numPartitions, 1) - 1) lapply(seq_along(part), function(i) { pos <- (start + i) %% numPartitions diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index cae06e6af2bff..88e1a508f37c4 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -69,7 +69,7 @@ infer_type <- function(x) { #' #' Converts an RDD to a DataFrame by infer the types. #' -#' @param sqlCtx A SQLContext +#' @param sqlContext A SQLContext #' @param data An RDD or list or data.frame #' @param schema a list of column names or named list (StructType), optional #' @return an DataFrame @@ -77,13 +77,13 @@ infer_type <- function(x) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) -#' df <- createDataFrame(sqlCtx, rdd) +#' df <- createDataFrame(sqlContext, rdd) #' } # TODO(davies): support sampling and infer type from NA -createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { +createDataFrame <- function(sqlContext, data, schema = NULL, samplingRatio = 1.0) { if (is.data.frame(data)) { # get the names of columns, they will be put into RDD schema <- names(data) @@ -102,7 +102,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { }) } if (is.list(data)) { - sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlCtx) + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext) rdd <- parallelize(sc, data) } else if (inherits(data, "RDD")) { rdd <- data @@ -146,7 +146,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", - srdd, schema$jobj, sqlCtx) + srdd, schema$jobj, sqlContext) dataFrame(sdf) } @@ -161,7 +161,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x))) # df <- toDF(rdd) # } @@ -170,14 +170,14 @@ setGeneric("toDF", function(x, ...) { standardGeneric("toDF") }) setMethod("toDF", signature(x = "RDD"), function(x, ...) { - sqlCtx <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { + sqlContext <- if (exists(".sparkRHivesc", envir = .sparkREnv)) { get(".sparkRHivesc", envir = .sparkREnv) } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { get(".sparkRSQLsc", envir = .sparkREnv) } else { stop("no SQL context available") } - createDataFrame(sqlCtx, x, ...) + createDataFrame(sqlContext, x, ...) }) #' Create a DataFrame from a JSON file. @@ -185,24 +185,24 @@ setMethod("toDF", signature(x = "RDD"), #' Loads a JSON file (one object per line), returning the result as a DataFrame #' It goes through the entire dataset once to determine the schema. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param path Path of file to read. A vector of multiple paths is allowed. #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' } -jsonFile <- function(sqlCtx, path) { +jsonFile <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path path <- normalizePath(path) # Convert a string vector of paths to a string containing comma separated paths path <- paste(path, collapse = ",") - sdf <- callJMethod(sqlCtx, "jsonFile", path) + sdf <- callJMethod(sqlContext, "jsonFile", path) dataFrame(sdf) } @@ -211,7 +211,7 @@ jsonFile <- function(sqlCtx, path) { # # Loads an RDD storing one JSON object per string as a DataFrame. # -# @param sqlCtx SQLContext to use +# @param sqlContext SQLContext to use # @param rdd An RDD of JSON string # @param schema A StructType object to use as schema # @param samplingRatio The ratio of simpling used to infer the schema @@ -220,16 +220,16 @@ jsonFile <- function(sqlCtx, path) { # @examples #\dontrun{ # sc <- sparkR.init() -# sqlCtx <- sparkRSQL.init(sc) +# sqlContext <- sparkRSQL.init(sc) # rdd <- texFile(sc, "path/to/json") -# df <- jsonRDD(sqlCtx, rdd) +# df <- jsonRDD(sqlContext, rdd) # } # TODO: support schema -jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { +jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { rdd <- serializeToString(rdd) if (is.null(schema)) { - sdf <- callJMethod(sqlCtx, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + sdf <- callJMethod(sqlContext, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) dataFrame(sdf) } else { stop("not implemented") @@ -241,64 +241,63 @@ jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) { #' #' Loads a Parquet file, returning the result as a DataFrame. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param ... Path(s) of parquet file(s) to read. #' @return DataFrame #' @export # TODO: Implement saveasParquetFile and write examples for both -parquetFile <- function(sqlCtx, ...) { +parquetFile <- function(sqlContext, ...) { # Allow the user to have a more flexible definiton of the text file path paths <- lapply(list(...), normalizePath) - sdf <- callJMethod(sqlCtx, "parquetFile", paths) + sdf <- callJMethod(sqlContext, "parquetFile", paths) dataFrame(sdf) } #' SQL Query -#' +#' #' Executes a SQL query using Spark, returning the result as a DataFrame. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param sqlQuery A character vector containing the SQL query #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' new_df <- sql(sqlCtx, "SELECT * FROM table") +#' new_df <- sql(sqlContext, "SELECT * FROM table") #' } -sql <- function(sqlCtx, sqlQuery) { - sdf <- callJMethod(sqlCtx, "sql", sqlQuery) - dataFrame(sdf) +sql <- function(sqlContext, sqlQuery) { + sdf <- callJMethod(sqlContext, "sql", sqlQuery) + dataFrame(sdf) } - #' Create a DataFrame from a SparkSQL Table #' #' Returns the specified Table as a DataFrame. The Table must have already been registered #' in the SQLContext. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The SparkSQL Table to convert to a DataFrame. #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' new_df <- table(sqlCtx, "table") +#' new_df <- table(sqlContext, "table") #' } -table <- function(sqlCtx, tableName) { - sdf <- callJMethod(sqlCtx, "table", tableName) +table <- function(sqlContext, tableName) { + sdf <- callJMethod(sqlContext, "table", tableName) dataFrame(sdf) } @@ -307,22 +306,22 @@ table <- function(sqlCtx, tableName) { #' #' Returns a DataFrame containing names of tables in the given database. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param databaseName name of the database #' @return a DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' tables(sqlCtx, "hive") +#' sqlContext <- sparkRSQL.init(sc) +#' tables(sqlContext, "hive") #' } -tables <- function(sqlCtx, databaseName = NULL) { +tables <- function(sqlContext, databaseName = NULL) { jdf <- if (is.null(databaseName)) { - callJMethod(sqlCtx, "tables") + callJMethod(sqlContext, "tables") } else { - callJMethod(sqlCtx, "tables", databaseName) + callJMethod(sqlContext, "tables", databaseName) } dataFrame(jdf) } @@ -332,22 +331,22 @@ tables <- function(sqlCtx, databaseName = NULL) { #' #' Returns the names of tables in the given database as an array. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param databaseName name of the database #' @return a list of table names #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' tableNames(sqlCtx, "hive") +#' sqlContext <- sparkRSQL.init(sc) +#' tableNames(sqlContext, "hive") #' } -tableNames <- function(sqlCtx, databaseName = NULL) { +tableNames <- function(sqlContext, databaseName = NULL) { if (is.null(databaseName)) { - callJMethod(sqlCtx, "tableNames") + callJMethod(sqlContext, "tableNames") } else { - callJMethod(sqlCtx, "tableNames", databaseName) + callJMethod(sqlContext, "tableNames", databaseName) } } @@ -356,58 +355,58 @@ tableNames <- function(sqlCtx, databaseName = NULL) { #' #' Caches the specified table in-memory. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The name of the table being cached #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' cacheTable(sqlCtx, "table") +#' cacheTable(sqlContext, "table") #' } -cacheTable <- function(sqlCtx, tableName) { - callJMethod(sqlCtx, "cacheTable", tableName) +cacheTable <- function(sqlContext, tableName) { + callJMethod(sqlContext, "cacheTable", tableName) } #' Uncache Table #' #' Removes the specified table from the in-memory cache. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The name of the table being uncached #' @return DataFrame #' @export #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #' path <- "path/to/file.json" -#' df <- jsonFile(sqlCtx, path) +#' df <- jsonFile(sqlContext, path) #' registerTempTable(df, "table") -#' uncacheTable(sqlCtx, "table") +#' uncacheTable(sqlContext, "table") #' } -uncacheTable <- function(sqlCtx, tableName) { - callJMethod(sqlCtx, "uncacheTable", tableName) +uncacheTable <- function(sqlContext, tableName) { + callJMethod(sqlContext, "uncacheTable", tableName) } #' Clear Cache #' #' Removes all cached tables from the in-memory cache. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @examples #' \dontrun{ -#' clearCache(sqlCtx) +#' clearCache(sqlContext) #' } -clearCache <- function(sqlCtx) { - callJMethod(sqlCtx, "clearCache") +clearCache <- function(sqlContext) { + callJMethod(sqlContext, "clearCache") } #' Drop Temporary Table @@ -415,22 +414,22 @@ clearCache <- function(sqlCtx) { #' Drops the temporary table with the given table name in the catalog. #' If the table has been cached/persisted before, it's also unpersisted. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName The name of the SparkSQL table to be dropped. #' @examples #' \dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- loadDF(sqlCtx, path, "parquet") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, path, "parquet") #' registerTempTable(df, "table") -#' dropTempTable(sqlCtx, "table") +#' dropTempTable(sqlContext, "table") #' } -dropTempTable <- function(sqlCtx, tableName) { +dropTempTable <- function(sqlContext, tableName) { if (class(tableName) != "character") { stop("tableName must be a string.") } - callJMethod(sqlCtx, "dropTempTable", tableName) + callJMethod(sqlContext, "dropTempTable", tableName) } #' Load an DataFrame @@ -441,7 +440,7 @@ dropTempTable <- function(sqlCtx, tableName) { #' If `source` is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param path The path of files to load #' @param source the name of external data source #' @return DataFrame @@ -449,19 +448,31 @@ dropTempTable <- function(sqlCtx, tableName) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- load(sqlCtx, "path/to/file.json", source = "json") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- read.df(sqlContext, "path/to/file.json", source = "json") #' } -loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { +read.df <- function(sqlContext, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path } - sdf <- callJMethod(sqlCtx, "load", source, options) + if (is.null(source)) { + sqlContext <- get(".sparkRSQLsc", envir = .sparkREnv) + source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", + "org.apache.spark.sql.parquet") + } + sdf <- callJMethod(sqlContext, "load", source, options) dataFrame(sdf) } +#' @aliases loadDF +#' @export + +loadDF <- function(sqlContext, path = NULL, source = NULL, ...) { + read.df(sqlContext, path, source, ...) +} + #' Create an external table #' #' Creates an external table based on the dataset in a data source, @@ -471,7 +482,7 @@ loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { #' If `source` is not specified, the default data source configured by #' "spark.sql.sources.default" will be used. #' -#' @param sqlCtx SQLContext to use +#' @param sqlContext SQLContext to use #' @param tableName A name of the table #' @param path The path of files to load #' @param source the name of external data source @@ -480,15 +491,15 @@ loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) -#' df <- sparkRSQL.createExternalTable(sqlCtx, "myjson", path="path/to/json", source="json") +#' sqlContext <- sparkRSQL.init(sc) +#' df <- sparkRSQL.createExternalTable(sqlContext, "myjson", path="path/to/json", source="json") #' } -createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, ...) { +createExternalTable <- function(sqlContext, tableName, path = NULL, source = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path } - sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options) + sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) dataFrame(sdf) } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 9a68445ab451a..80e92d3105a36 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -55,12 +55,17 @@ operators <- list( "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", # we can not override `&&` and `||`, so use `&` and `|` instead - "&" = "and", "|" = "or" #, "!" = "unary_$bang" + "&" = "and", "|" = "or", #, "!" = "unary_$bang" + "^" = "pow" ) column_functions1 <- c("asc", "desc", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains") functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", - "first", "last", "lower", "upper", "sumDistinct") + "first", "last", "lower", "upper", "sumDistinct", + "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", + "expm1", "floor", "log", "log10", "log1p", "rint", "sign", + "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") +binary_mathfunctions<- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, @@ -76,7 +81,11 @@ createOperator <- function(op) { if (class(e2) == "Column") { e2 <- e2@jc } - callJMethod(e1@jc, operators[[op]], e2) + if (op == "^") { + jc <- callJStatic("org.apache.spark.sql.functions", operators[[op]], e1@jc, e2) + } else { + callJMethod(e1@jc, operators[[op]], e2) + } } column(jc) }) @@ -106,11 +115,29 @@ createStaticFunction <- function(name) { setMethod(name, signature(x = "Column"), function(x) { + if (name == "ceiling") { + name <- "ceil" + } + if (name == "sign") { + name <- "signum" + } jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc) column(jc) }) } +createBinaryMathfunctions <- function(name) { + setMethod(name, + signature(y = "Column"), + function(y, x) { + if (class(x) == "Column") { + x <- x@jc + } + jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x) + column(jc) + }) +} + createMethods <- function() { for (op in names(operators)) { createOperator(op) @@ -124,6 +151,9 @@ createMethods <- function() { for (x in functions) { createStaticFunction(x) } + for (name in binary_mathfunctions) { + createBinaryMathfunctions(name) + } } createMethods() diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 557128a419f19..12e09176c9f92 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -396,6 +396,20 @@ setGeneric("columns", function(x) {standardGeneric("columns") }) #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) +#' @rdname nafunctions +#' @export +setGeneric("dropna", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("dropna") + }) + +#' @rdname nafunctions +#' @export +setGeneric("na.omit", + function(x, how = c("any", "all"), minNonNulls = NULL, cols = NULL) { + standardGeneric("na.omit") + }) + #' @rdname schema #' @export setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) @@ -408,6 +422,10 @@ setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @export setGeneric("except", function(x, y) { standardGeneric("except") }) +#' @rdname nafunctions +#' @export +setGeneric("fillna", function(x, value, cols = NULL) { standardGeneric("fillna") }) + #' @rdname filter #' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) @@ -456,19 +474,19 @@ setGeneric("rename", function(x, ...) { standardGeneric("rename") }) #' @export setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") }) -#' @rdname sampleDF +#' @rdname sample #' @export -setGeneric("sample_frac", +setGeneric("sample", function(x, withReplacement, fraction, seed) { - standardGeneric("sample_frac") - }) + standardGeneric("sample") + }) -#' @rdname sampleDF +#' @rdname sample #' @export -setGeneric("sampleDF", +setGeneric("sample_frac", function(x, withReplacement, fraction, seed) { - standardGeneric("sampleDF") - }) + standardGeneric("sample_frac") + }) #' @rdname saveAsParquetFile #' @export @@ -480,9 +498,13 @@ setGeneric("saveAsTable", function(df, tableName, source, mode, ...) { standardGeneric("saveAsTable") }) -#' @rdname saveAsTable +#' @rdname write.df +#' @export +setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") }) + +#' @rdname write.df #' @export -setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") }) +setGeneric("saveDF", function(df, path, ...) { standardGeneric("saveDF") }) #' @rdname schema #' @export @@ -548,6 +570,10 @@ setGeneric("avg", function(x, ...) { standardGeneric("avg") }) #' @export setGeneric("cast", function(x, dataType) { standardGeneric("cast") }) +#' @rdname column +#' @export +setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) + #' @rdname column #' @export setGeneric("contains", function(x, ...) { standardGeneric("contains") }) @@ -571,6 +597,10 @@ setGeneric("getField", function(x, ...) { standardGeneric("getField") }) #' @export setGeneric("getItem", function(x, ...) { standardGeneric("getItem") }) +#' @rdname column +#' @export +setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) + #' @rdname column #' @export setGeneric("isNull", function(x) { standardGeneric("isNull") }) @@ -599,6 +629,10 @@ setGeneric("n", function(x) { standardGeneric("n") }) #' @export setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") }) +#' @rdname column +#' @export +setGeneric("rint", function(x, ...) { standardGeneric("rint") }) + #' @rdname column #' @export setGeneric("rlike", function(x, ...) { standardGeneric("rlike") }) @@ -611,6 +645,14 @@ setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") }) #' @export setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) +#' @rdname column +#' @export +setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") }) + +#' @rdname column +#' @export +setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) + #' @rdname column #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 7694652856da5..1e24286dbcae2 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -329,7 +329,7 @@ setMethod("reduceByKey", convertEnvsToList(keys, vals) } locallyReduced <- lapplyPartition(x, reduceVals) - shuffled <- partitionBy(locallyReduced, numPartitions) + shuffled <- partitionBy(locallyReduced, numToInt(numPartitions)) lapplyPartition(shuffled, reduceVals) }) @@ -436,7 +436,7 @@ setMethod("combineByKey", convertEnvsToList(keys, combiners) } locallyCombined <- lapplyPartition(x, combineLocally) - shuffled <- partitionBy(locallyCombined, numPartitions) + shuffled <- partitionBy(locallyCombined, numToInt(numPartitions)) mergeAfterShuffle <- function(part) { combiners <- new.env() keys <- new.env() diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index c53d0a961016f..2081786e6f833 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -160,6 +160,14 @@ writeList <- function(con, arr) { } } +# Used to pass arrays where the elements can be of different types +writeGenericList <- function(con, list) { + writeInt(con, length(list)) + for (elem in list) { + writeObject(con, elem) + } +} + # Used to pass in hash maps required on Java side. writeEnv <- function(con, env) { len <- length(env) @@ -168,7 +176,7 @@ writeEnv <- function(con, env) { if (len > 0) { writeList(con, as.list(ls(env))) vals <- lapply(ls(env), function(x) { env[[x]] }) - writeList(con, as.list(vals)) + writeGenericList(con, as.list(vals)) } } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index bc82df01f0fff..5ced7c688f98a 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -222,19 +222,26 @@ sparkR.init <- function( #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRSQL.init(sc) +#' sqlContext <- sparkRSQL.init(sc) #'} -sparkRSQL.init <- function(jsc) { +sparkRSQL.init <- function(jsc = NULL) { if (exists(".sparkRSQLsc", envir = .sparkREnv)) { return(get(".sparkRSQLsc", envir = .sparkREnv)) } - sqlCtx <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "createSQLContext", - jsc) - assign(".sparkRSQLsc", sqlCtx, envir = .sparkREnv) - sqlCtx + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + + sqlContext <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createSQLContext", + sc) + assign(".sparkRSQLsc", sqlContext, envir = .sparkREnv) + sqlContext } #' Initialize a new HiveContext. @@ -246,15 +253,22 @@ sparkRSQL.init <- function(jsc) { #' @examples #'\dontrun{ #' sc <- sparkR.init() -#' sqlCtx <- sparkRHive.init(sc) +#' sqlContext <- sparkRHive.init(sc) #'} -sparkRHive.init <- function(jsc) { +sparkRHive.init <- function(jsc = NULL) { if (exists(".sparkRHivesc", envir = .sparkREnv)) { return(get(".sparkRHivesc", envir = .sparkREnv)) } - ssc <- callJMethod(jsc, "sc") + # If jsc is NULL, create a Spark Context + sc <- if (is.null(jsc)) { + sparkR.init() + } else { + jsc + } + + ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.HiveContext", ssc) }, error = function(err) { diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 0e7b7bd5a5b34..69b2700191c9a 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -122,13 +122,49 @@ hashCode <- function(key) { intBits <- packBits(rawToBits(rawVec), "integer") as.integer(bitwXor(intBits[2], intBits[1])) } else if (class(key) == "character") { - .Call("stringHashCode", key) + # TODO: SPARK-7839 means we might not have the native library available + if (is.loaded("stringHashCode")) { + .Call("stringHashCode", key) + } else { + n <- nchar(key) + if (n == 0) { + 0L + } else { + asciiVals <- sapply(charToRaw(key), function(x) { strtoi(x, 16L) }) + hashC <- 0 + for (k in 1:length(asciiVals)) { + hashC <- mult31AndAdd(hashC, asciiVals[k]) + } + as.integer(hashC) + } + } } else { warning(paste("Could not hash object, returning 0", sep = "")) as.integer(0) } } +# Helper function used to wrap a 'numeric' value to integer bounds. +# Useful for implementing C-like integer arithmetic +wrapInt <- function(value) { + if (value > .Machine$integer.max) { + value <- value - 2 * .Machine$integer.max - 2 + } else if (value < -1 * .Machine$integer.max) { + value <- 2 * .Machine$integer.max + value + 2 + } + value +} + +# Multiply `val` by 31 and add `addVal` to the result. Ensures that +# integer-overflows are handled at every step. +mult31AndAdd <- function(val, addVal) { + vec <- c(bitwShiftL(val, c(4,3,2,1,0)), addVal) + Reduce(function(a, b) { + wrapInt(as.numeric(a) + as.numeric(b)) + }, + vec) +} + # Create a new RDD with serializedMode == "byte". # Return itself if already in "byte" format. serializeToBytes <- function(rdd) { diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R index 33478d9e29995..ca94f1d4e7fd5 100644 --- a/R/pkg/inst/profile/shell.R +++ b/R/pkg/inst/profile/shell.R @@ -26,8 +26,8 @@ sc <- SparkR::sparkR.init(Sys.getenv("MASTER", unset = "")) assign("sc", sc, envir=.GlobalEnv) - sqlCtx <- SparkR::sparkRSQL.init(sc) - assign("sqlCtx", sqlCtx, envir=.GlobalEnv) + sqlContext <- SparkR::sparkRSQL.init(sc) + assign("sqlContext", sqlContext, envir=.GlobalEnv) cat("\n Welcome to SparkR!") - cat("\n Spark context is available as sc, SQL context is available as sqlCtx\n") + cat("\n Spark context is available as sc, SQL context is available as sqlContext\n") } diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 99c28830c6237..d2d82e791e876 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -23,7 +23,7 @@ context("SparkSQL functions") sc <- sparkR.init() -sqlCtx <- sparkRSQL.init(sc) +sqlContext <- sparkRSQL.init(sc) mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", @@ -32,6 +32,15 @@ jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet") writeLines(mockLines, jsonPath) +# For test nafunctions, like dropna(), fillna(),... +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}", + "{\"name\":\"Amy\",\"age\":null,\"height\":null}", + "{\"name\":null,\"age\":null,\"height\":null}") +jsonPathNa <- tempfile(pattern="sparkr-test", fileext=".tmp") +writeLines(mockLinesNa, jsonPathNa) + test_that("infer types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") @@ -67,25 +76,25 @@ test_that("structType and structField", { test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) - df <- createDataFrame(sqlCtx, rdd, list("a", "b")) + df <- createDataFrame(sqlContext, rdd, list("a", "b")) expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) - df <- createDataFrame(sqlCtx, rdd) + df <- createDataFrame(sqlContext, rdd) expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) - df <- createDataFrame(sqlCtx, rdd, schema) + df <- createDataFrame(sqlContext, rdd, schema) expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) - df <- createDataFrame(sqlCtx, rdd) + df <- createDataFrame(sqlContext, rdd) expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 10) expect_equal(columns(df), c("a", "b")) @@ -121,17 +130,17 @@ test_that("toDF", { test_that("create DataFrame from list or data.frame", { l <- list(list(1, 2), list(3, 4)) - df <- createDataFrame(sqlCtx, l, c("a", "b")) + df <- createDataFrame(sqlContext, l, c("a", "b")) expect_equal(columns(df), c("a", "b")) l <- list(list(a=1, b=2), list(a=3, b=4)) - df <- createDataFrame(sqlCtx, l) + df <- createDataFrame(sqlContext, l) expect_equal(columns(df), c("a", "b")) a <- 1:3 b <- c("a", "b", "c") ldf <- data.frame(a, b) - df <- createDataFrame(sqlCtx, ldf) + df <- createDataFrame(sqlContext, ldf) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) expect_equal(count(df), 3) @@ -142,7 +151,7 @@ test_that("create DataFrame from list or data.frame", { test_that("create DataFrame with different data types", { l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"), f = as.POSIXct("2015-03-15 12:13:14.056")) - df <- createDataFrame(sqlCtx, list(l)) + df <- createDataFrame(sqlContext, list(l)) expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"), c("d", "string"), c("e", "date"), c("f", "timestamp"))) expect_equal(count(df), 1) @@ -154,7 +163,7 @@ test_that("create DataFrame with different data types", { # e <- new.env() # assign("n", 3L, envir = e) # l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L)) -# df <- createDataFrame(sqlCtx, list(l), c("a", "b", "c", "d")) +# df <- createDataFrame(sqlContext, list(l), c("a", "b", "c", "d")) # expect_equal(dtypes(df), list(c("a", "array"), c("b", "array"), # c("c", "map"), c("d", "struct"))) # expect_equal(count(df), 1) @@ -163,7 +172,7 @@ test_that("create DataFrame with different data types", { #}) test_that("jsonFile() on a local file returns a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 3) }) @@ -171,88 +180,88 @@ test_that("jsonFile() on a local file returns a DataFrame", { test_that("jsonRDD() on a RDD with json string", { rdd <- parallelize(sc, mockLines) expect_true(count(rdd) == 3) - df <- jsonRDD(sqlCtx, rdd) + df <- jsonRDD(sqlContext, rdd) expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 3) rdd2 <- flatMap(rdd, function(x) c(x, x)) - df <- jsonRDD(sqlCtx, rdd2) + df <- jsonRDD(sqlContext, rdd2) expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 6) }) test_that("test cache, uncache and clearCache", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - cacheTable(sqlCtx, "table1") - uncacheTable(sqlCtx, "table1") - clearCache(sqlCtx) - dropTempTable(sqlCtx, "table1") + cacheTable(sqlContext, "table1") + uncacheTable(sqlContext, "table1") + clearCache(sqlContext) + dropTempTable(sqlContext, "table1") }) test_that("test tableNames and tables", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - expect_true(length(tableNames(sqlCtx)) == 1) - df <- tables(sqlCtx) + expect_true(length(tableNames(sqlContext)) == 1) + df <- tables(sqlContext) expect_true(count(df) == 1) - dropTempTable(sqlCtx, "table1") + dropTempTable(sqlContext, "table1") }) test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - newdf <- sql(sqlCtx, "SELECT * FROM table1 where name = 'Michael'") + newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") expect_true(inherits(newdf, "DataFrame")) expect_true(count(newdf) == 1) - dropTempTable(sqlCtx, "table1") + dropTempTable(sqlContext, "table1") }) test_that("insertInto() on a registered table", { - df <- loadDF(sqlCtx, jsonPath, "json") - saveDF(df, parquetPath, "parquet", "overwrite") - dfParquet <- loadDF(sqlCtx, parquetPath, "parquet") + df <- read.df(sqlContext, jsonPath, "json") + write.df(df, parquetPath, "parquet", "overwrite") + dfParquet <- read.df(sqlContext, parquetPath, "parquet") lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") writeLines(lines, jsonPath2) - df2 <- loadDF(sqlCtx, jsonPath2, "json") - saveDF(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- loadDF(sqlCtx, parquetPath2, "parquet") + df2 <- read.df(sqlContext, jsonPath2, "json") + write.df(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- read.df(sqlContext, parquetPath2, "parquet") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1") - expect_true(count(sql(sqlCtx, "select * from table1")) == 5) - expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Michael") - dropTempTable(sqlCtx, "table1") + expect_true(count(sql(sqlContext, "select * from table1")) == 5) + expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Michael") + dropTempTable(sqlContext, "table1") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_true(count(sql(sqlCtx, "select * from table1")) == 2) - expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Bob") - dropTempTable(sqlCtx, "table1") + expect_true(count(sql(sqlContext, "select * from table1")) == 2) + expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Bob") + dropTempTable(sqlContext, "table1") }) test_that("table() returns a new DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - tabledf <- table(sqlCtx, "table1") + tabledf <- table(sqlContext, "table1") expect_true(inherits(tabledf, "DataFrame")) expect_true(count(tabledf) == 3) - dropTempTable(sqlCtx, "table1") + dropTempTable(sqlContext, "table1") }) test_that("toRDD() returns an RRDD", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testRDD <- toRDD(df) expect_true(inherits(testRDD, "RDD")) expect_true(count(testRDD) == 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) @@ -274,7 +283,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { writeLines(textLines, textPath) textRDD <- textFile(sc, textPath) - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) @@ -292,7 +301,7 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { test_that("objectFile() works with row serialization", { objectPath <- tempfile(pattern="spark-test", fileext=".tmp") - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) dfRDD <- toRDD(df) saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) @@ -303,7 +312,7 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 row @@ -315,7 +324,7 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { }) test_that("collect() returns a data.frame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) expect_true(names(rdf)[1] == "age") @@ -324,20 +333,20 @@ test_that("collect() returns a data.frame", { }) test_that("limit() returns DataFrame with the correct number of rows", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) dfLimited <- limit(df, 2) expect_true(inherits(dfLimited, "DataFrame")) expect_true(count(dfLimited) == 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_true(nrow(collect(df)) == nrow(take(df, 10))) expect_true(ncol(collect(df)) == ncol(take(df, 10))) }) test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 row @@ -354,7 +363,7 @@ test_that("multiple pipeline transformations starting with a DataFrame result in }) test_that("cache(), persist(), and unpersist() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_false(df@env$isCached) cache(df) expect_true(df@env$isCached) @@ -373,7 +382,7 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testSchema <- schema(df) expect_true(length(testSchema$fields()) == 2) expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") @@ -394,7 +403,7 @@ test_that("schema(), dtypes(), columns(), names() return the correct values/form }) test_that("head() and first() return the correct data", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) expect_true(nrow(testHead) == 3) expect_true(ncol(testHead) == 2) @@ -415,18 +424,18 @@ test_that("distinct() on DataFrames", { jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPathWithDup) - df <- jsonFile(sqlCtx, jsonPathWithDup) + df <- jsonFile(sqlContext, jsonPathWithDup) uniques <- distinct(df) expect_true(inherits(uniques, "DataFrame")) expect_true(count(uniques) == 3) }) -test_that("sampleDF on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) - sampled <- sampleDF(df, FALSE, 1.0) +test_that("sample on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) expect_true(inherits(sampled, "DataFrame")) - sampled2 <- sampleDF(df, FALSE, 0.1) + sampled2 <- sample(df, FALSE, 0.1) expect_true(count(sampled2) < 3) # Also test sample_frac @@ -435,7 +444,7 @@ test_that("sampleDF on a DataFrame", { }) test_that("select operators", { - df <- select(jsonFile(sqlCtx, jsonPath), "name", "age") + df <- select(jsonFile(sqlContext, jsonPath), "name", "age") expect_true(inherits(df$name, "Column")) expect_true(inherits(df[[2]], "Column")) expect_true(inherits(df[["age"]], "Column")) @@ -461,7 +470,7 @@ test_that("select operators", { }) test_that("select with column", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) df1 <- select(df, "name") expect_true(columns(df1) == c("name")) expect_true(count(df1) == 3) @@ -472,7 +481,7 @@ test_that("select with column", { }) test_that("selectExpr() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") expect_true(names(selected) == "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) @@ -483,7 +492,7 @@ test_that("selectExpr() on a DataFrame", { }) test_that("column calculation", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) expect_true(names(d) == c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) @@ -491,16 +500,16 @@ test_that("column calculation", { expect_true(count(df2) == 3) }) -test_that("load() from json file", { - df <- loadDF(sqlCtx, jsonPath, "json") +test_that("read.df() from json file", { + df <- read.df(sqlContext, jsonPath, "json") expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 3) }) -test_that("save() as parquet file", { - df <- loadDF(sqlCtx, jsonPath, "json") - saveDF(df, parquetPath, "parquet", mode="overwrite") - df2 <- loadDF(sqlCtx, parquetPath, "parquet") +test_that("write.df() as parquet file", { + df <- read.df(sqlContext, jsonPath, "json") + write.df(df, parquetPath, "parquet", mode="overwrite") + df2 <- read.df(sqlContext, parquetPath, "parquet") expect_true(inherits(df2, "DataFrame")) expect_true(count(df2) == 3) }) @@ -530,6 +539,7 @@ test_that("column operators", { c2 <- (- c + 1 - 2) * 3 / 4.0 c3 <- (c + c2 - c2) * c2 %% c2 c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) + c5 <- c2 ^ c3 ^ c4 }) test_that("column functions", { @@ -538,10 +548,33 @@ test_that("column functions", { c3 <- lower(c) + upper(c) + first(c) + last(c) c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string") c5 <- n(c) + n_distinct(c) + c5 <- acos(c) + asin(c) + atan(c) + cbrt(c) + c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c) + c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c) + c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c) + c9 <- toDegrees(c) + toRadians(c) +}) + +test_that("column binary mathfunctions", { + lines <- c("{\"a\":1, \"b\":5}", + "{\"a\":2, \"b\":6}", + "{\"a\":3, \"b\":7}", + "{\"a\":4, \"b\":8}") + jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp") + writeLines(lines, jsonPathWithDup) + df <- jsonFile(sqlContext, jsonPathWithDup) + expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) + expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) + expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) }) test_that("string operators", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_equal(count(where(df, like(df$name, "A%"))), 1) expect_equal(count(where(df, startsWith(df$name, "A"))), 1) expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") @@ -549,7 +582,7 @@ test_that("string operators", { }) test_that("group by", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") expect_true(1 == count(df1)) df1 <- agg(df, age2 = max(df$age)) @@ -586,7 +619,7 @@ test_that("group by", { }) test_that("arrange() and orderBy() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) sorted <- arrange(df, df$age) expect_true(collect(sorted)[1,2] == "Michael") @@ -603,7 +636,7 @@ test_that("arrange() and orderBy() on a DataFrame", { }) test_that("filter() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) filtered <- filter(df, "age > 20") expect_true(count(filtered) == 1) expect_true(collect(filtered)$name == "Andy") @@ -613,7 +646,7 @@ test_that("filter() on a DataFrame", { }) test_that("join() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}", "{\"name\":\"Andy\", \"test\": \"no\"}", @@ -621,7 +654,7 @@ test_that("join() on a DataFrame", { "{\"name\":\"Bob\", \"test\": \"yes\"}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLines2, jsonPath2) - df2 <- jsonFile(sqlCtx, jsonPath2) + df2 <- jsonFile(sqlContext, jsonPath2) joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) @@ -644,7 +677,7 @@ test_that("join() on a DataFrame", { }) test_that("toJSON() returns an RDD of the correct values", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) expect_true(inherits(testRDD, "RDD")) expect_true(SparkR:::getSerializedMode(testRDD) == "string") @@ -652,25 +685,25 @@ test_that("toJSON() returns an RDD of the correct values", { }) test_that("showDF()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) s <- capture.output(showDF(df)) expect_output(s , "+----+-------+\n| age| name|\n+----+-------+\n|null|Michael|\n| 30| Andy|\n| 19| Justin|\n+----+-------+\n") }) test_that("isLocal()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) expect_false(isLocal(df)) }) test_that("unionAll(), except(), and intersect() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", "{\"name\":\"Andy\", \"age\":30}", "{\"name\":\"James\", \"age\":35}") jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(lines, jsonPath2) - df2 <- loadDF(sqlCtx, jsonPath2, "json") + df2 <- read.df(sqlContext, jsonPath2, "json") unioned <- arrange(unionAll(df, df2), df$age) expect_true(inherits(unioned, "DataFrame")) @@ -689,7 +722,7 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", { }) test_that("withColumn() and withColumnRenamed()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) expect_true(length(columns(newDF)) == 3) expect_true(columns(newDF)[3] == "newAge") @@ -701,7 +734,7 @@ test_that("withColumn() and withColumnRenamed()", { }) test_that("mutate() and rename()", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) expect_true(length(columns(newDF)) == 3) expect_true(columns(newDF)[3] == "newAge") @@ -712,34 +745,134 @@ test_that("mutate() and rename()", { expect_true(columns(newDF2)[1] == "newerAge") }) -test_that("saveDF() on DataFrame and works with parquetFile", { - df <- jsonFile(sqlCtx, jsonPath) - saveDF(df, parquetPath, "parquet", mode="overwrite") - parquetDF <- parquetFile(sqlCtx, parquetPath) +test_that("write.df() on DataFrame and works with parquetFile", { + df <- jsonFile(sqlContext, jsonPath) + write.df(df, parquetPath, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlContext, parquetPath) expect_true(inherits(parquetDF, "DataFrame")) expect_equal(count(df), count(parquetDF)) }) test_that("parquetFile works with multiple input paths", { - df <- jsonFile(sqlCtx, jsonPath) - saveDF(df, parquetPath, "parquet", mode="overwrite") + df <- jsonFile(sqlContext, jsonPath) + write.df(df, parquetPath, "parquet", mode="overwrite") parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - saveDF(df, parquetPath2, "parquet", mode="overwrite") - parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2) + write.df(df, parquetPath2, "parquet", mode="overwrite") + parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) expect_true(inherits(parquetDF, "DataFrame")) expect_true(count(parquetDF) == count(df)*2) }) test_that("describe() on a DataFrame", { - df <- jsonFile(sqlCtx, jsonPath) + df <- jsonFile(sqlContext, jsonPath) stats <- describe(df, "age") - expect_true(collect(stats)[1, "summary"] == "count") - expect_true(collect(stats)[2, "age"] == 24.5) - expect_true(collect(stats)[3, "age"] == 5.5) + expect_equal(collect(stats)[1, "summary"], "count") + expect_equal(collect(stats)[2, "age"], "24.5") + expect_equal(collect(stats)[3, "age"], "5.5") stats <- describe(df) - expect_true(collect(stats)[4, "name"] == "Andy") - expect_true(collect(stats)[5, "age"] == 30.0) + expect_equal(collect(stats)[4, "name"], "Andy") + expect_equal(collect(stats)[5, "age"], "30") +}) + +test_that("dropna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # drop with columns + + expected <- rows[!is.na(rows$name),] + actual <- collect(dropna(df, cols = "name")) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age),] + actual <- collect(dropna(df, cols = "age")) + row.names(expected) <- row.names(actual) + # identical on two dataframes does not work here. Don't know why. + # use identical on all columns as a workaround. + expect_true(identical(expected$age, actual$age)) + expect_true(identical(expected$height, actual$height)) + expect_true(identical(expected$name, actual$name)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_true(identical(expected, actual)) + + # drop with how + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df)) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] + actual <- collect(dropna(df, "all")) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] + actual <- collect(dropna(df, "any")) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) & !is.na(rows$height),] + actual <- collect(dropna(df, "any", cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + expected <- rows[!is.na(rows$age) | !is.na(rows$height),] + actual <- collect(dropna(df, "all", cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + # drop with threshold + + expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] + actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) + expect_true(identical(expected, actual)) + + expected <- rows[as.integer(!is.na(rows$age)) + + as.integer(!is.na(rows$height)) + + as.integer(!is.na(rows$name)) >= 3,] + actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) + expect_true(identical(expected, actual)) +}) + +test_that("fillna() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPathNa) + rows <- collect(df) + + # fill with value + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + actual <- collect(fillna(df, 50.6)) + expect_true(identical(expected, actual)) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown")) + expect_true(identical(expected, actual)) + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + actual <- collect(fillna(df, 50.6, "age")) + expect_true(identical(expected, actual)) + + expected <- rows + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, "unknown", c("age", "name"))) + expect_true(identical(expected, actual)) + + # fill with named list + + expected <- rows + expected$age[is.na(expected$age)] <- 50 + expected$height[is.na(expected$height)] <- 50.6 + expected$name[is.na(expected$name)] <- "unknown" + actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) + expect_true(identical(expected, actual)) }) unlink(parquetPath) unlink(jsonPath) +unlink(jsonPathNa) diff --git a/R/pkg/src/Makefile b/R/pkg/src-native/Makefile similarity index 100% rename from R/pkg/src/Makefile rename to R/pkg/src-native/Makefile diff --git a/R/pkg/src/Makefile.win b/R/pkg/src-native/Makefile.win similarity index 100% rename from R/pkg/src/Makefile.win rename to R/pkg/src-native/Makefile.win diff --git a/R/pkg/src/string_hash_code.c b/R/pkg/src-native/string_hash_code.c similarity index 100% rename from R/pkg/src/string_hash_code.c rename to R/pkg/src-native/string_hash_code.c diff --git a/README.md b/README.md index 9c09d40e2bdae..380422ca00dbe 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ Spark is a fast and general cluster computing system for Big Data. It provides high-level APIs in Scala, Java, and Python, and an optimized engine that supports general computation graphs for data analysis. It also supports a -rich set of higher-level tools including Spark SQL for SQL and structured -data processing, MLlib for machine learning, GraphX for graph processing, +rich set of higher-level tools including Spark SQL for SQL and DataFrames, +MLlib for machine learning, GraphX for graph processing, and Spark Streaming for stream processing. @@ -22,7 +22,7 @@ This README file only contains basic setup instructions. Spark is built using [Apache Maven](http://maven.apache.org/). To build Spark and its example programs, run: - mvn -DskipTests clean package + build/mvn -DskipTests clean package (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at @@ -43,7 +43,7 @@ Try the following command, which should return 1000: Alternatively, if you prefer Python, you can use the Python shell: ./bin/pyspark - + And run the following command, which should also return 1000: >>> sc.parallelize(range(1000)).count() @@ -58,9 +58,9 @@ To run one of them, use `./bin/run-example [params]`. For example: will run the Pi example locally. You can set the MASTER environment variable when running examples to submit -examples to a cluster. This can be a mesos:// or spark:// URL, -"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run -locally with one thread, or "local[N]" to run locally with N threads. You +examples to a cluster. This can be a mesos:// or spark:// URL, +"yarn-cluster" or "yarn-client" to run on YARN, and "local" to run +locally with one thread, or "local[N]" to run locally with N threads. You can also use an abbreviated class name if the class is in the `examples` package. For instance: @@ -75,7 +75,7 @@ can be run using: ./dev/run-tests -Please see the guidance on how to +Please see the guidance on how to [run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools). ## A Note About Hadoop Versions diff --git a/bagel/pom.xml b/bagel/pom.xml index 1f3dec91314f2..132cd433d78a2 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index ccb262a4ee02a..fb10d734ac74b 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.bagel -import org.scalatest.{BeforeAndAfter, FunSuite, Assertions} +import org.scalatest.{BeforeAndAfter, Assertions} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -27,7 +27,7 @@ import org.apache.spark.storage.StorageLevel class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable -class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { +class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts { var sc: SparkContext = _ diff --git a/bin/pyspark b/bin/pyspark index 8acad6113797d..7cb19c51b43a2 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -90,11 +90,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 - else - exec "$PYSPARK_DRIVER_PYTHON" $1 - fi + exec "$PYSPARK_DRIVER_PYTHON" -m $1 exit fi diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 2e0cb5db170ac..7de0011a48ca8 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -126,9 +126,9 @@ #*.sink.slf4j.class=org.apache.spark.metrics.sink.Slf4jSink # Polling period for Slf4JSink -#*.sink.sl4j.period=1 +#*.sink.slf4j.period=1 -#*.sink.sl4j.unit=minutes +#*.sink.slf4j.unit=minutes # Enable jvm source for instance master, worker, driver and executor diff --git a/core/pom.xml b/core/pom.xml index 262a3320db106..a02184222e9f0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -338,6 +338,12 @@ org.seleniumhq.selenium selenium-java + + + com.google.guava + guava + + test @@ -361,15 +367,31 @@ junit test + + org.hamcrest + hamcrest-core + test + + + org.hamcrest + hamcrest-library + test + com.novocode junit-interface test - org.spark-project + net.razorvine pyrolite 4.4 + + + net.razorvine + serpent + + net.sf.py4j @@ -459,6 +481,29 @@ + + sparkr-docs + + + + org.codehaus.mojo + exec-maven-plugin + + + sparkr-pkg-docs + compile + + exec + + + + + ..${path.separator}R${path.separator}create-docs${script.extension} + + + + + diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java new file mode 100644 index 0000000000000..d3d6280284beb --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import com.google.common.io.Closeables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.Partitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.storage.*; +import org.apache.spark.util.Utils; + +/** + * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path + * writes incoming records to separate files, one file per reduce partition, then concatenates these + * per-partition files to form a single output file, regions of which are served to reducers. + * Records are not buffered in memory. This is essentially identical to + * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format + * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}. + *

+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it + * simultaneously opens separate serializers and file streams for all partitions. As a result, + * {@link SortShuffleManager} only selects this write path when + *

    + *
  • no Ordering is specified,
  • + *
  • no Aggregator is specific, and
  • + *
  • the number of partitions is less than + * spark.shuffle.sort.bypassMergeThreshold.
  • + *
+ * + * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was + * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details. + *

+ * There have been proposals to completely remove this code path; see SPARK-6026 for details. + */ +final class BypassMergeSortShuffleWriter implements SortShuffleFileWriter { + + private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); + + private final int fileBufferSize; + private final boolean transferToEnabled; + private final int numPartitions; + private final BlockManager blockManager; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final Serializer serializer; + + /** Array of file writers, one for each partition */ + private BlockObjectWriter[] partitionWriters; + + public BypassMergeSortShuffleWriter( + SparkConf conf, + BlockManager blockManager, + Partitioner partitioner, + ShuffleWriteMetrics writeMetrics, + Serializer serializer) { + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); + this.numPartitions = partitioner.numPartitions(); + this.blockManager = blockManager; + this.partitioner = partitioner; + this.writeMetrics = writeMetrics; + this.serializer = serializer; + } + + @Override + public void insertAll(Iterator> records) throws IOException { + assert (partitionWriters == null); + if (!records.hasNext()) { + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new BlockObjectWriter[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open(); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime); + + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); + } + + for (BlockObjectWriter writer : partitionWriters) { + writer.commitAndClose(); + } + } + + @Override + public long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException { + // Track location of the partition starts in the output file + final long[] lengths = new long[numPartitions]; + if (partitionWriters == null) { + // We were passed an empty iterator + return lengths; + } + + final FileOutputStream out = new FileOutputStream(outputFile, true); + final long writeStartTime = System.nanoTime(); + boolean threwException = true; + try { + for (int i = 0; i < numPartitions; i++) { + final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file()); + boolean copyThrewException = true; + try { + lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) { + logger.error("Unable to delete file for partition {}", i); + } + } + threwException = false; + } finally { + Closeables.close(out, threwException); + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + } + partitionWriters = null; + return lengths; + } + + @Override + public void stop() throws IOException { + if (partitionWriters != null) { + try { + final DiskBlockManager diskBlockManager = blockManager.diskBlockManager(); + for (BlockObjectWriter writer : partitionWriters) { + // This method explicitly does _not_ throw exceptions: + writer.revertPartialWritesAndClose(); + if (!diskBlockManager.getFile(writer.blockId()).delete()) { + logger.error("Error while deleting file for block {}", writer.blockId()); + } + } + } finally { + partitionWriters = null; + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java new file mode 100644 index 0000000000000..656ea0401a144 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.File; +import java.io.IOException; + +import scala.Product2; +import scala.collection.Iterator; + +import org.apache.spark.annotation.Private; +import org.apache.spark.TaskContext; +import org.apache.spark.storage.BlockId; + +/** + * Interface for objects that {@link SortShuffleWriter} uses to write its output files. + */ +@Private +public interface SortShuffleFileWriter { + + void insertAll(Iterator> records) throws IOException; + + /** + * Write all the data added into this shuffle sorter into a file in the disk store. This is + * called by the SortShuffleWriter and can go through an efficient path of just concatenating + * binary files if we decided to avoid merge-sorting. + * + * @param blockId block ID to write to. The index file will be blockId.name + ".index". + * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + long[] writePartitionedFile( + BlockId blockId, + TaskContext context, + File outputFile) throws IOException; + + void stop() throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java new file mode 100644 index 0000000000000..3f746b886bc9b --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import scala.reflect.ClassTag; + +import org.apache.spark.serializer.DeserializationStream; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + * Our shuffle write path doesn't actually use this serializer (since we end up calling the + * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + * around this, we pass a dummy no-op serializer. + */ +final class DummySerializerInstance extends SerializerInstance { + + public static final DummySerializerInstance INSTANCE = new DummySerializerInstance(); + + private DummySerializerInstance() { } + + @Override + public SerializationStream serializeStream(final OutputStream s) { + return new SerializationStream() { + @Override + public void flush() { + // Need to implement this because DiskObjectWriter uses it to flush the compression stream + try { + s.flush(); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + } + + @Override + public SerializationStream writeObject(T t, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public void close() { + // Need to implement this because DiskObjectWriter uses it to close the compression stream + try { + s.close(); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + } + }; + } + + @Override + public ByteBuffer serialize(T t, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public DeserializationStream deserializeStream(InputStream s) { + throw new UnsupportedOperationException(); + } + + @Override + public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag ev1) { + throw new UnsupportedOperationException(); + } + + @Override + public T deserialize(ByteBuffer bytes, ClassTag ev1) { + throw new UnsupportedOperationException(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java new file mode 100644 index 0000000000000..4ee6a82c0423e --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +/** + * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. + *

+ * Within the long, the data is laid out as follows: + *

+ *   [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ * 
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that + * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the + * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this + * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task. + *

+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this + * optimization to future work as it will require more careful design to ensure that addresses are + * properly aligned (e.g. by padding records). + */ +final class PackedRecordPointer { + + static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27; // 128 megabytes + + /** + * The maximum partition identifier that can be encoded. Note that partition ids start from 0. + */ + static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215 + + /** Bit mask for the lower 40 bits of a long. */ + private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1; + + /** Bit mask for the upper 24 bits of a long */ + private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS; + + /** Bit mask for the lower 27 bits of a long. */ + private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1; + + /** Bit mask for the upper 13 bits of a long */ + private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; + + /** + * Pack a record address and partition id into a single word. + * + * @param recordPointer a record pointer encoded by TaskMemoryManager. + * @param partitionId a shuffle partition id (maximum value of 2^24). + * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class. + */ + public static long packPointer(long recordPointer, int partitionId) { + assert (partitionId <= MAXIMUM_PARTITION_ID); + // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page. + // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses. + final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24; + final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS); + return (((long) partitionId) << 40) | compressedAddress; + } + + private long packedRecordPointer; + + public void set(long packedRecordPointer) { + this.packedRecordPointer = packedRecordPointer; + } + + public int getPartitionId() { + return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40); + } + + public long getRecordPointer() { + final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS; + final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS; + return pageNumber | offsetInPage; + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java new file mode 100644 index 0000000000000..7bac0dc0bbeb6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.File; + +import org.apache.spark.storage.TempShuffleBlockId; + +/** + * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}. + */ +final class SpillInfo { + final long[] partitionLengths; + final File file; + final TempShuffleBlockId blockId; + + public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) { + this.partitionLengths = new long[numPartitions]; + this.file = file; + this.blockId = blockId; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java new file mode 100644 index 0000000000000..9e9ed94b7890c --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -0,0 +1,422 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.File; +import java.io.IOException; +import java.util.LinkedList; + +import scala.Tuple2; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.storage.*; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.Utils; + +/** + * An external sorter that is specialized for sort-based shuffle. + *

+ * Incoming records are appended to data pages. When all records have been inserted (or when the + * current thread's shuffle memory limit is reached), the in-memory records are sorted according to + * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then + * written to a single output file (or multiple files, if we've spilled). The format of the output + * files is the same as the format of the final output file written by + * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are + * written as a single serialized, compressed stream that can be read with a new decompression and + * deserialization stream. + *

+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its + * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a + * specialized merge procedure that avoids extra serialization/deserialization. + */ +final class UnsafeShuffleExternalSorter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); + + private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; + @VisibleForTesting + static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; + @VisibleForTesting + static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; + + private final int initialSize; + private final int numPartitions; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final BlockManager blockManager; + private final TaskContext taskContext; + private final ShuffleWriteMetrics writeMetrics; + + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ + private final int fileBufferSizeBytes; + + /** + * Memory pages that hold the records being sorted. The pages in this list are freed when + * spilling, although in principle we could recycle these pages across spills (on the other hand, + * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager + * itself). + */ + private final LinkedList allocatedPages = new LinkedList(); + + private final LinkedList spills = new LinkedList(); + + // These variables are reset after spilling: + private UnsafeShuffleInMemorySorter sorter; + private MemoryBlock currentPage = null; + private long currentPagePosition = -1; + private long freeSpaceInCurrentPage = 0; + + public UnsafeShuffleExternalSorter( + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + int initialSize, + int numPartitions, + SparkConf conf, + ShuffleWriteMetrics writeMetrics) throws IOException { + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.blockManager = blockManager; + this.taskContext = taskContext; + this.initialSize = initialSize; + this.numPartitions = numPartitions; + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided + this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + + this.writeMetrics = writeMetrics; + initializeForWriting(); + } + + /** + * Allocates new sort data structures. Called when creating the sorter and after each spill. + */ + private void initializeForWriting() throws IOException { + // TODO: move this sizing calculation logic into a static method of sorter: + final long memoryRequested = initialSize * 8L; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + if (memoryAcquired != memoryRequested) { + shuffleMemoryManager.release(memoryAcquired); + throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); + } + + this.sorter = new UnsafeShuffleInMemorySorter(initialSize); + } + + /** + * Sorts the in-memory records and writes the sorted records to an on-disk file. + * This method does not free the sort data structures. + * + * @param isLastFile if true, this indicates that we're writing the final output file and that the + * bytes written should be counted towards shuffle spill metrics rather than + * shuffle write metrics. + */ + private void writeSortedFile(boolean isLastFile) throws IOException { + + final ShuffleWriteMetrics writeMetricsToUse; + + if (isLastFile) { + // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes. + writeMetricsToUse = writeMetrics; + } else { + // We're spilling, so bytes written should be counted towards spill rather than write. + // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count + // them towards shuffle bytes written. + writeMetricsToUse = new ShuffleWriteMetrics(); + } + + // This call performs the actual sort. + final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords = + sorter.getSortedIterator(); + + // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this + // after SPARK-5581 is fixed. + BlockObjectWriter writer; + + // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to + // be an API to directly transfer bytes from managed memory to the disk writer, we buffer + // data through a byte array. This array does not need to be large enough to hold a single + // record; + final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE]; + + // Because this output will be read during shuffle, its compression codec must be controlled by + // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use + // createTempShuffleBlock here; see SPARK-3426 for more details. + final Tuple2 spilledFileInfo = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = spilledFileInfo._2(); + final TempShuffleBlockId blockId = spilledFileInfo._1(); + final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId); + + // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter. + // Our write path doesn't actually use this serializer (since we end up calling the `write()` + // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work + // around this, we pass a dummy no-op serializer. + final SerializerInstance ser = DummySerializerInstance.INSTANCE; + + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); + + int currentPartition = -1; + while (sortedRecords.hasNext()) { + sortedRecords.loadNext(); + final int partition = sortedRecords.packedRecordPointer.getPartitionId(); + assert (partition >= currentPartition); + if (partition != currentPartition) { + // Switch to the new partition + if (currentPartition != -1) { + writer.commitAndClose(); + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + } + currentPartition = partition; + writer = + blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); + } + + final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); + final Object recordPage = memoryManager.getPage(recordPointer); + final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer); + int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage); + long recordReadPosition = recordOffsetInPage + 4; // skip over record length + while (dataRemaining > 0) { + final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining); + PlatformDependent.copyMemory( + recordPage, + recordReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + toTransfer); + writer.write(writeBuffer, 0, toTransfer); + recordReadPosition += toTransfer; + dataRemaining -= toTransfer; + } + writer.recordWritten(); + } + + if (writer != null) { + writer.commitAndClose(); + // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted, + // then the file might be empty. Note that it might be better to avoid calling + // writeSortedFile() in that case. + if (currentPartition != -1) { + spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + spills.add(spillInfo); + } + } + + if (!isLastFile) { // i.e. this is a spill file + // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records + // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter + // relies on its `recordWritten()` method being called in order to trigger periodic updates to + // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that + // counter at a higher-level, then the in-progress metrics for records written and bytes + // written would get out of sync. + // + // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter; + // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those + // metrics to the true write metrics here. The reason for performing this copying is so that + // we can avoid reporting spilled bytes as shuffle write bytes. + // + // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`. + // Consistent with ExternalSorter, we do not count this IO towards shuffle write time. + // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this. + writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten()); + taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten()); + } + } + + /** + * Sort and spill the current records in response to memory pressure. + */ + @VisibleForTesting + void spill() throws IOException { + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", + Thread.currentThread().getId(), + Utils.bytesToString(getMemoryUsage()), + spills.size(), + spills.size() > 1 ? " times" : " time"); + + writeSortedFile(false); + final long sorterMemoryUsage = sorter.getMemoryUsage(); + sorter = null; + shuffleMemoryManager.release(sorterMemoryUsage); + final long spillSize = freeMemory(); + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); + + initializeForWriting(); + } + + private long getMemoryUsage() { + return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + } + + private long freeMemory() { + long memoryFreed = 0; + for (MemoryBlock block : allocatedPages) { + memoryManager.freePage(block); + shuffleMemoryManager.release(block.size()); + memoryFreed += block.size(); + } + allocatedPages.clear(); + currentPage = null; + currentPagePosition = -1; + freeSpaceInCurrentPage = 0; + return memoryFreed; + } + + /** + * Force all memory and spill files to be deleted; called by shuffle error-handling code. + */ + public void cleanupAfterError() { + freeMemory(); + for (SpillInfo spill : spills) { + if (spill.file.exists() && !spill.file.delete()) { + logger.error("Unable to delete spill file {}", spill.file.getPath()); + } + } + if (sorter != null) { + shuffleMemoryManager.release(sorter.getMemoryUsage()); + sorter = null; + } + } + + /** + * Checks whether there is enough space to insert a new record into the sorter. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + + * @return true if the record can be inserted without requiring more allocations, false otherwise. + */ + private boolean haveSpaceForRecord(int requiredSpace) { + assert (requiredSpace > 0); + return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + } + + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. + */ + private void allocateSpaceForRecord(int requiredSpace) throws IOException { + if (!sorter.hasSpaceForAnotherRecord()) { + logger.debug("Attempting to expand sort pointer array"); + final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + if (memoryAcquired < memoryToGrowPointerArray) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + } else { + sorter.expandPointerArray(); + shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + } + } + if (requiredSpace > freeSpaceInCurrentPage) { + logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, + freeSpaceInCurrentPage); + // TODO: we should track metrics on the amount of space wasted when we roll over to a new page + // without using the free space at the end of the current page. We should also do this for + // BytesToBytesMap. + if (requiredSpace > PAGE_SIZE) { + throw new IOException("Required space " + requiredSpace + " is greater than page size (" + + PAGE_SIZE + ")"); + } else { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquired < PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquired); + spill(); + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); + if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + } + } + currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPagePosition = currentPage.getBaseOffset(); + freeSpaceInCurrentPage = PAGE_SIZE; + allocatedPages.add(currentPage); + } + } + } + + /** + * Write a record to the shuffle sorter. + */ + public void insertRecord( + Object recordBaseObject, + long recordBaseOffset, + int lengthInBytes, + int partitionId) throws IOException { + // Need 4 bytes to store the record length. + final int totalSpaceRequired = lengthInBytes + 4; + if (!haveSpaceForRecord(totalSpaceRequired)) { + allocateSpaceForRecord(totalSpaceRequired); + } + + final long recordAddress = + memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object dataPageBaseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); + currentPagePosition += 4; + freeSpaceInCurrentPage -= 4; + PlatformDependent.copyMemory( + recordBaseObject, + recordBaseOffset, + dataPageBaseObject, + currentPagePosition, + lengthInBytes); + currentPagePosition += lengthInBytes; + freeSpaceInCurrentPage -= lengthInBytes; + sorter.insertRecord(recordAddress, partitionId); + } + + /** + * Close the sorter, causing any buffered data to be sorted and written out to disk. + * + * @return metadata for the spill files written by this sorter. If no records were ever inserted + * into this sorter, then this will return an empty array. + * @throws IOException + */ + public SpillInfo[] closeAndGetSpills() throws IOException { + try { + if (sorter != null) { + // Do not count the final file towards the spill count. + writeSortedFile(true); + freeMemory(); + } + return spills.toArray(new SpillInfo[spills.size()]); + } catch (IOException e) { + cleanupAfterError(); + throw e; + } + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java new file mode 100644 index 0000000000000..5bab501da9364 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.util.Comparator; + +import org.apache.spark.util.collection.Sorter; + +final class UnsafeShuffleInMemorySorter { + + private final Sorter sorter; + private static final class SortComparator implements Comparator { + @Override + public int compare(PackedRecordPointer left, PackedRecordPointer right) { + return left.getPartitionId() - right.getPartitionId(); + } + } + private static final SortComparator SORT_COMPARATOR = new SortComparator(); + + /** + * An array of record pointers and partition ids that have been encoded by + * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating + * records. + */ + private long[] pointerArray; + + /** + * The position in the pointer array where new records can be inserted. + */ + private int pointerArrayInsertPosition = 0; + + public UnsafeShuffleInMemorySorter(int initialSize) { + assert (initialSize > 0); + this.pointerArray = new long[initialSize]; + this.sorter = new Sorter(UnsafeShuffleSortDataFormat.INSTANCE); + } + + public void expandPointerArray() { + final long[] oldArray = pointerArray; + // Guard against overflow: + final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; + pointerArray = new long[newLength]; + System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + } + + public boolean hasSpaceForAnotherRecord() { + return pointerArrayInsertPosition + 1 < pointerArray.length; + } + + public long getMemoryUsage() { + return pointerArray.length * 8L; + } + + /** + * Inserts a record to be sorted. + * + * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to + * certain pointer compression techniques used by the sorter, the sort can + * only operate on pointers that point to locations in the first + * {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page. + * @param partitionId the partition id, which must be less than or equal to + * {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}. + */ + public void insertRecord(long recordPointer, int partitionId) { + if (!hasSpaceForAnotherRecord()) { + if (pointerArray.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Sort pointer array has reached maximum size"); + } else { + expandPointerArray(); + } + } + pointerArray[pointerArrayInsertPosition] = + PackedRecordPointer.packPointer(recordPointer, partitionId); + pointerArrayInsertPosition++; + } + + /** + * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining. + */ + public static final class UnsafeShuffleSorterIterator { + + private final long[] pointerArray; + private final int numRecords; + final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); + private int position = 0; + + public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) { + this.numRecords = numRecords; + this.pointerArray = pointerArray; + } + + public boolean hasNext() { + return position < numRecords; + } + + public void loadNext() { + packedRecordPointer.set(pointerArray[position]); + position++; + } + } + + /** + * Return an iterator over record pointers in sorted order. + */ + public UnsafeShuffleSorterIterator getSortedIterator() { + sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); + return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java new file mode 100644 index 0000000000000..a66d74ee44782 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import org.apache.spark.util.collection.SortDataFormat; + +final class UnsafeShuffleSortDataFormat extends SortDataFormat { + + public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat(); + + private UnsafeShuffleSortDataFormat() { } + + @Override + public PackedRecordPointer getKey(long[] data, int pos) { + // Since we re-use keys, this method shouldn't be called. + throw new UnsupportedOperationException(); + } + + @Override + public PackedRecordPointer newKey() { + return new PackedRecordPointer(); + } + + @Override + public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { + reuse.set(data[pos]); + return reuse; + } + + @Override + public void swap(long[] data, int pos0, int pos1) { + final long temp = data[pos0]; + data[pos0] = data[pos1]; + data[pos1] = temp; + } + + @Override + public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { + dst[dstPos] = src[srcPos]; + } + + @Override + public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { + System.arraycopy(src, srcPos, dst, dstPos, length); + } + + @Override + public long[] allocate(int length) { + return new long[length]; + } + +} diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java new file mode 100644 index 0000000000000..ad7eb04afcd8c --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe; + +import java.io.*; +import java.nio.channels.FileChannel; +import java.util.Iterator; +import javax.annotation.Nullable; + +import scala.Option; +import scala.Product2; +import scala.collection.JavaConversions; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.ByteStreams; +import com.google.common.io.Closeables; +import com.google.common.io.Files; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.*; +import org.apache.spark.annotation.Private; +import org.apache.spark.io.CompressionCodec; +import org.apache.spark.io.CompressionCodec$; +import org.apache.spark.io.LZFCompressionCodec; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.scheduler.MapStatus$; +import org.apache.spark.serializer.SerializationStream; +import org.apache.spark.serializer.Serializer; +import org.apache.spark.serializer.SerializerInstance; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.storage.TimeTrackingOutputStream; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +@Private +public class UnsafeShuffleWriter extends ShuffleWriter { + + private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class); + + private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); + + @VisibleForTesting + static final int INITIAL_SORT_BUFFER_SIZE = 4096; + + private final BlockManager blockManager; + private final IndexShuffleBlockResolver shuffleBlockResolver; + private final TaskMemoryManager memoryManager; + private final ShuffleMemoryManager shuffleMemoryManager; + private final SerializerInstance serializer; + private final Partitioner partitioner; + private final ShuffleWriteMetrics writeMetrics; + private final int shuffleId; + private final int mapId; + private final TaskContext taskContext; + private final SparkConf sparkConf; + private final boolean transferToEnabled; + + private MapStatus mapStatus = null; + private UnsafeShuffleExternalSorter sorter = null; + + /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ + private static final class MyByteArrayOutputStream extends ByteArrayOutputStream { + public MyByteArrayOutputStream(int size) { super(size); } + public byte[] getBuf() { return buf; } + } + + private MyByteArrayOutputStream serBuffer; + private SerializationStream serOutputStream; + + /** + * Are we in the process of stopping? Because map tasks can call stop() with success = true + * and then call stop() with success = false if they get an exception, we want to make sure + * we don't try deleting files, etc twice. + */ + private boolean stopping = false; + + public UnsafeShuffleWriter( + BlockManager blockManager, + IndexShuffleBlockResolver shuffleBlockResolver, + TaskMemoryManager memoryManager, + ShuffleMemoryManager shuffleMemoryManager, + UnsafeShuffleHandle handle, + int mapId, + TaskContext taskContext, + SparkConf sparkConf) throws IOException { + final int numPartitions = handle.dependency().partitioner().numPartitions(); + if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) { + throw new IllegalArgumentException( + "UnsafeShuffleWriter can only be used for shuffles with at most " + + UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions"); + } + this.blockManager = blockManager; + this.shuffleBlockResolver = shuffleBlockResolver; + this.memoryManager = memoryManager; + this.shuffleMemoryManager = shuffleMemoryManager; + this.mapId = mapId; + final ShuffleDependency dep = handle.dependency(); + this.shuffleId = dep.shuffleId(); + this.serializer = Serializer.getSerializer(dep.serializer()).newInstance(); + this.partitioner = dep.partitioner(); + this.writeMetrics = new ShuffleWriteMetrics(); + taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics)); + this.taskContext = taskContext; + this.sparkConf = sparkConf; + this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); + open(); + } + + /** + * This convenience method should only be called in test code. + */ + @VisibleForTesting + public void write(Iterator> records) throws IOException { + write(JavaConversions.asScalaIterator(records)); + } + + @Override + public void write(scala.collection.Iterator> records) throws IOException { + boolean success = false; + try { + while (records.hasNext()) { + insertRecordIntoSorter(records.next()); + } + closeAndWriteOutput(); + success = true; + } finally { + if (!success) { + sorter.cleanupAfterError(); + } + } + } + + private void open() throws IOException { + assert (sorter == null); + sorter = new UnsafeShuffleExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + INITIAL_SORT_BUFFER_SIZE, + partitioner.numPartitions(), + sparkConf, + writeMetrics); + serBuffer = new MyByteArrayOutputStream(1024 * 1024); + serOutputStream = serializer.serializeStream(serBuffer); + } + + @VisibleForTesting + void closeAndWriteOutput() throws IOException { + serBuffer = null; + serOutputStream = null; + final SpillInfo[] spills = sorter.closeAndGetSpills(); + sorter = null; + final long[] partitionLengths; + try { + partitionLengths = mergeSpills(spills); + } finally { + for (SpillInfo spill : spills) { + if (spill.file.exists() && ! spill.file.delete()) { + logger.error("Error while deleting spill file {}", spill.file.getPath()); + } + } + } + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + } + + @VisibleForTesting + void insertRecordIntoSorter(Product2 record) throws IOException { + final K key = record._1(); + final int partitionId = partitioner.getPartition(key); + serBuffer.reset(); + serOutputStream.writeKey(key, OBJECT_CLASS_TAG); + serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); + serOutputStream.flush(); + + final int serializedRecordSize = serBuffer.size(); + assert (serializedRecordSize > 0); + + sorter.insertRecord( + serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); + } + + @VisibleForTesting + void forceSorterToSpill() throws IOException { + assert (sorter != null); + sorter.spill(); + } + + /** + * Merge zero or more spill files together, choosing the fastest merging strategy based on the + * number of spills and the IO compression codec. + * + * @return the partition lengths in the merged file. + */ + private long[] mergeSpills(SpillInfo[] spills) throws IOException { + final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); + final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); + final boolean fastMergeEnabled = + sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); + final boolean fastMergeIsSupported = + !compressionEnabled || compressionCodec instanceof LZFCompressionCodec; + try { + if (spills.length == 0) { + new FileOutputStream(outputFile).close(); // Create an empty file + return new long[partitioner.numPartitions()]; + } else if (spills.length == 1) { + // Here, we don't need to perform any metrics updates because the bytes written to this + // output file would have already been counted as shuffle bytes written. + Files.move(spills[0].file, outputFile); + return spills[0].partitionLengths; + } else { + final long[] partitionLengths; + // There are multiple spills to merge, so none of these spill files' lengths were counted + // towards our shuffle write count or shuffle write time. If we use the slow merge path, + // then the final output file's size won't necessarily be equal to the sum of the spill + // files' sizes. To guard against this case, we look at the output file's actual size when + // computing shuffle bytes written. + // + // We allow the individual merge methods to report their own IO times since different merge + // strategies use different IO techniques. We count IO during merge towards the shuffle + // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" + // branch in ExternalSorter. + if (fastMergeEnabled && fastMergeIsSupported) { + // Compression is disabled or we are using an IO compression codec that supports + // decompression of concatenated compressed streams, so we can perform a fast spill merge + // that doesn't need to interpret the spilled bytes. + if (transferToEnabled) { + logger.debug("Using transferTo-based fast merge"); + partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + } else { + logger.debug("Using fileStream-based fast merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + } + } else { + logger.debug("Using slow merge"); + partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + } + // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has + // in-memory records, we write out the in-memory records to a file but do not count that + // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs + // to be counted as shuffle write, but this will lead to double-counting of the final + // SpillInfo's bytes. + writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length()); + writeMetrics.incShuffleBytesWritten(outputFile.length()); + return partitionLengths; + } + } catch (IOException e) { + if (outputFile.exists() && !outputFile.delete()) { + logger.error("Unable to delete output file {}", outputFile.getPath()); + } + throw e; + } + } + + /** + * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, + * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in + * cases where the IO compression codec does not support concatenation of compressed data, or in + * cases where users have explicitly disabled use of {@code transferTo} in order to work around + * kernel bugs. + * + * @param spills the spills to merge. + * @param outputFile the file to write the merged data to. + * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. + * @return the partition lengths in the merged file. + */ + private long[] mergeSpillsWithFileStream( + SpillInfo[] spills, + File outputFile, + @Nullable CompressionCodec compressionCodec) throws IOException { + assert (spills.length >= 2); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final InputStream[] spillInputStreams = new FileInputStream[spills.length]; + OutputStream mergedFileOutputStream = null; + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + spillInputStreams[i] = new FileInputStream(spills[i].file); + } + for (int partition = 0; partition < numPartitions; partition++) { + final long initialFileLength = outputFile.length(); + mergedFileOutputStream = + new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + if (compressionCodec != null) { + mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + } + + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = + new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + } + ByteStreams.copy(partitionInputStream, mergedFileOutputStream); + } + } + mergedFileOutputStream.flush(); + mergedFileOutputStream.close(); + partitionLengths[partition] = (outputFile.length() - initialFileLength); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (InputStream stream : spillInputStreams) { + Closeables.close(stream, threwException); + } + Closeables.close(mergedFileOutputStream, threwException); + } + return partitionLengths; + } + + /** + * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes. + * This is only safe when the IO compression codec and serializer support concatenation of + * serialized streams. + * + * @return the partition lengths in the merged file. + */ + private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { + assert (spills.length >= 2); + final int numPartitions = partitioner.numPartitions(); + final long[] partitionLengths = new long[numPartitions]; + final FileChannel[] spillInputChannels = new FileChannel[spills.length]; + final long[] spillInputChannelPositions = new long[spills.length]; + FileChannel mergedFileOutputChannel = null; + + boolean threwException = true; + try { + for (int i = 0; i < spills.length; i++) { + spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); + } + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. + mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); + + long bytesWrittenToMergedFile = 0; + for (int partition = 0; partition < numPartitions; partition++) { + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + long bytesToTransfer = partitionLengthInSpill; + final FileChannel spillInputChannel = spillInputChannels[i]; + final long writeStartTime = System.nanoTime(); + while (bytesToTransfer > 0) { + final long actualBytesTransferred = spillInputChannel.transferTo( + spillInputChannelPositions[i], + bytesToTransfer, + mergedFileOutputChannel); + spillInputChannelPositions[i] += actualBytesTransferred; + bytesToTransfer -= actualBytesTransferred; + } + writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime); + bytesWrittenToMergedFile += partitionLengthInSpill; + partitionLengths[partition] += partitionLengthInSpill; + } + } + // Check the position after transferTo loop to see if it is in the right position and raise an + // exception if it is incorrect. The position will not be increased to the expected length + // after calling transferTo in kernel version 2.6.32. This issue is described at + // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. + if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { + throw new IOException( + "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + + "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + + " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + + "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + + "to disable this NIO feature." + ); + } + threwException = false; + } finally { + // To avoid masking exceptions that caused us to prematurely enter the finally block, only + // throw exceptions during cleanup if threwException == false. + for (int i = 0; i < spills.length; i++) { + assert(spillInputChannelPositions[i] == spills[i].file.length()); + Closeables.close(spillInputChannels[i], threwException); + } + Closeables.close(mergedFileOutputChannel, threwException); + } + return partitionLengths; + } + + @Override + public Option stop(boolean success) { + try { + if (stopping) { + return Option.apply(null); + } else { + stopping = true; + if (success) { + if (mapStatus == null) { + throw new IllegalStateException("Cannot call stop(true) without having called write()"); + } + return Option.apply(mapStatus); + } else { + // The map task failed, so delete our output data. + shuffleBlockResolver.removeDataByMap(shuffleId, mapId); + return Option.apply(null); + } + } + } finally { + if (sorter != null) { + // If sorter is non-null, then this implies that we called stop() in response to an error, + // so we need to clean up memory and spill files created by the sorter + sorter.cleanupAfterError(); + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java new file mode 100644 index 0000000000000..dc2aa30466cc6 --- /dev/null +++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage; + +import java.io.IOException; +import java.io.OutputStream; + +import org.apache.spark.annotation.Private; +import org.apache.spark.executor.ShuffleWriteMetrics; + +/** + * Intercepts write calls and tracks total time spent writing in order to update shuffle write + * metrics. Not thread safe. + */ +@Private +public final class TimeTrackingOutputStream extends OutputStream { + + private final ShuffleWriteMetrics writeMetrics; + private final OutputStream outputStream; + + public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) { + this.writeMetrics = writeMetrics; + this.outputStream = outputStream; + } + + @Override + public void write(int b) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void write(byte[] b) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + final long startTime = System.nanoTime(); + outputStream.write(b, off, len); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void flush() throws IOException { + final long startTime = System.nanoTime(); + outputStream.flush(); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } + + @Override + public void close() throws IOException { + final long startTime = System.nanoTime(); + outputStream.close(); + writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime); + } +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js index acf2d93b718b2..2d9262b972a59 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +++ b/core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js @@ -20,7 +20,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -module.exports={graphlib:require("./lib/graphlib"),dagre:require("./lib/dagre"),intersect:require("./lib/intersect"),render:require("./lib/render"),util:require("./lib/util"),version:require("./lib/version")}},{"./lib/dagre":8,"./lib/graphlib":9,"./lib/intersect":10,"./lib/render":23,"./lib/util":25,"./lib/version":26}],2:[function(require,module,exports){var util=require("./util");module.exports={"default":normal,normal:normal,vee:vee,undirected:undirected};function normal(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function vee(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 L 4 5 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function undirected(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 5 L 10 5").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}},{"./util":25}],3:[function(require,module,exports){var _=require("./lodash"),addLabel=require("./label/add-label"),util=require("./util");module.exports=createClusters;function createClusters(selection,g){var clusters=g.nodes().filter(function(v){return util.isSubgraph(g,v)}),svgClusters=selection.selectAll("g.cluster").data(clusters,function(v){return v});var makeClusterIdentifier=function(v){return"cluster_"+v.replace(/^cluster/,"")};svgClusters.enter().append("g").attr("id",makeClusterIdentifier).attr("name",function(v){return g.node(v).label}).classed("cluster",true).style("opacity",0).append("rect");var sortedClusters=util.orderByRank(g,svgClusters.data());for(var i=0;i0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){ +module.exports={graphlib:require("./lib/graphlib"),dagre:require("./lib/dagre"),intersect:require("./lib/intersect"),render:require("./lib/render"),util:require("./lib/util"),version:require("./lib/version")}},{"./lib/dagre":8,"./lib/graphlib":9,"./lib/intersect":10,"./lib/render":23,"./lib/util":25,"./lib/version":26}],2:[function(require,module,exports){var util=require("./util");module.exports={"default":normal,normal:normal,vee:vee,undirected:undirected};function normal(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function vee(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 0 L 10 5 L 0 10 L 4 5 z").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}function undirected(parent,id,edge,type){var marker=parent.append("marker").attr("id",id).attr("viewBox","0 0 10 10").attr("refX",9).attr("refY",5).attr("markerUnits","strokeWidth").attr("markerWidth",8).attr("markerHeight",6).attr("orient","auto");var path=marker.append("path").attr("d","M 0 5 L 10 5").style("stroke-width",1).style("stroke-dasharray","1,0");util.applyStyle(path,edge[type+"Style"])}},{"./util":25}],3:[function(require,module,exports){var _=require("./lodash"),addLabel=require("./label/add-label"),util=require("./util");module.exports=createClusters;function createClusters(selection,g){var clusters=g.nodes().filter(function(v){return util.isSubgraph(g,v)}),svgClusters=selection.selectAll("g.cluster").data(clusters,function(v){return v});var makeClusterIdentifier=function(v){return"cluster_"+v.replace(/^cluster/,"")};svgClusters.enter().append("g").attr("class",makeClusterIdentifier).attr("name",function(v){return g.node(v).label}).classed("cluster",true).style("opacity",0).append("rect");var sortedClusters=util.orderByRank(g,svgClusters.data());for(var i=0;i0}},{}],14:[function(require,module,exports){module.exports=intersectNode;function intersectNode(node,point){return node.intersect(point)}},{}],15:[function(require,module,exports){var intersectLine=require("./intersect-line");module.exports=intersectPolygon;function intersectPolygon(node,polyPoints,point){var x1=node.x;var y1=node.y;var intersections=[];var minX=Number.POSITIVE_INFINITY,minY=Number.POSITIVE_INFINITY;polyPoints.forEach(function(entry){minX=Math.min(minX,entry.x);minY=Math.min(minY,entry.y)});var left=x1-node.width/2-minX;var top=y1-node.height/2-minY;for(var i=0;i1){intersections.sort(function(p,q){var pdx=p.x-point.x,pdy=p.y-point.y,distp=Math.sqrt(pdx*pdx+pdy*pdy),qdx=q.x-point.x,qdy=q.y-point.y,distq=Math.sqrt(qdx*qdx+qdy*qdy);return distpMath.abs(dx)*h){if(dy<0){h=-h}sx=dy===0?0:h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=dx===0?0:w*dy/dx}return{x:x+sx,y:y+sy}}},{}],17:[function(require,module,exports){var util=require("../util");module.exports=addHtmlLabel;function addHtmlLabel(root,node){var fo=root.append("foreignObject").attr("width","100000");var div=fo.append("xhtml:div");var label=node.label;switch(typeof label){case"function":div.insert(label);break;case"object":div.insert(function(){return label});break;default:div.html(label)}util.applyStyle(div,node.labelStyle);div.style("display","inline-block");div.style("white-space","nowrap");var w,h;div.each(function(){w=this.clientWidth;h=this.clientHeight});fo.attr("width",w).attr("height",h);return fo}},{"../util":25}],18:[function(require,module,exports){var addTextLabel=require("./add-text-label"),addHtmlLabel=require("./add-html-label");module.exports=addLabel;function addLabel(root,node){var label=node.label;var labelSvg=root.append("g");if(typeof label!=="string"||node.labelType==="html"){addHtmlLabel(labelSvg,node)}else{addTextLabel(labelSvg,node)}var labelBBox=labelSvg.node().getBBox();labelSvg.attr("transform","translate("+-labelBBox.width/2+","+-labelBBox.height/2+")");return labelSvg}},{"./add-html-label":17,"./add-text-label":19}],19:[function(require,module,exports){var util=require("../util");module.exports=addTextLabel;function addTextLabel(root,node){var domNode=root.append("text");var lines=processEscapeSequences(node.label).split("\n");for(var i=0;imaxPadding){maxPadding=child.paddingTop}}return maxPadding}function getRank(g,v){var maxRank=0;var children=g.children(v);for(var i=0;imaxRank){maxRank=thisRank}}return maxRank}function orderByRank(g,nodes){return nodes.sort(function(x,y){return getRank(g,x)-getRank(g,y)})}function edgeToId(e){return escapeId(e.v)+":"+escapeId(e.w)+":"+escapeId(e.name)}var ID_DELIM=/:/g;function escapeId(str){return str?String(str).replace(ID_DELIM,"\\:"):""}function applyStyle(dom,styleFn){if(styleFn){dom.attr("style",styleFn)}}function applyClass(dom,classFn,otherClasses){if(classFn){dom.attr("class",classFn).attr("class",otherClasses+" "+dom.attr("class"))}}function applyTransition(selection,g){var graph=g.graph();if(_.isPlainObject(graph)){var transition=graph.transition;if(_.isFunction(transition)){return transition(selection)}}return selection}},{"./lodash":20}],26:[function(require,module,exports){module.exports="0.4.4-pre"},{}],27:[function(require,module,exports){module.exports={graphlib:require("./lib/graphlib"),layout:require("./lib/layout"),debug:require("./lib/debug"),util:{time:require("./lib/util").time,notime:require("./lib/util").notime},version:require("./lib/version")}},{"./lib/debug":32,"./lib/graphlib":33,"./lib/layout":35,"./lib/util":55,"./lib/version":56}],28:[function(require,module,exports){"use strict";var _=require("./lodash"),greedyFAS=require("./greedy-fas");module.exports={run:run,undo:undo};function run(g){var fas=g.graph().acyclicer==="greedy"?greedyFAS(g,weightFn(g)):dfsFAS(g);_.each(fas,function(e){var label=g.edge(e);g.removeEdge(e);label.forwardName=e.name;label.reversed=true;g.setEdge(e.w,e.v,label,_.uniqueId("rev"))});function weightFn(g){return function(e){return g.edge(e).weight}}}function dfsFAS(g){var fas=[],stack={},visited={};function dfs(v){if(_.has(visited,v)){return}visited[v]=true;stack[v]=true;_.each(g.outEdges(v),function(e){if(_.has(stack,e.w)){fas.push(e)}else{dfs(e.w)}});delete stack[v]}_.each(g.nodes(),dfs);return fas}function undo(g){_.each(g.edges(),function(e){var label=g.edge(e);if(label.reversed){g.removeEdge(e);var forwardName=label.forwardName;delete label.reversed;delete label.forwardName;g.setEdge(e.w,e.v,label,forwardName)}})}},{"./greedy-fas":34,"./lodash":36}],29:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports=addBorderSegments;function addBorderSegments(g){function dfs(v){var children=g.children(v),node=g.node(v);if(children.length){_.each(children,dfs)}if(_.has(node,"minRank")){node.borderLeft=[];node.borderRight=[];for(var rank=node.minRank,maxRank=node.maxRank+1;rank0;--i){entry=buckets[i].dequeue();if(entry){results=results.concat(removeNode(g,buckets,zeroIdx,entry,true));break}}}}return results}function removeNode(g,buckets,zeroIdx,entry,collectPredecessors){var results=collectPredecessors?[]:undefined;_.each(g.inEdges(entry.v),function(edge){var weight=g.edge(edge),uEntry=g.node(edge.v);if(collectPredecessors){results.push({v:edge.v,w:edge.w})}uEntry.out-=weight;assignBucket(buckets,zeroIdx,uEntry)});_.each(g.outEdges(entry.v),function(edge){var weight=g.edge(edge),w=edge.w,wEntry=g.node(w);wEntry["in"]-=weight;assignBucket(buckets,zeroIdx,wEntry)});g.removeNode(entry.v);return results}function buildState(g,weightFn){var fasGraph=new Graph,maxIn=0,maxOut=0;_.each(g.nodes(),function(v){fasGraph.setNode(v,{v:v,"in":0,out:0})});_.each(g.edges(),function(e){var prevWeight=fasGraph.edge(e.v,e.w)||0,weight=weightFn(e),edgeWeight=prevWeight+weight;fasGraph.setEdge(e.v,e.w,edgeWeight);maxOut=Math.max(maxOut,fasGraph.node(e.v).out+=weight);maxIn=Math.max(maxIn,fasGraph.node(e.w)["in"]+=weight)});var buckets=_.range(maxOut+maxIn+3).map(function(){return new List});var zeroIdx=maxIn+1;_.each(fasGraph.nodes(),function(v){assignBucket(buckets,zeroIdx,fasGraph.node(v))});return{graph:fasGraph,buckets:buckets,zeroIdx:zeroIdx}}function assignBucket(buckets,zeroIdx,entry){if(!entry.out){buckets[0].enqueue(entry)}else if(!entry["in"]){buckets[buckets.length-1].enqueue(entry)}else{buckets[entry.out-entry["in"]+zeroIdx].enqueue(entry)}}},{"./data/list":31,"./graphlib":33,"./lodash":36}],35:[function(require,module,exports){"use strict";var _=require("./lodash"),acyclic=require("./acyclic"),normalize=require("./normalize"),rank=require("./rank"),normalizeRanks=require("./util").normalizeRanks,parentDummyChains=require("./parent-dummy-chains"),removeEmptyRanks=require("./util").removeEmptyRanks,nestingGraph=require("./nesting-graph"),addBorderSegments=require("./add-border-segments"),coordinateSystem=require("./coordinate-system"),order=require("./order"),position=require("./position"),util=require("./util"),Graph=require("./graphlib").Graph;module.exports=layout;function layout(g,opts){var time=opts&&opts.debugTiming?util.time:util.notime;time("layout",function(){var layoutGraph=time(" buildLayoutGraph",function(){return buildLayoutGraph(g)});time(" runLayout",function(){runLayout(layoutGraph,time)});time(" updateInputGraph",function(){updateInputGraph(g,layoutGraph)})})}function runLayout(g,time){time(" makeSpaceForEdgeLabels",function(){makeSpaceForEdgeLabels(g)});time(" removeSelfEdges",function(){removeSelfEdges(g)});time(" acyclic",function(){acyclic.run(g)});time(" nestingGraph.run",function(){nestingGraph.run(g)});time(" rank",function(){rank(util.asNonCompoundGraph(g))});time(" injectEdgeLabelProxies",function(){injectEdgeLabelProxies(g)});time(" removeEmptyRanks",function(){removeEmptyRanks(g)});time(" nestingGraph.cleanup",function(){nestingGraph.cleanup(g)});time(" normalizeRanks",function(){normalizeRanks(g)});time(" assignRankMinMax",function(){assignRankMinMax(g)});time(" removeEdgeLabelProxies",function(){removeEdgeLabelProxies(g)});time(" normalize.run",function(){normalize.run(g)});time(" parentDummyChains",function(){ parentDummyChains(g)});time(" addBorderSegments",function(){addBorderSegments(g)});time(" order",function(){order(g)});time(" insertSelfEdges",function(){insertSelfEdges(g)});time(" adjustCoordinateSystem",function(){coordinateSystem.adjust(g)});time(" position",function(){position(g)});time(" positionSelfEdges",function(){positionSelfEdges(g)});time(" removeBorderNodes",function(){removeBorderNodes(g)});time(" normalize.undo",function(){normalize.undo(g)});time(" fixupEdgeLabelCoords",function(){fixupEdgeLabelCoords(g)});time(" undoCoordinateSystem",function(){coordinateSystem.undo(g)});time(" translateGraph",function(){translateGraph(g)});time(" assignNodeIntersects",function(){assignNodeIntersects(g)});time(" reversePoints",function(){reversePointsForReversedEdges(g)});time(" acyclic.undo",function(){acyclic.undo(g)})}function updateInputGraph(inputGraph,layoutGraph){_.each(inputGraph.nodes(),function(v){var inputLabel=inputGraph.node(v),layoutLabel=layoutGraph.node(v);if(inputLabel){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y;if(layoutGraph.children(v).length){inputLabel.width=layoutLabel.width;inputLabel.height=layoutLabel.height}}});_.each(inputGraph.edges(),function(e){var inputLabel=inputGraph.edge(e),layoutLabel=layoutGraph.edge(e);inputLabel.points=layoutLabel.points;if(_.has(layoutLabel,"x")){inputLabel.x=layoutLabel.x;inputLabel.y=layoutLabel.y}});inputGraph.graph().width=layoutGraph.graph().width;inputGraph.graph().height=layoutGraph.graph().height}var graphNumAttrs=["nodesep","edgesep","ranksep","marginx","marginy"],graphDefaults={ranksep:50,edgesep:20,nodesep:50,rankdir:"tb"},graphAttrs=["acyclicer","ranker","rankdir","align"],nodeNumAttrs=["width","height"],nodeDefaults={width:0,height:0},edgeNumAttrs=["minlen","weight","width","height","labeloffset"],edgeDefaults={minlen:1,weight:1,width:0,height:0,labeloffset:10,labelpos:"r"},edgeAttrs=["labelpos"];function buildLayoutGraph(inputGraph){var g=new Graph({multigraph:true,compound:true}),graph=canonicalize(inputGraph.graph());g.setGraph(_.merge({},graphDefaults,selectNumberAttrs(graph,graphNumAttrs),_.pick(graph,graphAttrs)));_.each(inputGraph.nodes(),function(v){var node=canonicalize(inputGraph.node(v));g.setNode(v,_.defaults(selectNumberAttrs(node,nodeNumAttrs),nodeDefaults));g.setParent(v,inputGraph.parent(v))});_.each(inputGraph.edges(),function(e){var edge=canonicalize(inputGraph.edge(e));g.setEdge(e,_.merge({},edgeDefaults,selectNumberAttrs(edge,edgeNumAttrs),_.pick(edge,edgeAttrs)))});return g}function makeSpaceForEdgeLabels(g){var graph=g.graph();graph.ranksep/=2;_.each(g.edges(),function(e){var edge=g.edge(e);edge.minlen*=2;if(edge.labelpos.toLowerCase()!=="c"){if(graph.rankdir==="TB"||graph.rankdir==="BT"){edge.width+=edge.labeloffset}else{edge.height+=edge.labeloffset}}})}function injectEdgeLabelProxies(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.width&&edge.height){var v=g.node(e.v),w=g.node(e.w),label={rank:(w.rank-v.rank)/2+v.rank,e:e};util.addDummyNode(g,"edge-proxy",label,"_ep")}})}function assignRankMinMax(g){var maxRank=0;_.each(g.nodes(),function(v){var node=g.node(v);if(node.borderTop){node.minRank=g.node(node.borderTop).rank;node.maxRank=g.node(node.borderBottom).rank;maxRank=_.max(maxRank,node.maxRank)}});g.graph().maxRank=maxRank}function removeEdgeLabelProxies(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="edge-proxy"){g.edge(node.e).labelRank=node.rank;g.removeNode(v)}})}function translateGraph(g){var minX=Number.POSITIVE_INFINITY,maxX=0,minY=Number.POSITIVE_INFINITY,maxY=0,graphLabel=g.graph(),marginX=graphLabel.marginx||0,marginY=graphLabel.marginy||0;function getExtremes(attrs){var x=attrs.x,y=attrs.y,w=attrs.width,h=attrs.height;minX=Math.min(minX,x-w/2);maxX=Math.max(maxX,x+w/2);minY=Math.min(minY,y-h/2);maxY=Math.max(maxY,y+h/2)}_.each(g.nodes(),function(v){getExtremes(g.node(v))});_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){getExtremes(edge)}});minX-=marginX;minY-=marginY;_.each(g.nodes(),function(v){var node=g.node(v);node.x-=minX;node.y-=minY});_.each(g.edges(),function(e){var edge=g.edge(e);_.each(edge.points,function(p){p.x-=minX;p.y-=minY});if(_.has(edge,"x")){edge.x-=minX}if(_.has(edge,"y")){edge.y-=minY}});graphLabel.width=maxX-minX+marginX;graphLabel.height=maxY-minY+marginY}function assignNodeIntersects(g){_.each(g.edges(),function(e){var edge=g.edge(e),nodeV=g.node(e.v),nodeW=g.node(e.w),p1,p2;if(!edge.points){edge.points=[];p1=nodeW;p2=nodeV}else{p1=edge.points[0];p2=edge.points[edge.points.length-1]}edge.points.unshift(util.intersectRect(nodeV,p1));edge.points.push(util.intersectRect(nodeW,p2))})}function fixupEdgeLabelCoords(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(_.has(edge,"x")){if(edge.labelpos==="l"||edge.labelpos==="r"){edge.width-=edge.labeloffset}switch(edge.labelpos){case"l":edge.x-=edge.width/2+edge.labeloffset;break;case"r":edge.x+=edge.width/2+edge.labeloffset;break}}})}function reversePointsForReversedEdges(g){_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.reversed){edge.points.reverse()}})}function removeBorderNodes(g){_.each(g.nodes(),function(v){if(g.children(v).length){var node=g.node(v),t=g.node(node.borderTop),b=g.node(node.borderBottom),l=g.node(_.last(node.borderLeft)),r=g.node(_.last(node.borderRight));node.width=Math.abs(r.x-l.x);node.height=Math.abs(b.y-t.y);node.x=l.x+node.width/2;node.y=t.y+node.height/2}});_.each(g.nodes(),function(v){if(g.node(v).dummy==="border"){g.removeNode(v)}})}function removeSelfEdges(g){_.each(g.edges(),function(e){if(e.v===e.w){var node=g.node(e.v);if(!node.selfEdges){node.selfEdges=[]}node.selfEdges.push({e:e,label:g.edge(e)});g.removeEdge(e)}})}function insertSelfEdges(g){var layers=util.buildLayerMatrix(g);_.each(layers,function(layer){var orderShift=0;_.each(layer,function(v,i){var node=g.node(v);node.order=i+orderShift;_.each(node.selfEdges,function(selfEdge){util.addDummyNode(g,"selfedge",{width:selfEdge.label.width,height:selfEdge.label.height,rank:node.rank,order:i+ ++orderShift,e:selfEdge.e,label:selfEdge.label},"_se")});delete node.selfEdges})})}function positionSelfEdges(g){_.each(g.nodes(),function(v){var node=g.node(v);if(node.dummy==="selfedge"){var selfNode=g.node(node.e.v),x=selfNode.x+selfNode.width/2,y=selfNode.y,dx=node.x-x,dy=selfNode.height/2;g.setEdge(node.e,node.label);g.removeNode(v);node.label.points=[{x:x+2*dx/3,y:y-dy},{x:x+5*dx/6,y:y-dy},{x:x+dx,y:y},{x:x+5*dx/6,y:y+dy},{x:x+2*dx/3,y:y+dy}];node.label.x=node.x;node.label.y=node.y}})}function selectNumberAttrs(obj,attrs){return _.mapValues(_.pick(obj,attrs),Number)}function canonicalize(attrs){var newAttrs={};_.each(attrs,function(v,k){newAttrs[k.toLowerCase()]=v});return newAttrs}},{"./acyclic":28,"./add-border-segments":29,"./coordinate-system":30,"./graphlib":33,"./lodash":36,"./nesting-graph":37,"./normalize":38,"./order":43,"./parent-dummy-chains":48,"./position":50,"./rank":52,"./util":55}],36:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],37:[function(require,module,exports){var _=require("./lodash"),util=require("./util");module.exports={run:run,cleanup:cleanup};function run(g){var root=util.addDummyNode(g,"root",{},"_root"),depths=treeDepths(g),height=_.max(depths)-1,nodeSep=2*height+1;g.graph().nestingRoot=root;_.each(g.edges(),function(e){g.edge(e).minlen*=nodeSep});var weight=sumWeights(g)+1;_.each(g.children(),function(child){dfs(g,root,nodeSep,weight,height,depths,child)});g.graph().nodeRankFactor=nodeSep}function dfs(g,root,nodeSep,weight,height,depths,v){var children=g.children(v);if(!children.length){if(v!==root){g.setEdge(root,v,{weight:0,minlen:nodeSep})}return}var top=util.addBorderNode(g,"_bt"),bottom=util.addBorderNode(g,"_bb"),label=g.node(v);g.setParent(top,v);label.borderTop=top;g.setParent(bottom,v);label.borderBottom=bottom;_.each(children,function(child){dfs(g,root,nodeSep,weight,height,depths,child);var childNode=g.node(child),childTop=childNode.borderTop?childNode.borderTop:child,childBottom=childNode.borderBottom?childNode.borderBottom:child,thisWeight=childNode.borderTop?weight:2*weight,minlen=childTop!==childBottom?1:height-depths[v]+1;g.setEdge(top,childTop,{weight:thisWeight,minlen:minlen,nestingEdge:true});g.setEdge(childBottom,bottom,{weight:thisWeight,minlen:minlen,nestingEdge:true})});if(!g.parent(v)){g.setEdge(root,top,{weight:0,minlen:height+depths[v]})}}function treeDepths(g){var depths={};function dfs(v,depth){var children=g.children(v);if(children&&children.length){_.each(children,function(child){dfs(child,depth+1)})}depths[v]=depth}_.each(g.children(),function(v){dfs(v,1)});return depths}function sumWeights(g){return _.reduce(g.edges(),function(acc,e){return acc+g.edge(e).weight},0)}function cleanup(g){var graphLabel=g.graph();g.removeNode(graphLabel.nestingRoot);delete graphLabel.nestingRoot;_.each(g.edges(),function(e){var edge=g.edge(e);if(edge.nestingEdge){g.removeEdge(e)}})}},{"./lodash":36,"./util":55}],38:[function(require,module,exports){"use strict";var _=require("./lodash"),util=require("./util");module.exports={run:run,undo:undo};function run(g){g.graph().dummyChains=[];_.each(g.edges(),function(edge){normalizeEdge(g,edge)})}function normalizeEdge(g,e){var v=e.v,vRank=g.node(v).rank,w=e.w,wRank=g.node(w).rank,name=e.name,edgeLabel=g.edge(e),labelRank=edgeLabel.labelRank;if(wRank===vRank+1)return;g.removeEdge(e);var dummy,attrs,i;for(i=0,++vRank;vRank0){if(index%2){weightSum+=tree[index+1]}index=index-1>>1;tree[index]+=entry.weight}cc+=entry.weight*weightSum}));return cc}},{"../lodash":36}],43:[function(require,module,exports){"use strict";var _=require("../lodash"),initOrder=require("./init-order"),crossCount=require("./cross-count"),sortSubgraph=require("./sort-subgraph"),buildLayerGraph=require("./build-layer-graph"),addSubgraphConstraints=require("./add-subgraph-constraints"),Graph=require("../graphlib").Graph,util=require("../util");module.exports=order;function order(g){var maxRank=util.maxRank(g),downLayerGraphs=buildLayerGraphs(g,_.range(1,maxRank+1),"inEdges"),upLayerGraphs=buildLayerGraphs(g,_.range(maxRank-1,-1,-1),"outEdges");var layering=initOrder(g);assignOrder(g,layering);var bestCC=Number.POSITIVE_INFINITY,best;for(var i=0,lastBest=0;lastBest<4;++i,++lastBest){sweepLayerGraphs(i%2?downLayerGraphs:upLayerGraphs,i%4>=2);layering=util.buildLayerMatrix(g);var cc=crossCount(g,layering);if(cc=vEntry.barycenter){mergeEntries(vEntry,uEntry)}}}function handleOut(vEntry){return function(wEntry){wEntry["in"].push(vEntry);if(--wEntry.indegree===0){sourceSet.push(wEntry)}}}while(sourceSet.length){var entry=sourceSet.pop();entries.push(entry);_.each(entry["in"].reverse(),handleIn(entry));_.each(entry.out,handleOut(entry))}return _.chain(entries).filter(function(entry){return!entry.merged}).map(function(entry){return _.pick(entry,["vs","i","barycenter","weight"])}).value()}function mergeEntries(target,source){var sum=0,weight=0;if(target.weight){sum+=target.barycenter*target.weight;weight+=target.weight}if(source.weight){sum+=source.barycenter*source.weight;weight+=source.weight}target.vs=source.vs.concat(target.vs);target.barycenter=sum/weight;target.weight=weight;target.i=Math.min(source.i,target.i);source.merged=true}},{"../lodash":36}],46:[function(require,module,exports){var _=require("../lodash"),barycenter=require("./barycenter"),resolveConflicts=require("./resolve-conflicts"),sort=require("./sort");module.exports=sortSubgraph;function sortSubgraph(g,v,cg,biasRight){var movable=g.children(v),node=g.node(v),bl=node?node.borderLeft:undefined,br=node?node.borderRight:undefined,subgraphs={};if(bl){movable=_.filter(movable,function(w){return w!==bl&&w!==br})}var barycenters=barycenter(g,movable);_.each(barycenters,function(entry){if(g.children(entry.v).length){var subgraphResult=sortSubgraph(g,entry.v,cg,biasRight);subgraphs[entry.v]=subgraphResult;if(_.has(subgraphResult,"barycenter")){mergeBarycenters(entry,subgraphResult)}}});var entries=resolveConflicts(barycenters,cg);expandSubgraphs(entries,subgraphs);var result=sort(entries,biasRight);if(bl){result.vs=_.flatten([bl,result.vs,br],true);if(g.predecessors(bl).length){var blPred=g.node(g.predecessors(bl)[0]),brPred=g.node(g.predecessors(br)[0]);if(!_.has(result,"barycenter")){result.barycenter=0;result.weight=0}result.barycenter=(result.barycenter*result.weight+blPred.order+brPred.order)/(result.weight+2);result.weight+=2}}return result}function expandSubgraphs(entries,subgraphs){_.each(entries,function(entry){entry.vs=_.flatten(entry.vs.map(function(v){if(subgraphs[v]){return subgraphs[v].vs}return v}),true)})}function mergeBarycenters(target,other){if(!_.isUndefined(target.barycenter)){target.barycenter=(target.barycenter*target.weight+other.barycenter*other.weight)/(target.weight+other.weight);target.weight+=other.weight}else{target.barycenter=other.barycenter;target.weight=other.weight}}},{"../lodash":36,"./barycenter":40,"./resolve-conflicts":45,"./sort":47}],47:[function(require,module,exports){var _=require("../lodash"),util=require("../util");module.exports=sort;function sort(entries,biasRight){var parts=util.partition(entries,function(entry){return _.has(entry,"barycenter")});var sortable=parts.lhs,unsortable=_.sortBy(parts.rhs,function(entry){return-entry.i}),vs=[],sum=0,weight=0,vsIndex=0;sortable.sort(compareWithBias(!!biasRight));vsIndex=consumeUnsortable(vs,unsortable,vsIndex);_.each(sortable,function(entry){vsIndex+=entry.vs.length;vs.push(entry.vs);sum+=entry.barycenter*entry.weight;weight+=entry.weight;vsIndex=consumeUnsortable(vs,unsortable,vsIndex)});var result={vs:_.flatten(vs,true)};if(weight){result.barycenter=sum/weight;result.weight=weight}return result}function consumeUnsortable(vs,unsortable,index){var last;while(unsortable.length&&(last=_.last(unsortable)).i<=index){unsortable.pop();vs.push(last.vs);index++}return index}function compareWithBias(bias){return function(entryV,entryW){if(entryV.barycenterentryW.barycenter){return 1}return!bias?entryV.i-entryW.i:entryW.i-entryV.i}}},{"../lodash":36,"../util":55}],48:[function(require,module,exports){var _=require("./lodash");module.exports=parentDummyChains;function parentDummyChains(g){var postorderNums=postorder(g);_.each(g.graph().dummyChains,function(v){var node=g.node(v),edgeObj=node.edgeObj,pathData=findPath(g,postorderNums,edgeObj.v,edgeObj.w),path=pathData.path,lca=pathData.lca,pathIdx=0,pathV=path[pathIdx],ascending=true;while(v!==edgeObj.w){node=g.node(v);if(ascending){while((pathV=path[pathIdx])!==lca&&g.node(pathV).maxRanklow||lim>postorderNums[parent].lim));lca=parent;parent=w;while((parent=g.parent(parent))!==lca){wPath.push(parent)}return{path:vPath.concat(wPath.reverse()),lca:lca}}function postorder(g){var result={},lim=0;function dfs(v){var low=lim;_.each(g.children(v),dfs);result[v]={low:low,lim:lim++}}_.each(g.children(),dfs);return result}},{"./lodash":36}],49:[function(require,module,exports){"use strict";var _=require("../lodash"),Graph=require("../graphlib").Graph,util=require("../util");module.exports={positionX:positionX,findType1Conflicts:findType1Conflicts,findType2Conflicts:findType2Conflicts,addConflict:addConflict,hasConflict:hasConflict,verticalAlignment:verticalAlignment,horizontalCompaction:horizontalCompaction,alignCoordinates:alignCoordinates,findSmallestWidthAlignment:findSmallestWidthAlignment,balance:balance};function findType1Conflicts(g,layering){var conflicts={};function visitLayer(prevLayer,layer){var k0=0,scanPos=0,prevLayerLength=prevLayer.length,lastNode=_.last(layer);_.each(layer,function(v,i){var w=findOtherInnerSegmentNode(g,v),k1=w?g.node(w).order:prevLayerLength;if(w||v===lastNode){_.each(layer.slice(scanPos,i+1),function(scanNode){_.each(g.predecessors(scanNode),function(u){var uLabel=g.node(u),uPos=uLabel.order;if((uPosnextNorthBorder)){addConflict(conflicts,u,v)}})}})}function visitLayer(north,south){var prevNorthPos=-1,nextNorthPos,southPos=0;_.each(south,function(v,southLookahead){if(g.node(v).dummy==="border"){var predecessors=g.predecessors(v);if(predecessors.length){nextNorthPos=g.node(predecessors[0]).order;scan(south,southPos,southLookahead,prevNorthPos,nextNorthPos);southPos=southLookahead;prevNorthPos=nextNorthPos}}scan(south,southPos,south.length,nextNorthPos,north.length)});return south}_.reduce(layering,visitLayer);return conflicts}function findOtherInnerSegmentNode(g,v){if(g.node(v).dummy){return _.find(g.predecessors(v),function(u){return g.node(u).dummy})}}function addConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}var conflictsV=conflicts[v];if(!conflictsV){conflicts[v]=conflictsV={}}conflictsV[w]=true}function hasConflict(conflicts,v,w){if(v>w){var tmp=v;v=w;w=tmp}return _.has(conflicts[v],w)}function verticalAlignment(g,layering,conflicts,neighborFn){var root={},align={},pos={};_.each(layering,function(layer){_.each(layer,function(v,order){root[v]=v;align[v]=v;pos[v]=order})});_.each(layering,function(layer){var prevIdx=-1;_.each(layer,function(v){var ws=neighborFn(v);if(ws.length){ws=_.sortBy(ws,function(w){return pos[w]});var mp=(ws.length-1)/2;for(var i=Math.floor(mp),il=Math.ceil(mp);i<=il;++i){var w=ws[i];if(align[v]===v&&prevIdxwLabel.lim){tailLabel=wLabel;flip=true}var candidates=_.filter(g.edges(),function(edge){return flip===isDescendant(t,t.node(edge.v),tailLabel)&&flip!==isDescendant(t,t.node(edge.w),tailLabel)});return _.min(candidates,function(edge){return slack(g,edge)})}function exchangeEdges(t,g,e,f){var v=e.v,w=e.w;t.removeEdge(v,w);t.setEdge(f.v,f.w,{});initLowLimValues(t);initCutValues(t,g);updateRanks(t,g)}function updateRanks(t,g){var root=_.find(t.nodes(),function(v){return!g.node(v).parent}),vs=preorder(t,root);vs=vs.slice(1);_.each(vs,function(v){var parent=t.node(v).parent,edge=g.edge(v,parent),flipped=false;if(!edge){edge=g.edge(parent,v);flipped=true}g.node(v).rank=g.node(parent).rank+(flipped?edge.minlen:-edge.minlen)})}function isTreeEdge(tree,u,v){return tree.hasEdge(u,v)}function isDescendant(tree,vLabel,rootLabel){return rootLabel.low<=vLabel.lim&&vLabel.lim<=rootLabel.lim}},{"../graphlib":33,"../lodash":36,"../util":55,"./feasible-tree":51,"./util":54}],54:[function(require,module,exports){"use strict";var _=require("../lodash");module.exports={longestPath:longestPath,slack:slack};function longestPath(g){var visited={};function dfs(v){var label=g.node(v);if(_.has(visited,v)){return label.rank}visited[v]=true;var rank=_.min(_.map(g.outEdges(v),function(e){return dfs(e.w)-g.edge(e).minlen}));if(rank===Number.POSITIVE_INFINITY){rank=0}return label.rank=rank}_.each(g.sources(),dfs)}function slack(g,e){return g.node(e.w).rank-g.node(e.v).rank-g.edge(e).minlen}},{"../lodash":36}],55:[function(require,module,exports){"use strict";var _=require("./lodash"),Graph=require("./graphlib").Graph;module.exports={addDummyNode:addDummyNode,simplify:simplify,asNonCompoundGraph:asNonCompoundGraph,successorWeights:successorWeights,predecessorWeights:predecessorWeights,intersectRect:intersectRect,buildLayerMatrix:buildLayerMatrix,normalizeRanks:normalizeRanks,removeEmptyRanks:removeEmptyRanks,addBorderNode:addBorderNode,maxRank:maxRank,partition:partition,time:time,notime:notime};function addDummyNode(g,type,attrs,name){var v;do{v=_.uniqueId(name)}while(g.hasNode(v));attrs.dummy=type;g.setNode(v,attrs);return v}function simplify(g){var simplified=(new Graph).setGraph(g.graph());_.each(g.nodes(),function(v){simplified.setNode(v,g.node(v))});_.each(g.edges(),function(e){var simpleLabel=simplified.edge(e.v,e.w)||{weight:0,minlen:1},label=g.edge(e);simplified.setEdge(e.v,e.w,{weight:simpleLabel.weight+label.weight,minlen:Math.max(simpleLabel.minlen,label.minlen)})});return simplified}function asNonCompoundGraph(g){var simplified=new Graph({multigraph:g.isMultigraph()}).setGraph(g.graph());_.each(g.nodes(),function(v){if(!g.children(v).length){simplified.setNode(v,g.node(v))}});_.each(g.edges(),function(e){simplified.setEdge(e,g.edge(e))});return simplified}function successorWeights(g){var weightMap=_.map(g.nodes(),function(v){var sucs={};_.each(g.outEdges(v),function(e){sucs[e.w]=(sucs[e.w]||0)+g.edge(e).weight});return sucs});return _.zipObject(g.nodes(),weightMap)}function predecessorWeights(g){var weightMap=_.map(g.nodes(),function(v){var preds={};_.each(g.inEdges(v),function(e){preds[e.v]=(preds[e.v]||0)+g.edge(e).weight});return preds});return _.zipObject(g.nodes(),weightMap)}function intersectRect(rect,point){var x=rect.x;var y=rect.y;var dx=point.x-x;var dy=point.y-y;var w=rect.width/2;var h=rect.height/2;if(!dx&&!dy){throw new Error("Not possible to find intersection inside of the rectangle")}var sx,sy;if(Math.abs(dy)*w>Math.abs(dx)*h){if(dy<0){h=-h}sx=h*dx/dy;sy=h}else{if(dx<0){w=-w}sx=w;sy=w*dy/dx}return{x:x+sx,y:y+sy}}function buildLayerMatrix(g){var layering=_.map(_.range(maxRank(g)+1),function(){return[]});_.each(g.nodes(),function(v){var node=g.node(v),rank=node.rank;if(!_.isUndefined(rank)){layering[rank][node.order]=v}});return layering}function normalizeRanks(g){var min=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));_.each(g.nodes(),function(v){var node=g.node(v);if(_.has(node,"rank")){node.rank-=min}})}function removeEmptyRanks(g){var offset=_.min(_.map(g.nodes(),function(v){return g.node(v).rank}));var layers=[];_.each(g.nodes(),function(v){var rank=g.node(v).rank-offset;if(!_.has(layers,rank)){layers[rank]=[]}layers[rank].push(v)});var delta=0,nodeRankFactor=g.graph().nodeRankFactor;_.each(layers,function(vs,i){if(_.isUndefined(vs)&&i%nodeRankFactor!==0){--delta}else if(delta){_.each(vs,function(v){g.node(v).rank+=delta})}})}function addBorderNode(g,prefix,rank,order){var node={width:0,height:0};if(arguments.length>=4){node.rank=rank;node.order=order}return addDummyNode(g,"border",node,prefix)}function maxRank(g){return _.max(_.map(g.nodes(),function(v){var rank=g.node(v).rank;if(!_.isUndefined(rank)){return rank}}))}function partition(collection,fn){var result={lhs:[],rhs:[]};_.each(collection,function(value){if(fn(value)){result.lhs.push(value)}else{result.rhs.push(value)}});return result}function time(name,fn){var start=_.now();try{return fn()}finally{console.log(name+" time: "+(_.now()-start)+"ms")}}function notime(name,fn){return fn()}},{"./graphlib":33,"./lodash":36}],56:[function(require,module,exports){module.exports="0.7.1"},{}],57:[function(require,module,exports){var lib=require("./lib");module.exports={Graph:lib.Graph,json:require("./lib/json"),alg:require("./lib/alg"),version:lib.version}},{"./lib":73,"./lib/alg":64,"./lib/json":74}],58:[function(require,module,exports){var _=require("../lodash");module.exports=components;function components(g){var visited={},cmpts=[],cmpt;function dfs(v){if(_.has(visited,v))return;visited[v]=true;cmpt.push(v);_.each(g.successors(v),dfs);_.each(g.predecessors(v),dfs)}_.each(g.nodes(),function(v){cmpt=[];dfs(v);if(cmpt.length){cmpts.push(cmpt)}});return cmpts}},{"../lodash":75}],59:[function(require,module,exports){var _=require("../lodash");module.exports=dfs;function dfs(g,vs,order){if(!_.isArray(vs)){vs=[vs]}var acc=[],visited={};_.each(vs,function(v){if(!g.hasNode(v)){throw new Error("Graph does not have node: "+v)}doDfs(g,v,order==="post",visited,acc)});return acc}function doDfs(g,v,postorder,visited,acc){if(!_.has(visited,v)){visited[v]=true;if(!postorder){acc.push(v)}_.each(g.neighbors(v),function(w){doDfs(g,w,postorder,visited,acc)});if(postorder){acc.push(v)}}}},{"../lodash":75}],60:[function(require,module,exports){var dijkstra=require("./dijkstra"),_=require("../lodash");module.exports=dijkstraAll;function dijkstraAll(g,weightFunc,edgeFunc){return _.transform(g.nodes(),function(acc,v){acc[v]=dijkstra(g,v,weightFunc,edgeFunc)},{})}},{"../lodash":75,"./dijkstra":61}],61:[function(require,module,exports){var _=require("../lodash"),PriorityQueue=require("../data/priority-queue");module.exports=dijkstra;var DEFAULT_WEIGHT_FUNC=_.constant(1);function dijkstra(g,source,weightFn,edgeFn){return runDijkstra(g,String(source),weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runDijkstra(g,source,weightFn,edgeFn){var results={},pq=new PriorityQueue,v,vEntry;var updateNeighbors=function(edge){var w=edge.v!==v?edge.v:edge.w,wEntry=results[w],weight=weightFn(edge),distance=vEntry.distance+weight;if(weight<0){throw new Error("dijkstra does not allow negative edge weights. "+"Bad edge: "+edge+" Weight: "+weight)}if(distance0){v=pq.removeMin();vEntry=results[v];if(vEntry.distance===Number.POSITIVE_INFINITY){break}edgeFn(v).forEach(updateNeighbors)}return results}},{"../data/priority-queue":71,"../lodash":75}],62:[function(require,module,exports){var _=require("../lodash"),tarjan=require("./tarjan");module.exports=findCycles;function findCycles(g){return _.filter(tarjan(g),function(cmpt){return cmpt.length>1})}},{"../lodash":75,"./tarjan":69}],63:[function(require,module,exports){var _=require("../lodash");module.exports=floydWarshall;var DEFAULT_WEIGHT_FUNC=_.constant(1);function floydWarshall(g,weightFn,edgeFn){return runFloydWarshall(g,weightFn||DEFAULT_WEIGHT_FUNC,edgeFn||function(v){return g.outEdges(v)})}function runFloydWarshall(g,weightFn,edgeFn){var results={},nodes=g.nodes();nodes.forEach(function(v){results[v]={};results[v][v]={distance:0};nodes.forEach(function(w){if(v!==w){results[v][w]={distance:Number.POSITIVE_INFINITY}}});edgeFn(v).forEach(function(edge){var w=edge.v===v?edge.w:edge.v,d=weightFn(edge);results[v][w]={distance:d,predecessor:v}})});nodes.forEach(function(k){var rowK=results[k];nodes.forEach(function(i){var rowI=results[i];nodes.forEach(function(j){var ik=rowI[k];var kj=rowK[j];var ij=rowI[j];var altDistance=ik.distance+kj.distance;if(altDistance0){v=pq.removeMin();if(_.has(parents,v)){result.setEdge(v,parents[v])}else if(init){throw new Error("Input graph is not connected: "+g)}else{init=true}g.nodeEdges(v).forEach(updateNeighbors)}return result}},{"../data/priority-queue":71,"../graph":72,"../lodash":75}],69:[function(require,module,exports){var _=require("../lodash");module.exports=tarjan;function tarjan(g){var index=0,stack=[],visited={},results=[];function dfs(v){var entry=visited[v]={onStack:true,lowlink:index,index:index++};stack.push(v);g.successors(v).forEach(function(w){if(!_.has(visited,w)){dfs(w);entry.lowlink=Math.min(entry.lowlink,visited[w].lowlink)}else if(visited[w].onStack){entry.lowlink=Math.min(entry.lowlink,visited[w].index)}});if(entry.lowlink===entry.index){var cmpt=[],w;do{w=stack.pop();visited[w].onStack=false;cmpt.push(w)}while(v!==w);results.push(cmpt)}}g.nodes().forEach(function(v){if(!_.has(visited,v)){dfs(v)}});return results}},{"../lodash":75}],70:[function(require,module,exports){var _=require("../lodash");module.exports=topsort;topsort.CycleException=CycleException;function topsort(g){var visited={},stack={},results=[];function visit(node){if(_.has(stack,node)){throw new CycleException}if(!_.has(visited,node)){stack[node]=true;visited[node]=true;_.each(g.predecessors(node),visit);delete stack[node];results.push(node)}}_.each(g.sinks(),visit);if(_.size(visited)!==g.nodeCount()){throw new CycleException}return results}function CycleException(){}},{"../lodash":75}],71:[function(require,module,exports){var _=require("../lodash");module.exports=PriorityQueue;function PriorityQueue(){this._arr=[];this._keyIndices={}}PriorityQueue.prototype.size=function(){return this._arr.length};PriorityQueue.prototype.keys=function(){return this._arr.map(function(x){return x.key})};PriorityQueue.prototype.has=function(key){return _.has(this._keyIndices,key)};PriorityQueue.prototype.priority=function(key){var index=this._keyIndices[key];if(index!==undefined){return this._arr[index].priority}};PriorityQueue.prototype.min=function(){if(this.size()===0){throw new Error("Queue underflow")}return this._arr[0].key};PriorityQueue.prototype.add=function(key,priority){var keyIndices=this._keyIndices;key=String(key);if(!_.has(keyIndices,key)){var arr=this._arr;var index=arr.length;keyIndices[key]=index;arr.push({key:key,priority:priority});this._decrease(index);return true}return false};PriorityQueue.prototype.removeMin=function(){this._swap(0,this._arr.length-1);var min=this._arr.pop();delete this._keyIndices[min.key];this._heapify(0);return min.key};PriorityQueue.prototype.decrease=function(key,priority){var index=this._keyIndices[key];if(priority>this._arr[index].priority){throw new Error("New priority is greater than current priority. "+"Key: "+key+" Old: "+this._arr[index].priority+" New: "+priority)}this._arr[index].priority=priority;this._decrease(index)};PriorityQueue.prototype._heapify=function(i){var arr=this._arr;var l=2*i,r=l+1,largest=i;if(l>1;if(arr[parent].priority1){this.setNode(v,value)}else{this.setNode(v)}},this);return this};Graph.prototype.setNode=function(v,value){if(_.has(this._nodes,v)){if(arguments.length>1){this._nodes[v]=value}return this}this._nodes[v]=arguments.length>1?value:this._defaultNodeLabelFn(v);if(this._isCompound){this._parent[v]=GRAPH_NODE;this._children[v]={};this._children[GRAPH_NODE][v]=true}this._in[v]={};this._preds[v]={};this._out[v]={};this._sucs[v]={};++this._nodeCount;return this};Graph.prototype.node=function(v){return this._nodes[v]};Graph.prototype.hasNode=function(v){return _.has(this._nodes,v)};Graph.prototype.removeNode=function(v){var self=this;if(_.has(this._nodes,v)){var removeEdge=function(e){self.removeEdge(self._edgeObjs[e])};delete this._nodes[v];if(this._isCompound){this._removeFromParentsChildList(v);delete this._parent[v];_.each(this.children(v),function(child){this.setParent(child)},this);delete this._children[v]}_.each(_.keys(this._in[v]),removeEdge);delete this._in[v];delete this._preds[v];_.each(_.keys(this._out[v]),removeEdge);delete this._out[v];delete this._sucs[v];--this._nodeCount}return this};Graph.prototype.setParent=function(v,parent){if(!this._isCompound){throw new Error("Cannot set parent in a non-compound graph")}if(_.isUndefined(parent)){parent=GRAPH_NODE}else{for(var ancestor=parent;!_.isUndefined(ancestor);ancestor=this.parent(ancestor)){if(ancestor===v){throw new Error("Setting "+parent+" as parent of "+v+" would create create a cycle")}}this.setNode(parent)}this.setNode(v);this._removeFromParentsChildList(v);this._parent[v]=parent;this._children[parent][v]=true;return this};Graph.prototype._removeFromParentsChildList=function(v){delete this._children[this._parent[v]][v]};Graph.prototype.parent=function(v){if(this._isCompound){var parent=this._parent[v];if(parent!==GRAPH_NODE){return parent}}};Graph.prototype.children=function(v){if(_.isUndefined(v)){v=GRAPH_NODE}if(this._isCompound){var children=this._children[v];if(children){return _.keys(children)}}else if(v===GRAPH_NODE){return this.nodes()}else if(this.hasNode(v)){return[]}};Graph.prototype.predecessors=function(v){var predsV=this._preds[v];if(predsV){return _.keys(predsV)}};Graph.prototype.successors=function(v){var sucsV=this._sucs[v];if(sucsV){return _.keys(sucsV)}};Graph.prototype.neighbors=function(v){var preds=this.predecessors(v);if(preds){return _.union(preds,this.successors(v))}};Graph.prototype.setDefaultEdgeLabel=function(newDefault){if(!_.isFunction(newDefault)){newDefault=_.constant(newDefault)}this._defaultEdgeLabelFn=newDefault;return this};Graph.prototype.edgeCount=function(){return this._edgeCount};Graph.prototype.edges=function(){return _.values(this._edgeObjs)};Graph.prototype.setPath=function(vs,value){var self=this,args=arguments;_.reduce(vs,function(v,w){if(args.length>1){self.setEdge(v,w,value)}else{self.setEdge(v,w)}return w});return this};Graph.prototype.setEdge=function(){var v,w,name,value,valueSpecified=false;if(_.isPlainObject(arguments[0])){v=arguments[0].v;w=arguments[0].w;name=arguments[0].name;if(arguments.length===2){value=arguments[1];valueSpecified=true}}else{v=arguments[0];w=arguments[1];name=arguments[3];if(arguments.length>2){value=arguments[2];valueSpecified=true}}v=""+v;w=""+w;if(!_.isUndefined(name)){name=""+name}var e=edgeArgsToId(this._isDirected,v,w,name);if(_.has(this._edgeLabels,e)){if(valueSpecified){this._edgeLabels[e]=value}return this}if(!_.isUndefined(name)&&!this._isMultigraph){throw new Error("Cannot set a named edge when isMultigraph = false")}this.setNode(v);this.setNode(w);this._edgeLabels[e]=valueSpecified?value:this._defaultEdgeLabelFn(v,w,name);var edgeObj=edgeArgsToObj(this._isDirected,v,w,name);v=edgeObj.v;w=edgeObj.w;Object.freeze(edgeObj);this._edgeObjs[e]=edgeObj;incrementOrInitEntry(this._preds[w],v);incrementOrInitEntry(this._sucs[v],w);this._in[w][e]=edgeObj;this._out[v][e]=edgeObj;this._edgeCount++;return this};Graph.prototype.edge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return this._edgeLabels[e]};Graph.prototype.hasEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name);return _.has(this._edgeLabels,e)};Graph.prototype.removeEdge=function(v,w,name){var e=arguments.length===1?edgeObjToId(this._isDirected,arguments[0]):edgeArgsToId(this._isDirected,v,w,name),edge=this._edgeObjs[e];if(edge){v=edge.v;w=edge.w;delete this._edgeLabels[e];delete this._edgeObjs[e];decrementOrRemoveEntry(this._preds[w],v);decrementOrRemoveEntry(this._sucs[v],w);delete this._in[w][e];delete this._out[v][e];this._edgeCount--}return this};Graph.prototype.inEdges=function(v,u){var inV=this._in[v];if(inV){var edges=_.values(inV);if(!u){return edges}return _.filter(edges,function(edge){return edge.v===u})}};Graph.prototype.outEdges=function(v,w){var outV=this._out[v];if(outV){var edges=_.values(outV);if(!w){return edges}return _.filter(edges,function(edge){return edge.w===w})}};Graph.prototype.nodeEdges=function(v,w){var inEdges=this.inEdges(v,w);if(inEdges){return inEdges.concat(this.outEdges(v,w))}};function incrementOrInitEntry(map,k){if(_.has(map,k)){map[k]++}else{map[k]=1}}function decrementOrRemoveEntry(map,k){if(!--map[k]){delete map[k]}}function edgeArgsToId(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}return v+EDGE_KEY_DELIM+w+EDGE_KEY_DELIM+(_.isUndefined(name)?DEFAULT_EDGE_NAME:name)}function edgeArgsToObj(isDirected,v,w,name){if(!isDirected&&v>w){var tmp=v;v=w;w=tmp}var edgeObj={v:v,w:w};if(name){edgeObj.name=name}return edgeObj}function edgeObjToId(isDirected,edgeObj){return edgeArgsToId(isDirected,edgeObj.v,edgeObj.w,edgeObj.name)}},{"./lodash":75}],73:[function(require,module,exports){module.exports={Graph:require("./graph"),version:require("./version")}},{"./graph":72,"./version":76}],74:[function(require,module,exports){var _=require("./lodash"),Graph=require("./graph");module.exports={write:write,read:read};function write(g){var json={options:{directed:g.isDirected(),multigraph:g.isMultigraph(),compound:g.isCompound()},nodes:writeNodes(g),edges:writeEdges(g)};if(!_.isUndefined(g.graph())){json.value=_.clone(g.graph())}return json}function writeNodes(g){return _.map(g.nodes(),function(v){var nodeValue=g.node(v),parent=g.parent(v),node={v:v};if(!_.isUndefined(nodeValue)){node.value=nodeValue}if(!_.isUndefined(parent)){node.parent=parent}return node})}function writeEdges(g){return _.map(g.edges(),function(e){var edgeValue=g.edge(e),edge={v:e.v,w:e.w};if(!_.isUndefined(e.name)){edge.name=e.name}if(!_.isUndefined(edgeValue)){edge.value=edgeValue}return edge})}function read(json){var g=new Graph(json.options).setGraph(json.value);_.each(json.nodes,function(entry){g.setNode(entry.v,entry.value);if(entry.parent){g.setParent(entry.v,entry.parent)}});_.each(json.edges,function(entry){g.setEdge({v:entry.v,w:entry.w,name:entry.name},entry.value)});return g}},{"./graph":72,"./lodash":75}],75:[function(require,module,exports){module.exports=require(20)},{"/Users/andrew/Documents/dev/dagre-d3/lib/lodash.js":20,lodash:77}],76:[function(require,module,exports){module.exports="1.0.1"},{}],77:[function(require,module,exports){(function(global){(function(){var undefined;var arrayPool=[],objectPool=[];var idCounter=0;var keyPrefix=+new Date+"";var largeArraySize=75;var maxPoolSize=40;var whitespace=" \f \ufeff"+"\n\r\u2028\u2029"+" ᠎              ";var reEmptyStringLeading=/\b__p \+= '';/g,reEmptyStringMiddle=/\b(__p \+=) '' \+/g,reEmptyStringTrailing=/(__e\(.*?\)|\b__t\)) \+\n'';/g;var reEsTemplate=/\$\{([^\\}]*(?:\\.[^\\}]*)*)\}/g;var reFlags=/\w*$/;var reFuncName=/^\s*function[ \n\r\t]+\w/;var reInterpolate=/<%=([\s\S]+?)%>/g;var reLeadingSpacesAndZeros=RegExp("^["+whitespace+"]*0+(?=.$)");var reNoMatch=/($^)/;var reThis=/\bthis\b/;var reUnescapedString=/['\n\r\t\u2028\u2029\\]/g;var contextProps=["Array","Boolean","Date","Function","Math","Number","Object","RegExp","String","_","attachEvent","clearTimeout","isFinite","isNaN","parseInt","setTimeout"];var templateCounter=0;var argsClass="[object Arguments]",arrayClass="[object Array]",boolClass="[object Boolean]",dateClass="[object Date]",funcClass="[object Function]",numberClass="[object Number]",objectClass="[object Object]",regexpClass="[object RegExp]",stringClass="[object String]";var cloneableClasses={};cloneableClasses[funcClass]=false;cloneableClasses[argsClass]=cloneableClasses[arrayClass]=cloneableClasses[boolClass]=cloneableClasses[dateClass]=cloneableClasses[numberClass]=cloneableClasses[objectClass]=cloneableClasses[regexpClass]=cloneableClasses[stringClass]=true;var debounceOptions={leading:false,maxWait:0,trailing:false};var descriptor={configurable:false,enumerable:false,value:null,writable:false};var objectTypes={"boolean":false,"function":true,object:true,number:false,string:false,undefined:false};var stringEscapes={"\\":"\\","'":"'","\n":"n","\r":"r"," ":"t","\u2028":"u2028","\u2029":"u2029"};var root=objectTypes[typeof window]&&window||this;var freeExports=objectTypes[typeof exports]&&exports&&!exports.nodeType&&exports;var freeModule=objectTypes[typeof module]&&module&&!module.nodeType&&module;var moduleExports=freeModule&&freeModule.exports===freeExports&&freeExports;var freeGlobal=objectTypes[typeof global]&&global;if(freeGlobal&&(freeGlobal.global===freeGlobal||freeGlobal.window===freeGlobal)){root=freeGlobal}function baseIndexOf(array,value,fromIndex){var index=(fromIndex||0)-1,length=array?array.length:0;while(++index-1?0:-1:cache?0:-1}function cachePush(value){var cache=this.cache,type=typeof value;if(type=="boolean"||value==null){cache[value]=true}else{if(type!="number"&&type!="string"){type="object"}var key=type=="number"?value:keyPrefix+value,typeCache=cache[type]||(cache[type]={});if(type=="object"){(typeCache[key]||(typeCache[key]=[])).push(value)}else{typeCache[key]=true}}}function charAtCallback(value){return value.charCodeAt(0)}function compareAscending(a,b){var ac=a.criteria,bc=b.criteria,index=-1,length=ac.length;while(++indexother||typeof value=="undefined"){return 1}if(value/g,evaluate:/<%([\s\S]+?)%>/g,interpolate:reInterpolate,variable:"",imports:{_:lodash}};function baseBind(bindData){var func=bindData[0],partialArgs=bindData[2],thisArg=bindData[4];function bound(){if(partialArgs){var args=slice(partialArgs);push.apply(args,arguments)}if(this instanceof bound){var thisBinding=baseCreate(func.prototype),result=func.apply(thisBinding,args||arguments);return isObject(result)?result:thisBinding}return func.apply(thisArg,args||arguments)}setBindData(bound,bindData);return bound}function baseClone(value,isDeep,callback,stackA,stackB){if(callback){var result=callback(value);if(typeof result!="undefined"){return result}}var isObj=isObject(value);if(isObj){var className=toString.call(value);if(!cloneableClasses[className]){return value}var ctor=ctorByClass[className];switch(className){case boolClass:case dateClass:return new ctor(+value);case numberClass:case stringClass:return new ctor(value);case regexpClass:result=ctor(value.source,reFlags.exec(value));result.lastIndex=value.lastIndex;return result}}else{return value}var isArr=isArray(value);if(isDeep){var initedStack=!stackA;stackA||(stackA=getArray());stackB||(stackB=getArray());var length=stackA.length;while(length--){if(stackA[length]==value){return stackB[length]}}result=isArr?ctor(value.length):{}}else{result=isArr?slice(value):assign({},value)}if(isArr){if(hasOwnProperty.call(value,"index")){result.index=value.index}if(hasOwnProperty.call(value,"input")){result.input=value.input}}if(!isDeep){return result}stackA.push(value);stackB.push(result);(isArr?forEach:forOwn)(value,function(objValue,key){result[key]=baseClone(objValue,isDeep,callback,stackA,stackB)});if(initedStack){releaseArray(stackA);releaseArray(stackB)}return result}function baseCreate(prototype,properties){return isObject(prototype)?nativeCreate(prototype):{}; diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index dbacbf19beee5..dde6069000bc4 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -100,7 +100,7 @@ sorttable = { this.removeChild(document.getElementById('sorttable_sortfwdind')); sortrevind = document.createElement('span'); sortrevind.id = "sorttable_sortrevind"; - sortrevind.innerHTML = stIsIE ? ' 5' : ' ▴'; + sortrevind.innerHTML = stIsIE ? ' 5' : ' ▾'; this.appendChild(sortrevind); return; } @@ -113,7 +113,7 @@ sorttable = { this.removeChild(document.getElementById('sorttable_sortrevind')); sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); return; } @@ -134,7 +134,7 @@ sorttable = { this.className += ' sorttable_sorted'; sortfwdind = document.createElement('span'); sortfwdind.id = "sorttable_sortfwdind"; - sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▾'; + sortfwdind.innerHTML = stIsIE ? ' 6' : ' ▴'; this.appendChild(sortfwdind); // build an array to sort. This is a Schwartzian transform thing, diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css index 18c72694f3e2d..3b4ae2ed354b8 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css @@ -15,33 +15,26 @@ * limitations under the License. */ -#dag-viz-graph svg path { - stroke: #444; - stroke-width: 1.5px; +#dag-viz-graph a, #dag-viz-graph a:hover { + text-decoration: none; } -#dag-viz-graph svg g.cluster rect { - stroke-width: 1px; +#dag-viz-graph .label { + font-weight: normal; + text-shadow: none; } -#dag-viz-graph svg g.node circle { - fill: #444; +#dag-viz-graph svg path { + stroke: #444; + stroke-width: 1.5px; } -#dag-viz-graph svg g.node rect { - fill: #C3EBFF; - stroke: #3EC0FF; +#dag-viz-graph svg g.cluster rect { stroke-width: 1px; } -#dag-viz-graph svg g.node.cached circle { - fill: #444; -} - -#dag-viz-graph svg g.node.cached rect { - fill: #B3F5C5; - stroke: #56F578; - stroke-width: 1px; +#dag-viz-graph div#empty-dag-viz-message { + margin: 15px; } /* Job page specific styles */ @@ -57,12 +50,23 @@ stroke-width: 1px; } -#dag-viz-graph svg.job g.cluster[id*="stage"] rect { +#dag-viz-graph svg.job g.cluster.skipped rect { + fill: #D6D6D6; + stroke: #B7B7B7; + stroke-width: 1px; +} + +#dag-viz-graph svg.job g.cluster.stage rect { fill: #FFFFFF; stroke: #FF99AC; stroke-width: 1px; } +#dag-viz-graph svg.job g.cluster.stage.skipped rect { + stroke: #ADADAD; + stroke-width: 1px; +} + #dag-viz-graph svg.job g#cross-stage-edges path { fill: none; } @@ -71,6 +75,20 @@ fill: #333; } +#dag-viz-graph svg.job g.cluster.skipped text { + fill: #666; +} + +#dag-viz-graph svg.job g.node circle { + fill: #444; +} + +#dag-viz-graph svg.job g.node.cached circle { + fill: #A3F545; + stroke: #52C366; + stroke-width: 2px; +} + /* Stage page specific styles */ #dag-viz-graph svg.stage g.cluster rect { @@ -79,7 +97,7 @@ stroke-width: 1px; } -#dag-viz-graph svg.stage g.cluster[id*="stage"] rect { +#dag-viz-graph svg.stage g.cluster.stage rect { fill: #FFFFFF; stroke: #FFA6B6; stroke-width: 1px; @@ -93,11 +111,14 @@ fill: #333; } -#dag-viz-graph a, #dag-viz-graph a:hover { - text-decoration: none; +#dag-viz-graph svg.stage g.node rect { + fill: #C3EBFF; + stroke: #3EC0FF; + stroke-width: 1px; } -#dag-viz-graph .label { - font-weight: normal; - text-shadow: none; +#dag-viz-graph svg.stage g.node.cached rect { + fill: #B3F5C5; + stroke: #52C366; + stroke-width: 2px; } diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index f7d0d3c61457c..e96af8768daa0 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -57,9 +57,7 @@ var VizConstants = { stageSep: 40, graphPrefix: "graph_", nodePrefix: "node_", - stagePrefix: "stage_", - clusterPrefix: "cluster_", - stageClusterPrefix: "cluster_stage_" + clusterPrefix: "cluster_" }; var JobPageVizConstants = { @@ -86,7 +84,7 @@ function toggleDagViz(forJob) { $(arrowSelector).toggleClass('arrow-open'); var shouldShow = $(arrowSelector).hasClass("arrow-open"); if (shouldShow) { - var shouldRender = graphContainer().select("svg").empty(); + var shouldRender = graphContainer().select("*").empty(); if (shouldRender) { renderDagViz(forJob); } @@ -108,7 +106,7 @@ function toggleDagViz(forJob) { * Output DOM hierarchy: * div#dag-viz-graph > * svg > - * g#cluster_stage_[stageId] + * g.cluster_stage_[stageId] * * Note that the input metadata is populated by o.a.s.ui.UIUtils.showDagViz. * Any changes in the input format here must be reflected there. @@ -117,17 +115,23 @@ function renderDagViz(forJob) { // If there is not a dot file to render, fail fast and report error var jobOrStage = forJob ? "job" : "stage"; - if (metadataContainer().empty()) { - graphContainer() - .append("div") - .text("No visualization information available for this " + jobOrStage); + if (metadataContainer().empty() || + metadataContainer().selectAll("div").empty()) { + var message = + "No visualization information available for this " + jobOrStage + "!
" + + "If this is an old " + jobOrStage + ", its visualization metadata may have been " + + "cleaned up over time.
You may consider increasing the value of "; + if (forJob) { + message += "spark.ui.retainedJobs and spark.ui.retainedStages."; + } else { + message += "spark.ui.retainedStages"; + } + graphContainer().append("div").attr("id", "empty-dag-viz-message").html(message); return; } // Render - var svg = graphContainer() - .append("svg") - .attr("class", jobOrStage); + var svg = graphContainer().append("svg").attr("class", jobOrStage); if (forJob) { renderDagVizForJob(svg); } else { @@ -137,7 +141,7 @@ function renderDagViz(forJob) { // Find cached RDDs and mark them as such metadataContainer().selectAll(".cached-rdd").each(function(v) { var nodeId = VizConstants.nodePrefix + d3.select(this).text(); - svg.selectAll("#" + nodeId).classed("cached", true); + svg.selectAll("g." + nodeId).classed("cached", true); }); resizeSvg(svg); @@ -177,29 +181,35 @@ function renderDagVizForJob(svgContainer) { var dot = metadata.select(".dot-file").text(); var stageId = metadata.attr("stage-id"); var containerId = VizConstants.graphPrefix + stageId; - // Link each graph to the corresponding stage page (TODO: handle stage attempts) - var stageLink = "/stages/stage/?id=" + - stageId.replace(VizConstants.stagePrefix, "") + "&attempt=0&expandDagViz=true"; - var container = svgContainer - .append("a") - .attr("xlink:href", stageLink) - .append("g") - .attr("id", containerId); + var isSkipped = metadata.attr("skipped") == "true"; + var container; + if (isSkipped) { + container = svgContainer + .append("g") + .attr("id", containerId) + .attr("skipped", "true"); + } else { + // Link each graph to the corresponding stage page (TODO: handle stage attempts) + // Use the link from the stage table so it also works for the history server + var attemptId = 0 + var stageLink = d3.select("#stage-" + stageId + "-" + attemptId) + .select("a.name-link") + .attr("href") + "&expandDagViz=true"; + container = svgContainer + .append("a") + .attr("xlink:href", stageLink) + .append("g") + .attr("id", containerId); + } // Now we need to shift the container for this stage so it doesn't overlap with // existing ones, taking into account the position and width of the last stage's // container. We do not need to do this for the first stage of this job. if (i > 0) { - var existingStages = svgContainer - .selectAll("g.cluster") - .filter("[id*=\"" + VizConstants.stageClusterPrefix + "\"]"); + var existingStages = svgContainer.selectAll("g.cluster.stage") if (!existingStages.empty()) { var lastStage = d3.select(existingStages[0].pop()); - var lastStageId = lastStage.attr("id"); - var lastStageWidth = toFloat(svgContainer - .select("#" + lastStageId) - .select("rect") - .attr("width")); + var lastStageWidth = toFloat(lastStage.select("rect").attr("width")); var lastStagePosition = getAbsolutePosition(lastStage); var offset = lastStagePosition.x + lastStageWidth + VizConstants.stageSep; container.attr("transform", "translate(" + offset + ", 0)"); @@ -209,6 +219,12 @@ function renderDagVizForJob(svgContainer) { // Actually render the stage renderDot(dot, container, true); + // Mark elements as skipped if appropriate. Unfortunately we need to mark all + // elements instead of the parent container because of CSS override rules. + if (isSkipped) { + container.selectAll("g").classed("skipped", true); + } + // Round corners on rectangles container .selectAll("rect") @@ -238,6 +254,9 @@ function renderDot(dot, container, forJob) { var renderer = new dagreD3.render(); preprocessGraphLayout(g, forJob); renderer(container, g); + + // Find the stage cluster and mark it for styling and post-processing + container.selectAll("g.cluster[name*=\"Stage\"]").classed("stage", true); } /* -------------------- * @@ -372,14 +391,14 @@ function getAbsolutePosition(d3selection) { function connectRDDs(fromRDDId, toRDDId, edgesContainer, svgContainer) { var fromNodeId = VizConstants.nodePrefix + fromRDDId; var toNodeId = VizConstants.nodePrefix + toRDDId; - var fromPos = getAbsolutePosition(svgContainer.select("#" + fromNodeId)); - var toPos = getAbsolutePosition(svgContainer.select("#" + toNodeId)); + var fromPos = getAbsolutePosition(svgContainer.select("g." + fromNodeId)); + var toPos = getAbsolutePosition(svgContainer.select("g." + toNodeId)); // On the job page, RDDs are rendered as dots (circles). When rendering the path, // we need to account for the radii of these circles. Otherwise the arrow heads // will bleed into the circle itself. var delta = toFloat(svgContainer - .select("g.node#" + toNodeId) + .select("g.node." + toNodeId) .select("circle") .attr("r")); if (fromPos.x < toPos.x) { @@ -431,10 +450,35 @@ function addTooltipsForRDDs(svgContainer) { node.select("circle") .attr("data-toggle", "tooltip") .attr("data-placement", "bottom") - .attr("title", tooltipText) + .attr("title", tooltipText); } + // Link tooltips for all nodes that belong to the same RDD + node.on("mouseenter", function() { triggerTooltipForRDD(node, true); }); + node.on("mouseleave", function() { triggerTooltipForRDD(node, false); }); }); - $("[data-toggle=tooltip]").tooltip({container: "body"}); + + $("[data-toggle=tooltip]") + .filter("g.node circle") + .tooltip({ container: "body", trigger: "manual" }); +} + +/* + * (Job page only) Helper function to show or hide tooltips for all nodes + * in the graph that refer to the same RDD the specified node represents. + */ +function triggerTooltipForRDD(d3node, show) { + var classes = d3node.node().classList; + for (var i = 0; i < classes.length; i++) { + var clazz = classes[i]; + var isRDDClass = clazz.indexOf(VizConstants.nodePrefix) == 0; + if (isRDDClass) { + graphContainer().selectAll("g." + clazz).each(function() { + var circle = d3.select(this).select("circle").node(); + var showOrHide = show ? "show" : "hide"; + $(circle).tooltip(showOrHide); + }); + } + } } /* Helper function to convert attributes to numeric values. */ diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css index d1e6d462b836f..0f400461c5293 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.css @@ -24,6 +24,65 @@ div#application-timeline, div#job-timeline { margin-top: 5px; } +#task-assignment-timeline div.legend-area { + width: 574px; +} + +#task-assignment-timeline .legend-area > svg { + width: 100%; + height: 55px; +} + +#task-assignment-timeline div.item.range { + padding: 0px; + height: 26px; + border-width: 0; +} + +.task-assignment-timeline-content { + width: 100%; +} + +.task-assignment-timeline-duration-bar { + width: 100%; + height: 26px; +} + +rect.scheduler-delay-proportion { + fill: #80B1D3; + stroke: #6B94B0; +} + +rect.deserialization-time-proportion { + fill: #FB8072; + stroke: #D26B5F; +} + +rect.shuffle-read-time-proportion { + fill: #FDB462; + stroke: #D39651; +} + +rect.executor-runtime-proportion { + fill: #B3DE69; + stroke: #95B957; +} + +rect.shuffle-write-time-proportion { + fill: #FFED6F; + stroke: #D5C65C; +} + +rect.serialization-time-proportion { + fill: #BC80BD; + stroke: #9D6B9E; +} + +rect.getting-result-time-proportion { + fill: #8DD3C7; + stroke: #75B0A6; +} + .vis.timeline { line-height: 14px; } @@ -178,6 +237,10 @@ tr.corresponding-item-hover > td, tr.corresponding-item-hover > th { display: none; } +#task-assignment-timeline.collapsed { + display: none; +} + .control-panel { margin-bottom: 5px; } @@ -186,7 +249,8 @@ tr.corresponding-item-hover > td, tr.corresponding-item-hover > th { margin: 0; } -span.expand-application-timeline, span.expand-job-timeline { +span.expand-application-timeline, span.expand-job-timeline, +span.expand-task-assignment-timeline { cursor: pointer; } diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index 558beb8a5867f..ca74ef9d7e94e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -46,7 +46,7 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) { }; $(this).click(function() { - var jobPagePath = $(getSelectorForJobEntry(this)).find("a").attr("href") + var jobPagePath = $(getSelectorForJobEntry(this)).find("a.name-link").attr("href") window.location.href = jobPagePath }); @@ -105,7 +105,7 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { }; $(this).click(function() { - var stagePagePath = $(getSelectorForStageEntry(this)).find("a").attr("href") + var stagePagePath = $(getSelectorForStageEntry(this)).find("a.name-link").attr("href") window.location.href = stagePagePath }); @@ -133,6 +133,57 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { }); } +function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime) { + var groups = new vis.DataSet(groupArray); + var items = new vis.DataSet(eventObjArray); + var container = $("#task-assignment-timeline")[0] + var options = { + groupOrder: function(a, b) { + return a.value - b.value + }, + editable: false, + align: 'left', + selectable: false, + showCurrentTime: false, + min: minLaunchTime, + max: maxFinishTime, + zoomable: false + }; + + var taskTimeline = new vis.Timeline(container) + taskTimeline.setOptions(options); + taskTimeline.setGroups(groups); + taskTimeline.setItems(items); + + // If a user zooms while a tooltip is displayed, the user may zoom such that the cursor is no + // longer over the task that the tooltip corresponds to. So, when a user zooms, we should hide + // any currently displayed tooltips. + var currentDisplayedTooltip = null; + $("#task-assignment-timeline").on({ + "mouseenter": function() { + currentDisplayedTooltip = this; + }, + "mouseleave": function() { + currentDisplayedTooltip = null; + } + }, ".task-assignment-timeline-content"); + taskTimeline.on("rangechange", function(prop) { + if (currentDisplayedTooltip !== null) { + $(currentDisplayedTooltip).tooltip("hide"); + } + }); + + setupZoomable("#task-assignment-timeline-zoom-lock", taskTimeline); + + $("span.expand-task-assignment-timeline").click(function() { + $("#task-assignment-timeline").toggleClass("collapsed"); + + // Switch the class of the arrow from open to closed. + $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-open"); + $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-closed"); + }); +} + function setupExecutorEventAction() { $(".item.box.executor").each(function () { $(this).hover( @@ -147,7 +198,7 @@ function setupExecutorEventAction() { } function setupZoomable(id, timeline) { - $(id + '>input[type="checkbox"]').click(function() { + $(id + ' > input[type="checkbox"]').click(function() { if (this.checked) { timeline.setOptions({zoomable: true}); } else { @@ -155,7 +206,7 @@ function setupZoomable(id, timeline) { } }); - $(id + ">span").click(function() { + $(id + " > span").click(function() { $(this).parent().find('input:checkbox').trigger('click'); }); } diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 330df1d59a9b1..5a8d17bd99933 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -228,7 +228,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa * @tparam T result type */ class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) - extends Accumulable[T,T](initialValue, param, name) { + extends Accumulable[T, T](initialValue, param, name) { def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) } diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index af9765d313e9e..ceeb58075d345 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -34,8 +34,8 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - // When spilling is enabled sorting will happen externally, but not necessarily with an - // ExternalSorter. + // When spilling is enabled sorting will happen externally, but not necessarily with an + // ExternalSorter. private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") @@ -45,7 +45,7 @@ case class Aggregator[K, V, C] ( def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]], context: TaskContext): Iterator[(K, C)] = { if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K,C] + val combiners = new AppendOnlyMap[K, C] var kv: Product2[K, V] = null val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) @@ -76,7 +76,7 @@ case class Aggregator[K, V, C] ( : Iterator[(K, C)] = { if (!isSpillEnabled) { - val combiners = new AppendOnlyMap[K,C] + val combiners = new AppendOnlyMap[K, C] var kc: Product2[K, C] = null val update = (hadValue: Boolean, oldValue: C) => { if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 66bda68088502..9514604752640 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -91,7 +91,7 @@ private[spark] class ExecutorAllocationManager( // How long there must be backlogged tasks for before an addition is triggered (seconds) private val schedulerBacklogTimeoutS = conf.getTimeAsSeconds( - "spark.dynamicAllocation.schedulerBacklogTimeout", "5s") + "spark.dynamicAllocation.schedulerBacklogTimeout", "1s") // Same as above, but used only after `schedulerBacklogTimeoutS` is exceeded private val sustainedSchedulerBacklogTimeoutS = conf.getTimeAsSeconds( @@ -99,7 +99,7 @@ private[spark] class ExecutorAllocationManager( // How long an executor must be idle for before it is removed (seconds) private val executorIdleTimeoutS = conf.getTimeAsSeconds( - "spark.dynamicAllocation.executorIdleTimeout", "600s") + "spark.dynamicAllocation.executorIdleTimeout", "60s") // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) @@ -268,6 +268,8 @@ private[spark] class ExecutorAllocationManager( numExecutorsTarget = math.max(maxNeeded, minNumExecutors) client.requestTotalExecutors(numExecutorsTarget) numExecutorsToAdd = 1 + logInfo(s"Lowering target number of executors to $numExecutorsTarget because " + + s"not all requests are actually needed (previously $oldNumExecutorsTarget)") numExecutorsTarget - oldNumExecutorsTarget } else if (addTime != NOT_SET && now >= addTime) { val delta = addExecutors(maxNeeded) @@ -292,9 +294,8 @@ private[spark] class ExecutorAllocationManager( private def addExecutors(maxNumExecutorsNeeded: Int): Int = { // Do not request more executors if it would put our target over the upper bound if (numExecutorsTarget >= maxNumExecutors) { - val numExecutorsPending = numExecutorsTarget - executorIds.size - logDebug(s"Not adding executors because there are already ${executorIds.size} registered " + - s"and ${numExecutorsPending} pending executor(s) (limit $maxNumExecutors)") + logDebug(s"Not adding executors because our current target total " + + s"is already $numExecutorsTarget (limit $maxNumExecutors)") numExecutorsToAdd = 1 return 0 } @@ -310,10 +311,19 @@ private[spark] class ExecutorAllocationManager( // Ensure that our target fits within configured bounds: numExecutorsTarget = math.max(math.min(numExecutorsTarget, maxNumExecutors), minNumExecutors) + val delta = numExecutorsTarget - oldNumExecutorsTarget + + // If our target has not changed, do not send a message + // to the cluster manager and reset our exponential growth + if (delta == 0) { + numExecutorsToAdd = 1 + return 0 + } + val addRequestAcknowledged = testing || client.requestTotalExecutors(numExecutorsTarget) if (addRequestAcknowledged) { - val delta = numExecutorsTarget - oldNumExecutorsTarget - logInfo(s"Requesting $delta new executor(s) because tasks are backlogged" + + val executorsString = "executor" + { if (delta > 1) "s" else "" } + logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" + s" (new desired total will be $numExecutorsTarget)") numExecutorsToAdd = if (delta == numExecutorsToAdd) { numExecutorsToAdd * 2 @@ -420,7 +430,7 @@ private[spark] class ExecutorAllocationManager( * This resets all variables used for adding executors. */ private def onSchedulerQueueEmpty(): Unit = synchronized { - logDebug(s"Clearing timer to add executors because there are no more pending tasks") + logDebug("Clearing timer to add executors because there are no more pending tasks") addTime = NOT_SET numExecutorsToAdd = 1 } diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 91f9ef8ce7185..48792a958130c 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -150,7 +150,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def isCompleted: Boolean = jobWaiter.jobFinished - + override def isCancelled: Boolean = _cancelled override def value: Option[Try[T]] = { diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index f2b024ff6cb67..6909015ff66e6 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -29,7 +29,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} /** * A heartbeat from executors to the driver. This is a shared message used by several internal - * components to convey liveness or execution information for in-progress tasks. It will also + * components to convey liveness or execution information for in-progress tasks. It will also * expire the hosts that have not heartbeated for more than spark.network.timeout. */ private[spark] case class Heartbeat( @@ -43,8 +43,8 @@ private[spark] case class Heartbeat( */ private[spark] case object TaskSchedulerIsSet -private[spark] case object ExpireDeadHosts - +private[spark] case object ExpireDeadHosts + private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** @@ -62,18 +62,18 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) // "spark.network.timeout" uses "seconds", while `spark.storage.blockManagerSlaveTimeoutMs` uses // "milliseconds" - private val slaveTimeoutMs = + private val slaveTimeoutMs = sc.conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", "120s") - private val executorTimeoutMs = + private val executorTimeoutMs = sc.conf.getTimeAsSeconds("spark.network.timeout", s"${slaveTimeoutMs}ms") * 1000 - + // "spark.network.timeoutInterval" uses "seconds", while // "spark.storage.blockManagerTimeoutIntervalMs" uses "milliseconds" - private val timeoutIntervalMs = + private val timeoutIntervalMs = sc.conf.getTimeAsMs("spark.storage.blockManagerTimeoutIntervalMs", "60s") - private val checkTimeoutIntervalMs = + private val checkTimeoutIntervalMs = sc.conf.getTimeAsSeconds("spark.network.timeoutInterval", s"${timeoutIntervalMs}ms") * 1000 - + private var timeoutCheckingTask: ScheduledFuture[_] = null // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not @@ -140,7 +140,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) } } } - + override def onStop(): Unit = { if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel(true) diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 7e706bcc42f04..7cf7bc0dc6810 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -50,8 +50,8 @@ private[spark] class HttpFileServer( def stop() { httpServer.stop() - - // If we only stop sc, but the driver process still run as a services then we need to delete + + // If we only stop sc, but the driver process still run as a services then we need to delete // the tmp dir, if not, it will create too many tmp dirs try { Utils.deleteRecursively(baseDir) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index b8d244408bc5b..82889bcd30988 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -103,7 +103,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { */ class RangePartitioner[K : Ordering : ClassTag, V]( @transient partitions: Int, - @transient rdd: RDD[_ <: Product2[K,V]], + @transient rdd: RDD[_ <: Product2[K, V]], private var ascending: Boolean = true) extends Partitioner { @@ -185,7 +185,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } override def equals(other: Any): Boolean = other match { - case r: RangePartitioner[_,_] => + case r: RangePartitioner[_, _] => r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending case _ => false @@ -249,7 +249,7 @@ private[spark] object RangePartitioner { * @param sampleSizePerPartition max sample size per partition * @return (total number of items, an array of (partitionId, number of items, sample)) */ - def sketch[K:ClassTag]( + def sketch[K : ClassTag]( rdd: RDD[K], sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { val shift = rdd.id @@ -272,7 +272,7 @@ private[spark] object RangePartitioner { * @param partitions number of partitions * @return selected bounds */ - def determineBounds[K:Ordering:ClassTag]( + def determineBounds[K : Ordering : ClassTag]( candidates: ArrayBuffer[(K, Float)], partitions: Int): Array[K] = { val ordering = implicitly[Ordering[K]] diff --git a/core/src/main/scala/org/apache/spark/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/SizeEstimator.scala deleted file mode 100644 index 54fc3a856adfa..0000000000000 --- a/core/src/main/scala/org/apache/spark/SizeEstimator.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.apache.spark.annotation.DeveloperApi - -/** - * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in - * memory-aware caches. - * - * Based on the following JavaWorld article: - * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html - */ -@DeveloperApi -object SizeEstimator { - /** - * :: DeveloperApi :: - * Estimate the number of bytes that the given object takes up on the JVM heap. The estimate - * includes space taken up by objects referenced by the given object, their references, and so on - * and so forth. - * - * This is useful for determining the amount of heap space a broadcast variable will occupy on - * each executor or the amount of space each object will take when caching objects in - * deserialized form. This is not the same as the serialized size of the object, which will - * typically be much smaller. - */ - @DeveloperApi - def estimate(obj: AnyRef): Long = org.apache.spark.util.SizeEstimator.estimate(obj) -} diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index a8fc90ad2050e..46d72841dccce 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -227,7 +227,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsBytes(key: String, defaultValue: String): Long = { Utils.byteStringAsBytes(get(key, defaultValue)) } - + /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. @@ -244,7 +244,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsKb(key: String, defaultValue: String): Long = { Utils.byteStringAsKb(get(key, defaultValue)) } - + /** * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. @@ -261,7 +261,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsMb(key: String, defaultValue: String): Long = { Utils.byteStringAsMb(get(key, defaultValue)) } - + /** * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. @@ -278,7 +278,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { def getSizeAsGb(key: String, defaultValue: String): Long = { Utils.byteStringAsGb(get(key, defaultValue)) } - + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) @@ -480,8 +480,8 @@ private[spark] object SparkConf extends Logging { "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + "are no longer accepted. To specify the equivalent now, one may use '64k'.") ) - - Map(configs.map { cfg => (cfg.key -> cfg) }:_*) + + Map(configs.map { cfg => (cfg.key -> cfg) } : _*) } /** @@ -508,8 +508,8 @@ private[spark] object SparkConf extends Logging { "spark.reducer.maxSizeInFlight" -> Seq( AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), "spark.kryoserializer.buffer" -> - Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", - translation = s => s"${s.toDouble * 1000}k")), + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${(s.toDouble * 1000).toInt}k")), "spark.kryoserializer.buffer.max" -> Seq( AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), "spark.shuffle.file.buffer" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b59f562d05ead..a453c9bf4864a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -371,6 +371,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli throw new SparkException("An application name must be set in your configuration") } + // System property spark.yarn.app.id must be set if user code ran by AM on a YARN cluster + // yarn-standalone is deprecated, but still supported + if ((master == "yarn-cluster" || master == "yarn-standalone") && + !_conf.contains("spark.yarn.app.id")) { + throw new SparkException("Detected yarn-cluster mode, but isn't running on a cluster. " + + "Deployment to YARN is not supported directly by SparkContext. Please use spark-submit.") + } + if (_conf.getBoolean("spark.logConf", false)) { logInfo("Spark configuration:\n" + _conf.toDebugString) } @@ -381,7 +389,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER) - _jars =_conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten + _jars = _conf.getOption("spark.jars").map(_.split(",")).map(_.filter(_.size != 0)).toSeq.flatten _files = _conf.getOption("spark.files").map(_.split(",")).map(_.filter(_.size != 0)) .toSeq.flatten @@ -430,7 +438,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _ui = if (conf.getBoolean("spark.ui.enabled", true)) { Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, - _env.securityManager,appName, startTime = startTime)) + _env.securityManager, appName, startTime = startTime)) } else { // For tests, do not enable the UI None @@ -670,7 +678,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * Note: Return statements are NOT allowed in the given body. */ - private def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) + private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) // Methods for creating RDDs @@ -689,6 +697,78 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } + /** + * Creates a new RDD[Long] containing elements from `start` to `end`(exclusive), increased by + * `step` every element. + * + * @note if we need to cache this RDD, we should make sure each partition does not exceed limit. + * + * @param start the start value. + * @param end the end value. + * @param step the incremental step + * @param numSlices the partition number of the new RDD. + * @return + */ + def range( + start: Long, + end: Long, + step: Long = 1, + numSlices: Int = defaultParallelism): RDD[Long] = withScope { + assertNotStopped() + // when step is 0, range will run infinitely + require(step != 0, "step cannot be 0") + val numElements: BigInt = { + val safeStart = BigInt(start) + val safeEnd = BigInt(end) + if ((safeEnd - safeStart) % step == 0 || safeEnd > safeStart ^ step > 0) { + (safeEnd - safeStart) / step + } else { + // the remainder has the same sign with range, could add 1 more + (safeEnd - safeStart) / step + 1 + } + } + parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex((i, _) => { + val partitionStart = (i * numElements) / numSlices * step + start + val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start + def getSafeMargin(bi: BigInt): Long = + if (bi.isValidLong) { + bi.toLong + } else if (bi > 0) { + Long.MaxValue + } else { + Long.MinValue + } + val safePartitionStart = getSafeMargin(partitionStart) + val safePartitionEnd = getSafeMargin(partitionEnd) + + new Iterator[Long] { + private[this] var number: Long = safePartitionStart + private[this] var overflow: Boolean = false + + override def hasNext = + if (!overflow) { + if (step > 0) { + number < safePartitionEnd + } else { + number > safePartitionEnd + } + } else false + + override def next() = { + val ret = number + number += step + if (number < ret ^ step < 0) { + // we have Long.MaxValue + Long.MaxValue < Long.MaxValue + // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step + // back, we are pretty sure that we have an overflow. + overflow = true + } + ret + } + } + }) + } + /** Distribute a local Scala collection to form an RDD. * * This method is identical to `parallelize`. @@ -837,7 +917,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli classOf[FixedLengthBinaryInputFormat], classOf[LongWritable], classOf[BytesWritable], - conf=conf) + conf = conf) val data = br.map { case (k, v) => val bytes = v.getBytes assert(bytes.length == recordLength, "Byte array does not have correct length") @@ -1079,8 +1159,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]): RDD[(K, V)] = { withScope { assertNotStopped() - val kc = kcf() - val vc = vcf() + val kc = clean(kcf)() + val vc = clean(vcf)() val format = classOf[SequenceFileInputFormat[Writable, Writable]] val writables = hadoopFile(path, format, kc.writableClass(km).asInstanceOf[Class[Writable]], @@ -1187,7 +1267,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T] (initialValue: R): Accumulable[R, T] = { - val param = new GrowableAccumulableParam[R,T] + val param = new GrowableAccumulableParam[R, T] val acc = new Accumulable(initialValue, param) cleaner.foreach(_.registerAccumulatorForCleanup(acc)) acc @@ -1236,7 +1316,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val uri = new URI(path) val schemeCorrectedPath = uri.getScheme match { case null | "local" => new File(path).getCanonicalFile.toURI.toString - case _ => path + case _ => path } val hadoopPath = new Path(schemeCorrectedPath) @@ -1804,7 +1884,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * * @param f the closure to clean * @param checkSerializable whether or not to immediately check f for serializability - * @throws SparkException if checkSerializable is set but f is not + * @throws SparkException if checkSerializable is set but f is not * serializable */ private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = { @@ -1911,7 +1991,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Note: this code assumes that the task scheduler has been initialized and has contacted // the cluster manager to get an application ID (in case the cluster manager provides one). listenerBus.post(SparkListenerApplicationStart(appName, Some(applicationId), - startTime, sparkUser, applicationAttemptId)) + startTime, sparkUser, applicationAttemptId, schedulerBackend.getDriverLogUrls)) } /** Post the application end event */ diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0c4d28f786edd..a185954089528 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -298,7 +298,7 @@ object SparkEnv extends Logging { } } - val mapOutputTracker = if (isDriver) { + val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf) } else { new MapOutputTrackerWorker(conf) @@ -313,7 +313,8 @@ object SparkEnv extends Logging { // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", - "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", + "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) @@ -347,7 +348,7 @@ object SparkEnv extends Logging { val fileServerPort = conf.getInt("spark.fileserver.port", 0) val server = new HttpFileServer(conf, securityManager, fileServerPort) server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) + conf.set("spark.fileserver.uri", server.serverUri) server } else { null @@ -378,7 +379,7 @@ object SparkEnv extends Logging { } val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { - new OutputCommitCoordinator(conf) + new OutputCommitCoordinator(conf, isDriver) } val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 2ec42d3aea169..59ac82ccec53b 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -50,8 +50,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private var jID: SerializableWritable[JobID] = null private var taID: SerializableWritable[TaskAttemptID] = null - @transient private var writer: RecordWriter[AnyRef,AnyRef] = null - @transient private var format: OutputFormat[AnyRef,AnyRef] = null + @transient private var writer: RecordWriter[AnyRef, AnyRef] = null + @transient private var format: OutputFormat[AnyRef, AnyRef] = null @transient private var committer: OutputCommitter = null @transient private var jobContext: JobContext = null @transient private var taskContext: TaskAttemptContext = null @@ -114,10 +114,10 @@ class SparkHadoopWriter(@transient jobConf: JobConf) // ********* Private Functions ********* - private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = { + private def getOutputFormat(): OutputFormat[AnyRef, AnyRef] = { if (format == null) { format = conf.value.getOutputFormat() - .asInstanceOf[OutputFormat[AnyRef,AnyRef]] + .asInstanceOf[OutputFormat[AnyRef, AnyRef]] } format } @@ -138,7 +138,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private def getTaskContext(): TaskAttemptContext = { if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) + taskContext = newTaskAttemptContext(conf.value, taID.value) } taskContext } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 398ca41e16151..a1ebbecf93b7b 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -51,7 +51,7 @@ private[spark] object TestUtils { classpathUrls: Seq[URL] = Seq()): URL = { val tempDir = Utils.createTempDir() val files1 = for (name <- classNames) yield { - createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) + createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls) } val files2 = for ((childName, baseName) <- classNamesWithBase) yield { createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls) @@ -105,23 +105,18 @@ private[spark] object TestUtils { URI.create(s"string:///${name.replace(".", "/")}${SOURCE.extension}") } - private class JavaSourceFromString(val name: String, val code: String) + private[spark] class JavaSourceFromString(val name: String, val code: String) extends SimpleJavaFileObject(createURI(name), SOURCE) { override def getCharContent(ignoreEncodingErrors: Boolean): String = code } - /** Creates a compiled class with the given name. Class file will be placed in destDir. */ + /** Creates a compiled class with the source file. Class file will be placed in destDir. */ def createCompiledClass( className: String, destDir: File, - toStringValue: String = "", - baseClass: String = null, - classpathUrls: Seq[URL] = Seq()): File = { + sourceFile: JavaSourceFromString, + classpathUrls: Seq[URL]): File = { val compiler = ToolProvider.getSystemJavaCompiler - val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") - val sourceFile = new JavaSourceFromString(className, - "public class " + className + extendsText + " implements java.io.Serializable {" + - " @Override public String toString() { return \"" + toStringValue + "\"; }}") // Calling this outputs a class file in pwd. It's easier to just rename the file than // build a custom FileManager that controls the output location. @@ -144,4 +139,18 @@ private[spark] object TestUtils { assert(out.exists(), "Destination file not moved: " + out.getAbsolutePath()) out } + + /** Creates a compiled class with the given name. Class file will be placed in destDir. */ + def createCompiledClass( + className: String, + destDir: File, + toStringValue: String = "", + baseClass: String = null, + classpathUrls: Seq[URL] = Seq()): File = { + val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("") + val sourceFile = new JavaSourceFromString(className, + "public class " + className + extendsText + " implements java.io.Serializable {" + + " @Override public String toString() { return \"" + toStringValue + "\"; }}") + createCompiledClass(className, destDir, sourceFile, classpathUrls) + } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 61af867b11b9c..a650df605b92e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -137,7 +137,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) */ def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index db4e996feb31c..ed312770ee131 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -101,7 +101,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] @@ -109,10 +109,10 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) - + /** * Return a sampled subset of this RDD. - * + * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 8bf0627fc420d..c95615a5a9307 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -60,10 +60,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { @deprecated("Use partitions() instead.", "1.1.0") def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) - + /** Set of partitions in this RDD. */ def partitions: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) + /** The partitioner of this RDD. */ + def partitioner: Optional[Partitioner] = JavaUtils.optionToOptional(rdd.partitioner) + /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context: SparkContext = rdd.context @@ -96,7 +99,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def mapPartitionsWithIndex[R]( f: JFunction2[jl.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), + new JavaRDD(rdd.mapPartitionsWithIndex(((a, b) => f(a, asJavaIterator(b))), preservesPartitioning)(fakeClassTag))(fakeClassTag) /** @@ -386,9 +389,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. + * given associative and commutative function and a neutral "zero value". The function + * op(t1, t2) is allowed to modify t1 and return it as its result value to avoid object + * allocation; however, it should not modify t2. + * + * This behaves somewhat differently from fold operations implemented for non-distributed + * collections in functional languages like Scala. This fold operation may be applied to + * partitions individually, and then fold those results into the final result, rather than + * apply the fold to each element sequentially in some defined ordering. For functions + * that are not commutative, the result may differ from that of a fold applied to a + * non-distributed collection. */ def fold(zeroValue: T)(f: JFunction2[T, T, T]): T = rdd.fold(zeroValue)(f) @@ -485,9 +495,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } - def takeSample(withReplacement: Boolean, num: Int): JList[T] = + def takeSample(withReplacement: Boolean, num: Int): JList[T] = takeSample(withReplacement, num, Utils.random.nextLong) - + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 7409dc2d866f6..55a37f8c944b2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -47,6 +47,7 @@ private[spark] class PythonRDD( pythonIncludes: JList[String], preservePartitoning: Boolean, pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { @@ -210,6 +211,8 @@ private[spark] class PythonRDD( val dataOut = new DataOutputStream(stream) // Partition index dataOut.writeInt(split.index) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) // Python includes (*.zip and *.egg files) @@ -720,7 +723,7 @@ private[spark] object PythonRDD extends Logging { val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] - converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) + converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec = codec) } /** @@ -794,10 +797,10 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) - /** + /** * We try to reuse a single Socket to transfer accumulator updates, as they are all added * by the DAGScheduler's single-threaded actor anyway. - */ + */ @transient var socket: Socket = _ def openSocket(): Socket = synchronized { @@ -840,6 +843,7 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: * An Wrapper for Python Broadcast, which is written into disk by Python. It also will * write the data into disk after deserialization, then Python can read it from disks. */ +// scalastyle:off no.finalize private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { /** @@ -881,3 +885,4 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial } } } +// scalastyle:on no.finalize diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index efb6b93cfc35d..90dacaeb93429 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -50,8 +50,15 @@ private[spark] object PythonUtils { /** * Convert list of T into seq of T (for calling API with varargs) */ - def toSeq[T](cols: JList[T]): Seq[T] = { - cols.toList.toSeq + def toSeq[T](vs: JList[T]): Seq[T] = { + vs.toList.toSeq + } + + /** + * Convert list of T into array of T (for calling API with array) + */ + def toArray[T](vs: JList[T]): Array[T] = { + vs.toArray().asInstanceOf[Array[T]] } /** diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 3a2c94bd9d875..d24c650d37bb0 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io.{DataOutputStream, File, FileOutputStream, IOException} -import java.net.{InetSocketAddress, ServerSocket} +import java.net.{InetAddress, InetSocketAddress, ServerSocket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap @@ -44,11 +44,11 @@ private[spark] class RBackend { bossGroup = new NioEventLoopGroup(2) val workerGroup = bossGroup val handler = new RBackendHandler(this) - + bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) .channel(classOf[NioServerSocketChannel]) - + bootstrap.childHandler(new ChannelInitializer[SocketChannel]() { def initChannel(ch: SocketChannel): Unit = { ch.pipeline() @@ -65,7 +65,7 @@ private[spark] class RBackend { } }) - channelFuture = bootstrap.bind(new InetSocketAddress(0)) + channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0)) channelFuture.syncUninterruptibly() channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() } @@ -101,7 +101,7 @@ private[spark] object RBackend extends Logging { try { // bind to random port val boundPort = sparkRBackend.init() - val serverSocket = new ServerSocket(0, 1) + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() // tell the R process via temporary file diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 0075d963711f1..2e86984c66b3a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -77,7 +77,7 @@ private[r] class RBackendHandler(server: RBackend) val reply = bos.toByteArray ctx.write(reply) } - + override def channelReadComplete(ctx: ChannelHandlerContext): Unit = { ctx.flush() } @@ -124,7 +124,7 @@ private[r] class RBackendHandler(server: RBackend) } throw new Exception(s"No matched method found for $cls.$methodName") } - val ret = methods.head.invoke(obj, args:_*) + val ret = methods.head.invoke(obj, args : _*) // Write status bit writeInt(dos, 0) @@ -135,7 +135,7 @@ private[r] class RBackendHandler(server: RBackend) matchMethod(numArgs, args, x.getParameterTypes) }.head - val obj = ctor.newInstance(args:_*) + val obj = ctor.newInstance(args : _*) writeInt(dos, 0) writeObject(dos, obj.asInstanceOf[AnyRef]) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 6fea5e1144f2f..4dfa7325934ff 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io._ -import java.net.ServerSocket +import java.net.{InetAddress, ServerSocket} import java.util.{Map => JMap} import scala.collection.JavaConversions._ @@ -55,7 +55,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( val parentIterator = firstParent[T].iterator(partition, context) // we expect two connections - val serverSocket = new ServerSocket(0, 2) + val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() // The stdout/stderr is shared by multiple tasks, because we use one daemon @@ -309,7 +309,7 @@ private class StringRRDD[T: ClassTag]( } private object SpecialLengths { - val TIMING_DATA = -1 + val TIMING_DATA = -1 } private[r] class BufferedStreamThread( @@ -355,7 +355,6 @@ private[r] object RRDD { val sparkConf = new SparkConf().setAppName(appName) .setSparkHome(sparkHome) - .setJars(jars) // Override `master` if we have a user-specified value if (master != "") { @@ -373,7 +372,11 @@ private[r] object RRDD { sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String]) } - new JavaSparkContext(sparkConf) + val jsc = new JavaSparkContext(sparkConf) + jars.foreach { jar => + jsc.addJar(jar) + } + jsc } /** @@ -414,7 +417,7 @@ private[r] object RRDD { synchronized { if (daemonChannel == null) { // we expect one connections - val serverSocket = new ServerSocket(0, 1) + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val daemonPort = serverSocket.getLocalPort errThread = createRProcess(rLibDir, daemonPort, "daemon.R") // the socket used to send out the input of task diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 371dfe454d1a2..f8e3f1a79082e 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -157,9 +157,11 @@ private[spark] object SerDe { val keysLen = readInt(in) val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType)) - val valuesType = readObjectType(in) val valuesLen = readInt(in) - val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType)) + val values = (0 until valuesLen).map(_ => { + val valueType = readObjectType(in) + readTypedObject(in, valueType) + }) mapAsJavaMap(keys.zip(values).toMap) } else { new java.util.HashMap[Object, Object]() diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4457c75e8b0fc..b69af639f7862 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -125,7 +125,7 @@ private[broadcast] object HttpBroadcast extends Logging { securityManager = securityMgr if (isDriver) { createServer(conf) - conf.set("spark.httpBroadcast.uri", serverUri) + conf.set("spark.httpBroadcast.uri", serverUri) } serverUri = conf.get("spark.httpBroadcast.uri") cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) @@ -187,7 +187,7 @@ private[broadcast] object HttpBroadcast extends Logging { } private def read[T: ClassTag](id: Long): T = { - logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) + logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name var uc: URLConnection = null diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index c048b78910f38..b4edb6109e839 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -65,7 +65,7 @@ private object FaultToleranceTest extends App with Logging { private val workers = ListBuffer[TestWorkerInfo]() private var sc: SparkContext = _ - private val zk = SparkCuratorUtil.newClient(conf) + private val zk = SparkCuratorUtil.newClient(conf) private var numPassed = 0 private var numFailed = 0 diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 860e1a24901b6..0550f00a172ab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -43,6 +43,8 @@ class LocalSparkCluster( private val localHostname = Utils.localHostName() private val masterActorSystems = ArrayBuffer[ActorSystem]() private val workerActorSystems = ArrayBuffer[ActorSystem]() + // exposed for testing + var masterWebUIPort = -1 def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") @@ -53,7 +55,9 @@ class LocalSparkCluster( .set("spark.shuffle.service.enabled", "false") /* Start the Master */ - val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, _conf) + val (masterSystem, masterPort, webUiPort, _) = + Master.startSystemAndActor(localHostname, 0, 0, _conf) + masterWebUIPort = webUiPort masterActorSystems += masterSystem val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort val masters = Array(masterUrl) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 53e18c4bcec23..c2ed43a5397d6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -18,9 +18,11 @@ package org.apache.spark.deploy import java.net.URI +import java.io.File import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ +import scala.util.Try import org.apache.spark.api.python.PythonUtils import org.apache.spark.util.{RedirectThread, Utils} @@ -81,16 +83,13 @@ object PythonRunner { throw new IllegalArgumentException("Launching Python applications through " + s"spark-submit is currently only supported for local files: $path") } - val windows = Utils.isWindows || testWindows - var formattedPath = if (windows) Utils.formatWindowsPath(path) else path - - // Strip the URI scheme from the path - formattedPath = - new URI(formattedPath).getScheme match { - case null => formattedPath - case Utils.windowsDrive(d) if windows => formattedPath - case _ => new URI(formattedPath).getPath - } + // get path when scheme is file. + val uri = Try(new URI(path)).getOrElse(new File(path).toURI) + var formattedPath = uri.getScheme match { + case null => path + case "file" | "local" => uri.getPath + case _ => null + } // Guard against malformed paths potentially throwing NPE if (formattedPath == null) { @@ -99,7 +98,9 @@ object PythonRunner { // In Windows, the drive should not be prefixed with "/" // For instance, python does not understand "/C:/path/to/sheep.py" - formattedPath = if (windows) formattedPath.stripPrefix("/") else formattedPath + if (Utils.isWindows && formattedPath.matches("/[a-zA-Z]:/.*")) { + formattedPath = formattedPath.stripPrefix("/") + } formattedPath } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 329fa06ba8ba5..8cf4d58847d8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -361,7 +361,7 @@ object SparkSubmit { pyArchives = pythonPath.mkString(",") } - pyArchives = pyArchives.split(",").map { localPath=> + pyArchives = pyArchives.split(",").map { localPath => val localURI = Utils.resolveURI(localPath) if (localURI.getScheme != "local") { args.files = mergeFileLists(args.files, localURI.toString) @@ -428,6 +428,8 @@ object SparkSubmit { OptionAssigner(args.executorCores, YARN, CLIENT, sysProp = "spark.executor.cores"), OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), + OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), + OptionAssigner(args.keytab, YARN, CLIENT, sysProp = "spark.yarn.keytab"), // Yarn cluster only OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), @@ -440,10 +442,8 @@ object SparkSubmit { OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), - - // Yarn client or cluster - OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, clOption = "--principal"), - OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, clOption = "--keytab"), + OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"), + OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"), // Other options OptionAssigner(args.executorCores, STANDALONE, ALL_DEPLOY_MODES, @@ -753,7 +753,9 @@ private[spark] object SparkSubmitUtils { * @param artifactId the artifactId of the coordinate * @param version the version of the coordinate */ - private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) + private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) { + override def toString: String = s"$groupId:$artifactId:$version" + } /** * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided @@ -776,6 +778,10 @@ private[spark] object SparkSubmitUtils { } } + /** Path of the local Maven cache. */ + private[spark] def m2Path: File = new File(System.getProperty("user.home"), + ".m2" + File.separator + "repository" + File.separator) + /** * Extracts maven coordinates from a comma-delimited string * @param remoteRepos Comma-delimited string of remote repositories @@ -789,8 +795,7 @@ private[spark] object SparkSubmitUtils { val localM2 = new IBiblioResolver localM2.setM2compatible(true) - val m2Path = ".m2" + File.separator + "repository" + File.separator - localM2.setRoot(new File(System.getProperty("user.home"), m2Path).toURI.toString) + localM2.setRoot(m2Path.toURI.toString) localM2.setUsepoms(true) localM2.setName("local-m2-cache") cr.add(localM2) @@ -864,7 +869,7 @@ private[spark] object SparkSubmitUtils { md.addDependency(dd) } } - + /** Add exclusion rules for dependencies already included in the spark-assembly */ def addExclusionRules( ivySettings: IvySettings, @@ -915,69 +920,72 @@ private[spark] object SparkSubmitUtils { "" } else { val sysOut = System.out - // To prevent ivy from logging to system out - System.setOut(printStream) - val artifacts = extractMavenCoordinates(coordinates) - // Default configuration name for ivy - val ivyConfName = "default" - // set ivy settings for location of cache - val ivySettings: IvySettings = new IvySettings - // Directories for caching downloads through ivy and storing the jars when maven coordinates - // are supplied to spark-submit - val alternateIvyCache = ivyPath.getOrElse("") - val packagesDirectory: File = - if (alternateIvyCache.trim.isEmpty) { - new File(ivySettings.getDefaultIvyUserDir, "jars") + try { + // To prevent ivy from logging to system out + System.setOut(printStream) + val artifacts = extractMavenCoordinates(coordinates) + // Default configuration name for ivy + val ivyConfName = "default" + // set ivy settings for location of cache + val ivySettings: IvySettings = new IvySettings + // Directories for caching downloads through ivy and storing the jars when maven coordinates + // are supplied to spark-submit + val alternateIvyCache = ivyPath.getOrElse("") + val packagesDirectory: File = + if (alternateIvyCache.trim.isEmpty) { + new File(ivySettings.getDefaultIvyUserDir, "jars") + } else { + ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) + ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) + new File(alternateIvyCache, "jars") + } + printStream.println( + s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") + printStream.println(s"The jars for the packages stored in: $packagesDirectory") + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + // create the dependency resolvers + val repoResolver = createRepoResolvers(remoteRepos, ivySettings) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + + val ivy = Ivy.newInstance(ivySettings) + // Set resolve options to download transitive dependencies as well + val resolveOptions = new ResolveOptions + resolveOptions.setTransitive(true) + val retrieveOptions = new RetrieveOptions + // Turn downloading and logging off for testing + if (isTest) { + resolveOptions.setDownload(false) + resolveOptions.setLog(LogOptions.LOG_QUIET) + retrieveOptions.setLog(LogOptions.LOG_QUIET) } else { - ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) - ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) - new File(alternateIvyCache, "jars") + resolveOptions.setDownload(true) } - printStream.println( - s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") - printStream.println(s"The jars for the packages stored in: $packagesDirectory") - // create a pattern matcher - ivySettings.addMatcher(new GlobPatternMatcher) - // create the dependency resolvers - val repoResolver = createRepoResolvers(remoteRepos, ivySettings) - ivySettings.addResolver(repoResolver) - ivySettings.setDefaultResolver(repoResolver.getName) - - val ivy = Ivy.newInstance(ivySettings) - // Set resolve options to download transitive dependencies as well - val resolveOptions = new ResolveOptions - resolveOptions.setTransitive(true) - val retrieveOptions = new RetrieveOptions - // Turn downloading and logging off for testing - if (isTest) { - resolveOptions.setDownload(false) - resolveOptions.setLog(LogOptions.LOG_QUIET) - retrieveOptions.setLog(LogOptions.LOG_QUIET) - } else { - resolveOptions.setDownload(true) - } - // A Module descriptor must be specified. Entries are dummy strings - val md = getModuleDescriptor - md.setDefaultConf(ivyConfName) + // A Module descriptor must be specified. Entries are dummy strings + val md = getModuleDescriptor + md.setDefaultConf(ivyConfName) - // Add exclusion rules for Spark and Scala Library - addExclusionRules(ivySettings, ivyConfName, md) - // add all supplied maven artifacts as dependencies - addDependenciesToIvy(md, artifacts, ivyConfName) + // Add exclusion rules for Spark and Scala Library + addExclusionRules(ivySettings, ivyConfName, md) + // add all supplied maven artifacts as dependencies + addDependenciesToIvy(md, artifacts, ivyConfName) - // resolve dependencies - val rr: ResolveReport = ivy.resolve(md, resolveOptions) - if (rr.hasError) { - throw new RuntimeException(rr.getAllProblemMessages.toString) + // resolve dependencies + val rr: ResolveReport = ivy.resolve(md, resolveOptions) + if (rr.hasError) { + throw new RuntimeException(rr.getAllProblemMessages.toString) + } + // retrieve all resolved dependencies + ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision].[ext]", + retrieveOptions.setConfs(Array(ivyConfName))) + resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) + } finally { + System.setOut(sysOut) } - // retrieve all resolved dependencies - ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, - packagesDirectory.getAbsolutePath + File.separator + - "[organization]_[artifact]-[revision].[ext]", - retrieveOptions.setConfs(Array(ivyConfName))) - System.setOut(sysOut) - resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c0e4c771908b3..cc6a7bd9f4119 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -169,6 +169,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull + principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 517cbe5176241..5a0eb585a9049 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -25,7 +25,8 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.status.api.v1.{ApplicationInfo, ApplicationsListResource, JsonRootResource, UIRoot} +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, + UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.{SignalLogger, Utils} @@ -125,7 +126,7 @@ class HistoryServer( def initialize() { attachPage(new HistoryPage(this)) - attachHandler(JsonRootResource.getJsonServlet(this)) + attachHandler(ApiRootResource.getServletHandler(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index a2a97a7877ce7..4692d22651c93 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -23,7 +23,7 @@ import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) +private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { private var propertiesFile: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala index 9c3f79f1244b7..66a9ff38678c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala @@ -30,6 +30,11 @@ private[spark] class MasterSource(val master: Master) extends Source { override def getValue: Int = master.workers.size }) + // Gauge for alive worker numbers in cluster + metricRegistry.register(MetricRegistry.name("aliveWorkers"), new Gauge[Int]{ + override def getValue: Int = master.workers.filter(_.state == WorkerState.ALIVE).size + }) + // Gauge for application numbers in cluster metricRegistry.register(MetricRegistry.name("apps"), new Gauge[Int] { override def getValue: Int = master.apps.size diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 80db6d474b5c1..328d95a7a0c68 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -32,7 +32,7 @@ import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) extends PersistenceEngine with Logging { - + private val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/master_status" private val zk: CuratorFramework = SparkCuratorUtil.newClient(conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index f0b270d799d23..cc1288d6782ca 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -74,6 +74,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory") val workers = state.workers.sortBy(_.id) + val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", @@ -107,12 +108,12 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { }.getOrElse { Seq.empty } } -
  • Workers: {state.workers.size}
  • -
  • Cores: {state.workers.map(_.cores).sum} Total, - {state.workers.map(_.coresUsed).sum} Used
  • -
  • Memory: - {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total, - {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used
  • +
  • Alive Workers: {aliveWorkers.size}
  • +
  • Cores in use: {aliveWorkers.map(_.cores).sum} Total, + {aliveWorkers.map(_.coresUsed).sum} Used
  • +
  • Memory in use: + {Utils.megabytesToString(aliveWorkers.map(_.memory).sum)} Total, + {Utils.megabytesToString(aliveWorkers.map(_.memoryUsed).sum)} Used
  • Applications: {state.activeApps.size} Running, {state.completedApps.size} Completed
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index eb26e9f99c70b..2111a8581f2e4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -19,7 +19,8 @@ package org.apache.spark.deploy.master.ui import org.apache.spark.Logging import org.apache.spark.deploy.master.Master -import org.apache.spark.status.api.v1.{ApplicationsListResource, ApplicationInfo, JsonRootResource, UIRoot} +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource, ApplicationInfo, + UIRoot} import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.RpcUtils @@ -47,7 +48,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) attachPage(new HistoryNotFoundPage(this)) attachPage(masterPage) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(JsonRootResource.getJsonServlet(this)) + attachHandler(ApiRootResource.getServletHandler(this)) attachHandler(createRedirectHandler( "/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST"))) attachHandler(createRedirectHandler( diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index 6078f50518ba4..1fe956320a1b8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -57,7 +57,11 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { private val supportedMasterPrefixes = Seq("spark://", "mesos://") - private val masters: Array[String] = Utils.parseStandaloneMasterUrls(master) + private val masters: Array[String] = if (master.startsWith("spark://")) { + Utils.parseStandaloneMasterUrls(master) + } else { + Array(master) + } // Set of masters that lost contact with us, used to keep track of // whether there are masters still alive for us to communicate with diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 8f3cc54051048..ebc6cd76c6afd 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -324,9 +324,6 @@ private[worker] class Worker( map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) - case Heartbeat => - logInfo(s"Received heartbeat from driver ${sender.path}") - case RegisterWorkerFailed(message) => if (!registered) { logError("Worker registration failed: " + message) @@ -557,7 +554,7 @@ private[deploy] object Worker extends Logging { conf = conf, securityManager = securityMgr) val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) + masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 88170d4df3053..dc2bee6f2bdca 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -29,6 +29,7 @@ import org.apache.spark.util.logging.RollingFileAppender private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { private val worker = parent.worker private val workDir = parent.workDir + private val supportedLogTypes = Set("stderr", "stdout") def renderLog(request: HttpServletRequest): String = { val defaultBytes = 100 * 1024 @@ -129,6 +130,11 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with offsetOption: Option[Long], byteLength: Int ): (String, Long, Long, Long) = { + + if (!supportedLogTypes.contains(logType)) { + return ("Error: Log type must be one of " + supportedLogTypes.mkString(", "), 0, 0, 0) + } + try { val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType) logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index ed159dec4f998..f3a26f54a81fb 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -33,7 +33,7 @@ import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( override val rpcEnv: RpcEnv, @@ -55,18 +55,19 @@ private[spark] class CoarseGrainedExecutorBackend( private[this] val ser: SerializerInstance = env.closureSerializer.newInstance() override def onStart() { - import scala.concurrent.ExecutionContext.Implicits.global logInfo("Connecting to driver: " + driverUrl) rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => + // This is a very fast action so we can use "ThreadUtils.sameThread" driver = Some(ref) ref.ask[RegisteredExecutor.type]( RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) - } onComplete { + }(ThreadUtils.sameThread).onComplete { + // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => Utils.tryLogNonFatalError { Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor } case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) - } + }(ThreadUtils.sameThread) } def extractLogUrls: Map[String, String] = { diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 06152f16ae618..38b61d7242fce 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -43,22 +43,22 @@ class TaskMetrics extends Serializable { private var _hostname: String = _ def hostname: String = _hostname private[spark] def setHostname(value: String) = _hostname = value - + /** * Time taken on the executor to deserialize this task */ private var _executorDeserializeTime: Long = _ def executorDeserializeTime: Long = _executorDeserializeTime private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value - - + + /** * Time the executor spends actually running the task (including fetching shuffle data) */ private var _executorRunTime: Long = _ def executorRunTime: Long = _executorRunTime private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value - + /** * The number of bytes this task transmitted back to the driver as the TaskResult */ @@ -261,7 +261,7 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { */ private var _recordsRead: Long = _ def recordsRead: Long = _recordsRead - def incRecordsRead(records: Long): Unit = _recordsRead += records + def incRecordsRead(records: Long): Unit = _recordsRead += records /** * Invoke the bytesReadCallback and mutate bytesRead. @@ -315,7 +315,7 @@ class ShuffleReadMetrics extends Serializable { def remoteBlocksFetched: Int = _remoteBlocksFetched private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value - + /** * Number of local blocks fetched in this shuffle by this task */ @@ -333,7 +333,7 @@ class ShuffleReadMetrics extends Serializable { def fetchWaitTime: Long = _fetchWaitTime private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value - + /** * Total number of remote bytes read from the shuffle by this task */ @@ -381,7 +381,7 @@ class ShuffleWriteMetrics extends Serializable { def shuffleBytesWritten: Long = _shuffleBytesWritten private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value - + /** * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ @@ -389,7 +389,7 @@ class ShuffleWriteMetrics extends Serializable { def shuffleWriteTime: Long = _shuffleWriteTime private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value - + /** * Total number of records written to the shuffle by this task */ diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0756cdb2ed8e6..0d8ac1f80a9f4 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -17,7 +17,7 @@ package org.apache.spark.io -import java.io.{InputStream, OutputStream} +import java.io.{IOException, InputStream, OutputStream} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} @@ -154,8 +154,53 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt - new SnappyOutputStream(s, blockSize) + new SnappyOutputStreamWrapper(new SnappyOutputStream(s, blockSize)) } override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s) } + +/** + * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close + * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version + * of snappy-java that contains the fix for https://github.com/xerial/snappy-java/issues/107. + */ +private final class SnappyOutputStreamWrapper(os: SnappyOutputStream) extends OutputStream { + + private[this] var closed: Boolean = false + + override def write(b: Int): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b) + } + + override def write(b: Array[Byte]): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.write(b, off, len) + } + + override def flush(): Unit = { + if (closed) { + throw new IOException("Stream is closed") + } + os.flush() + } + + override def close(): Unit = { + if (!closed) { + closed = true + os.close() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index cfd20392d12f1..390d148bc97f9 100644 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -60,7 +60,7 @@ trait SparkHadoopMapReduceUtil { val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") .asInstanceOf[Class[Enum[_]]] val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( - taskTypeClass, if(isMap) "MAP" else "REDUCE") + taskTypeClass, if (isMap) "MAP" else "REDUCE") val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, classOf[Int], classOf[Int]) ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index e8b3074e8f1a6..11dfcfe2f04e1 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -26,9 +26,9 @@ import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem private[spark] class Slf4jSink( - val property: Properties, + val property: Properties, val registry: MetricRegistry, - securityMgr: SecurityManager) + securityMgr: SecurityManager) extends Sink { val SLF4J_DEFAULT_PERIOD = 10 val SLF4J_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala index 90e3aa70b99ef..670e683663324 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/package.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/package.scala @@ -20,4 +20,4 @@ package org.apache.spark.metrics /** * Sinks used in Spark's metrics system. */ -package object sink +package object sink diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index b573f1a8a5fcb..67a376102994c 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -110,7 +110,7 @@ private[nio] class BlockMessage() { def getType: Int = typ def getId: BlockId = id def getData: ByteBuffer = data - def getLevel: StorageLevel = level + def getLevel: StorageLevel = level def toBufferMessage: BufferMessage = { val buffers = new ArrayBuffer[ByteBuffer]() @@ -155,7 +155,7 @@ private[nio] class BlockMessage() { override def toString: String = { "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index 1ba25aa74aa02..7d0806f0c2580 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -114,8 +114,8 @@ private[nio] object BlockMessageArray { val blockMessages = (0 until 10).map { i => if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear + val buffer = ByteBuffer.allocate(100) + buffer.clear() BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, StorageLevel.MEMORY_ONLY_SER)) } else { diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 6b898bd4bfc1b..1499da07bb83b 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -326,15 +326,14 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // MUST be called within the selector loop def connect() { - try{ + try { channel.register(selector, SelectionKey.OP_CONNECT) channel.connect(address) logInfo("Initiating connection to [" + address + "]") } catch { - case e: Exception => { + case e: Exception => logError("Error connecting to " + address, e) callOnExceptionCallbacks(e) - } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 497871ed6d5e5..c0bca2c4bc994 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -635,12 +635,11 @@ private[nio] class ConnectionManager( val message = securityMsgResp.toBufferMessage if (message == null) throw new IOException("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) - } catch { - case e: Exception => { + } catch { + case e: Exception => logError("Error handling sasl client authentication", e) waitingConn.close() throw new IOException("Error evaluating sasl response: ", e) - } } } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala index 747a2088a7258..232c552f9865d 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala @@ -75,7 +75,7 @@ private[nio] class SecurityMessage extends Logging { for (i <- 1 to idLength) { idBuilder += buffer.getChar() } - connectionId = idBuilder.toString() + connectionId = idBuilder.toString() val tokenLength = buffer.getInt() token = new Array[Byte](tokenLength) diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 3ef3cc219dec6..91b07ce3af1b6 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -32,12 +32,12 @@ import org.apache.spark.util.collection.OpenHashMap * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. */ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[OpenHashMap[T,Long], Map[T, BoundedDouble]] { + extends ApproximateEvaluator[OpenHashMap[T, Long], Map[T, BoundedDouble]] { var outputsMerged = 0 - var sums = new OpenHashMap[T,Long]() // Sum of counts for each key + var sums = new OpenHashMap[T, Long]() // Sum of counts for each key - override def merge(outputId: Int, taskResult: OpenHashMap[T,Long]) { + override def merge(outputId: Int, taskResult: OpenHashMap[T, Long]) { outputsMerged += 1 taskResult.foreach { case (key, value) => sums.changeValue(key, value, _ + value) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index ec185340c3a2d..ca1eb1f4e4a9a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -19,8 +19,10 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong +import org.apache.spark.util.ThreadUtils + import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.ExecutionContext import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} @@ -66,6 +68,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val f = new ComplexFutureAction[Seq[T]] f.run { + // This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which + // is a cached thread pool. val results = new ArrayBuffer[T](num) val totalParts = self.partitions.length var partsScanned = 0 @@ -81,9 +85,9 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi numPartsToTry = partsScanned * 4 } else { // the left side of max is >=1 whenever partsScanned >= 2 - numPartsToTry = Math.max(1, + numPartsToTry = Math.max(1, (1.5 * num * partsScanned / results.size).toInt - partsScanned) - numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } @@ -101,7 +105,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi partsScanned += numPartsToTry } results.toSeq - } + }(AsyncRDDActions.futureExecutionContext) f } @@ -123,3 +127,8 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi (index, data) => Unit, Unit) } } + +private object AsyncRDDActions { + val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("AsyncRDDActions-future", 128)) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 0d130dd4c7a60..a4715e3437d94 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -49,7 +49,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) if (fs.exists(cpath)) { val dirContents = fs.listStatus(cpath).map(_.getPath) val partitionFiles = dirContents.filter(_.getName.startsWith("part-")).map(_.toString).sorted - val numPart = partitionFiles.length + val numPart = partitionFiles.length if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 0c1b02c07d09f..663eebb8e4191 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -310,11 +310,11 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: def throwBalls() { if (noLocality) { // no preferredLocations in parent RDD, no randomization needed if (maxPartitions > groupArr.size) { // just return prev.partitions - for ((p,i) <- prev.partitions.zipWithIndex) { + for ((p, i) <- prev.partitions.zipWithIndex) { groupArr(i).arr += p } } else { // no locality available, then simply split partitions based on positions in array - for(i <- 0 until maxPartitions) { + for (i <- 0 until maxPartitions) { val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 2ab967f4bb313..84456d6d868dc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -196,7 +196,7 @@ class NewHadoopRDD[K, V]( override def getPreferredLocations(hsplit: Partition): Seq[String] = { val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => + case Some(c) => try { val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] Some(HadoopRDD.convertSplitLocationInfo(infos)) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index a6d5d2c94e17f..cfd3e26faf2b9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -296,6 +296,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = self.withScope { + val cleanedF = self.sparkContext.clean(func) if (keyClass.isArray) { throw new SparkException("reduceByKeyLocally() does not support array keys") @@ -305,7 +306,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val map = new JHashMap[K, V] iter.foreach { pair => val old = map.get(pair._1) - map.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) + map.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } Iterator(map) } : Iterator[JHashMap[K, V]] @@ -313,7 +314,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val mergeMaps = (m1: JHashMap[K, V], m2: JHashMap[K, V]) => { m2.foreach { pair => val old = m1.get(pair._1) - m1.put(pair._1, if (old == null) pair._2 else func(old, pair._2)) + m1.put(pair._1, if (old == null) pair._2 else cleanedF(old, pair._2)) } m1 } : JHashMap[K, V] @@ -327,7 +328,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) reduceByKeyLocally(func) } - /** + /** * Count the number of elements for each key, collecting the results to a local Map. * * Note that this method should only be used if the resulting map is expected to be small, as @@ -466,7 +467,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val mergeValue = (buf: CompactBuffer[V], v: V) => buf += v val mergeCombiners = (c1: CompactBuffer[V], c2: CompactBuffer[V]) => c1 ++= c2 val bufs = combineByKey[CompactBuffer[V]]( - createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine=false) + createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine = false) bufs.asInstanceOf[RDD[(K, Iterable[V])]] } @@ -1010,7 +1011,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) jobFormat.checkOutputSpecs(job) } - val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { + val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { val config = wrappedConf.value /* "reduce task" */ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, @@ -1026,7 +1027,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L Utils.tryWithSafeFinally { diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 7598ff617b399..9e3880714a79f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -86,7 +86,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( } val location = if (locations.isEmpty) { None - } else { + } else { // Find the location that maximum number of parent partitions prefer Some(locations.groupBy(x => x).maxBy(_._2.length)._1) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 02a94baf372d9..10610f4b6f1ff 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -434,11 +434,11 @@ abstract class RDD[T: ClassTag]( * @return A random sub-sample of the RDD without replacement. */ private[spark] def randomSampleWithRange(lb: Double, ub: Double, seed: Long): RDD[T] = { - this.mapPartitionsWithIndex { case (index, partition) => + this.mapPartitionsWithIndex( { (index, partition) => val sampler = new BernoulliCellSampler[T](lb, ub) sampler.setSeed(seed + index) sampler.sample(partition) - } + }, preservesPartitioning = true) } /** @@ -454,7 +454,7 @@ abstract class RDD[T: ClassTag]( withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = { - val numStDev = 10.0 + val numStDev = 10.0 if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") @@ -1015,9 +1015,16 @@ abstract class RDD[T: ClassTag]( /** * Aggregate the elements of each partition, and then the results for all the partitions, using a - * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to - * modify t1 and return it as its result value to avoid object allocation; however, it should not - * modify t2. + * given associative and commutative function and a neutral "zero value". The function + * op(t1, t2) is allowed to modify t1 and return it as its result value to avoid object + * allocation; however, it should not modify t2. + * + * This behaves somewhat differently from fold operations implemented for non-distributed + * collections in functional languages like Scala. This fold operation may be applied to + * partitions individually, and then fold those results into the final result, rather than + * apply the fold to each element sequentially in some defined ordering. For functions + * that are not commutative, the result may differ from that of a fold applied to a + * non-distributed collection. */ def fold(zeroValue: T)(op: (T, T) => T): T = withScope { // Clone the zero value since we will also be serializing it as part of tasks @@ -1131,8 +1138,8 @@ abstract class RDD[T: ClassTag]( if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValueApprox() does not support arrays") } - val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T,Long] = { (ctx, iter) => - val map = new OpenHashMap[T,Long] + val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T, Long] = { (ctx, iter) => + val map = new OpenHashMap[T, Long] iter.foreach { t => map.changeValue(t, 1L, _ + 1L) } @@ -1524,7 +1531,7 @@ abstract class RDD[T: ClassTag]( * doCheckpoint() is called recursively on the parent RDDs. */ private[spark] def doCheckpoint(): Unit = { - RDDOperationScope.withScope(sc, "checkpoint", false, true) { + RDDOperationScope.withScope(sc, "checkpoint", allowNesting = false, ignoreParent = true) { if (!doCheckpointCalled) { doCheckpointCalled = true if (checkpointData.isDefined) { @@ -1578,15 +1585,15 @@ abstract class RDD[T: ClassTag]( case 0 => Seq.empty case 1 => val d = rdd.dependencies.head - debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]], true) + debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_, _, _]], true) case _ => val frontDeps = rdd.dependencies.take(len - 1) val frontDepStrings = frontDeps.flatMap( - d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_,_,_]])) + d => debugString(d.rdd, prefix, d.isInstanceOf[ShuffleDependency[_, _, _]])) val lastDep = rdd.dependencies.last val lastDepStrings = - debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_,_,_]], true) + debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true) (frontDepStrings ++ lastDepStrings) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala index 93ec606f2de7d..6b09dfafc889c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDOperationScope.scala @@ -24,7 +24,7 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} /** * A general, named code block representing an operation that instantiates RDDs. @@ -43,9 +43,8 @@ import org.apache.spark.SparkContext @JsonPropertyOrder(Array("id", "name", "parent")) private[spark] class RDDOperationScope( val name: String, - val parent: Option[RDDOperationScope] = None) { - - val id: Int = RDDOperationScope.nextScopeId() + val parent: Option[RDDOperationScope] = None, + val id: String = RDDOperationScope.nextScopeId().toString) { def toJson: String = { RDDOperationScope.jsonMapper.writeValueAsString(this) @@ -75,7 +74,7 @@ private[spark] class RDDOperationScope( * A collection of utility methods to construct a hierarchical representation of RDD scopes. * An RDD scope tracks the series of operations that created a given RDD. */ -private[spark] object RDDOperationScope { +private[spark] object RDDOperationScope extends Logging { private val jsonMapper = new ObjectMapper().registerModule(DefaultScalaModule) private val scopeCounter = new AtomicInteger(0) @@ -88,15 +87,26 @@ private[spark] object RDDOperationScope { /** * Execute the given body such that all RDDs created in this body will have the same scope. - * The name of the scope will be the name of the method that immediately encloses this one. + * The name of the scope will be the first method name in the stack trace that is not the + * same as this method's. * * Note: Return statements are NOT allowed in body. */ private[spark] def withScope[T]( sc: SparkContext, allowNesting: Boolean = false)(body: => T): T = { - val callerMethodName = Thread.currentThread.getStackTrace()(3).getMethodName - withScope[T](sc, callerMethodName, allowNesting)(body) + val stackTrace = Thread.currentThread.getStackTrace().tail // ignore "Thread#getStackTrace" + val ourMethodName = stackTrace(1).getMethodName // i.e. withScope + // Climb upwards to find the first method that's called something different + val callerMethodName = stackTrace + .find(_.getMethodName != ourMethodName) + .map(_.getMethodName) + .getOrElse { + // Log a warning just in case, but this should almost certainly never happen + logWarning("No valid method name for this RDD operation scope!") + "N/A" + } + withScope[T](sc, callerMethodName, allowNesting, ignoreParent = false)(body) } /** @@ -116,7 +126,7 @@ private[spark] object RDDOperationScope { sc: SparkContext, name: String, allowNesting: Boolean, - ignoreParent: Boolean = false)(body: => T): T = { + ignoreParent: Boolean)(body: => T): T = { // Save the old scope to restore it later val scopeKey = SparkContext.RDD_SCOPE_KEY val noOverrideKey = SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 3dfcf67f0eb66..4b5f15dd06b85 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -104,13 +104,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag if (!convertKey && !convertValue) { self.saveAsHadoopFile(path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (!convertKey && convertValue) { - self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile( + self.map(x => (x._1, anyToWritable(x._2))).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && !convertValue) { - self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile( + self.map(x => (anyToWritable(x._1), x._2)).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } else if (convertKey && convertValue) { - self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile( + self.map(x => (anyToWritable(x._1), anyToWritable(x._2))).saveAsHadoopFile( path, keyWritableClass, valueWritableClass, format, jobConf, codec) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 633aeba3bbae6..f7cb1791d4ac6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -125,7 +125,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( integrate(0, t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys integrate(1, t => map.remove(t._1)) - map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten + map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index a96b6c3d23454..81f40ad33aa5d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -123,7 +123,7 @@ private[spark] class ZippedPartitionsRDD3 } private[spark] class ZippedPartitionsRDD4 - [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag]( + [A: ClassTag, B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]( sc: SparkContext, var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], var rdd1: RDD[A], diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5d812918a13d1..75a567fb31520 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -193,9 +193,15 @@ class DAGScheduler( def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized { // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] - val locs: Seq[Seq[TaskLocation]] = blockManagerMaster.getLocations(blockIds).map { bms => - bms.map(bm => TaskLocation(bm.host, bm.executorId)) + // Note: if the storage level is NONE, we don't need to get locations from block manager. + val locs: Seq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) { + Seq.fill(rdd.partitions.size)(Nil) + } else { + val blockIds = + rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId] + blockManagerMaster.getLocations(blockIds).map { bms => + bms.map(bm => TaskLocation(bm.host, bm.executorId)) + } } cacheLocs(rdd.id) = locs } @@ -208,19 +214,17 @@ class DAGScheduler( /** * Get or create a shuffle map stage for the given shuffle dependency's map side. - * The jobId value passed in will be used if the stage doesn't already exist with - * a lower jobId (jobId always increases across jobs.) */ private def getShuffleMapStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int): ShuffleMapStage = { + firstJobId: Int): ShuffleMapStage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => // We are going to register ancestor shuffle dependencies - registerShuffleDependencies(shuffleDep, jobId) + registerShuffleDependencies(shuffleDep, firstJobId) // Then register current shuffleDep - val stage = newOrUsedShuffleStage(shuffleDep, jobId) + val stage = newOrUsedShuffleStage(shuffleDep, firstJobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage @@ -230,15 +234,15 @@ class DAGScheduler( /** * Helper function to eliminate some code re-use when creating new stages. */ - private def getParentStagesAndId(rdd: RDD[_], jobId: Int): (List[Stage], Int) = { - val parentStages = getParentStages(rdd, jobId) + private def getParentStagesAndId(rdd: RDD[_], firstJobId: Int): (List[Stage], Int) = { + val parentStages = getParentStages(rdd, firstJobId) val id = nextStageId.getAndIncrement() (parentStages, id) } /** * Create a ShuffleMapStage as part of the (re)-creation of a shuffle map stage in - * newOrUsedShuffleStage. The stage will be associated with the provided jobId. + * newOrUsedShuffleStage. The stage will be associated with the provided firstJobId. * Production of shuffle map stages should always use newOrUsedShuffleStage, not * newShuffleMapStage directly. */ @@ -246,21 +250,19 @@ class DAGScheduler( rdd: RDD[_], numTasks: Int, shuffleDep: ShuffleDependency[_, _, _], - jobId: Int, + firstJobId: Int, callSite: CallSite): ShuffleMapStage = { - val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId) + val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, firstJobId) val stage: ShuffleMapStage = new ShuffleMapStage(id, rdd, numTasks, parentStages, - jobId, callSite, shuffleDep) + firstJobId, callSite, shuffleDep) stageIdToStage(id) = stage - updateJobIdStageIdMaps(jobId, stage) + updateJobIdStageIdMaps(firstJobId, stage) stage } /** - * Create a ResultStage -- either directly for use as a result stage, or as part of the - * (re)-creation of a shuffle map stage in newOrUsedShuffleStage. The stage will be associated - * with the provided jobId. + * Create a ResultStage associated with the provided jobId. */ private def newResultStage( rdd: RDD[_], @@ -277,16 +279,16 @@ class DAGScheduler( /** * Create a shuffle map Stage for the given RDD. The stage will also be associated with the - * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is + * provided firstJobId. If a stage for the shuffleId existed previously so that the shuffleId is * present in the MapOutputTracker, then the number and location of available outputs are * recovered from the MapOutputTracker */ private def newOrUsedShuffleStage( shuffleDep: ShuffleDependency[_, _, _], - jobId: Int): ShuffleMapStage = { + firstJobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd val numTasks = rdd.partitions.size - val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, jobId, rdd.creationSite) + val stage = newShuffleMapStage(rdd, numTasks, shuffleDep, firstJobId, rdd.creationSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) @@ -304,10 +306,10 @@ class DAGScheduler( } /** - * Get or create the list of parent stages for a given RDD. The stages will be assigned the - * provided jobId if they haven't already been created with a lower jobId. + * Get or create the list of parent stages for a given RDD. The new Stages will be created with + * the provided firstJobId. */ - private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { + private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError @@ -321,7 +323,7 @@ class DAGScheduler( for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - parents += getShuffleMapStage(shufDep, jobId) + parents += getShuffleMapStage(shufDep, firstJobId) case _ => waitingForVisit.push(dep.rdd) } @@ -336,11 +338,11 @@ class DAGScheduler( } /** Find ancestor missing shuffle dependencies and register into shuffleToMapStage */ - private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) { + private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], firstJobId: Int) { val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd) while (parentsWithNoMapStage.nonEmpty) { val currentShufDep = parentsWithNoMapStage.pop() - val stage = newOrUsedShuffleStage(currentShufDep, jobId) + val stage = newOrUsedShuffleStage(currentShufDep, firstJobId) shuffleToMapStage(currentShufDep.shuffleId) = stage } } @@ -386,11 +388,12 @@ class DAGScheduler( def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd - if (getCacheLocs(rdd).contains(Nil)) { + val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil) + if (rddHasUncachedPartitions) { for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) + val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { missing += mapStage } @@ -577,7 +580,7 @@ class DAGScheduler( private[scheduler] def doCancelAllJobs() { // Cancel all running jobs. - runningStages.map(_.jobId).foreach(handleJobCancellation(_, + runningStages.map(_.firstJobId).foreach(handleJobCancellation(_, reason = "as part of cancellation of all jobs")) activeJobs.clear() // These should already be empty by this point, jobIdToActiveJob.clear() // but just in case we lost track of some jobs... @@ -603,7 +606,7 @@ class DAGScheduler( clearCacheLocs() val failedStagesCopy = failedStages.toArray failedStages.clear() - for (stage <- failedStagesCopy.sortBy(_.jobId)) { + for (stage <- failedStagesCopy.sortBy(_.firstJobId)) { submitStage(stage) } } @@ -623,7 +626,7 @@ class DAGScheduler( logTrace("failed: " + failedStages) val waitingStagesCopy = waitingStages.toArray waitingStages.clear() - for (stage <- waitingStagesCopy.sortBy(_.jobId)) { + for (stage <- waitingStagesCopy.sortBy(_.firstJobId)) { submitStage(stage) } } @@ -843,7 +846,7 @@ class DAGScheduler( } } - val properties = jobIdToActiveJob.get(stage.jobId).map(_.properties).orNull + val properties = jobIdToActiveJob.get(stage.firstJobId).map(_.properties).orNull runningStages += stage // SparkListenerStageSubmitted should be posted before testing whether tasks are @@ -909,7 +912,7 @@ class DAGScheduler( stage.pendingTasks ++= tasks logDebug("New pending tasks: " + stage.pendingTasks) taskScheduler.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) + new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.firstJobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark @@ -1323,7 +1326,7 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_, _, _] => - val mapStage = getShuffleMapStage(shufDep, stage.jobId) + val mapStage = getShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { waitingForVisit.push(mapStage.rdd) } // Otherwise there's no need to follow the dependency back @@ -1364,10 +1367,10 @@ class DAGScheduler( private def getPreferredLocsInternal( rdd: RDD[_], partition: Int, - visited: HashSet[(RDD[_],Int)]): Seq[TaskLocation] = { + visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = { // If the partition has already been visited, no need to re-visit. // This avoids exponential path exploration. SPARK-695 - if (!visited.add((rdd,partition))) { + if (!visited.add((rdd, partition))) { // Nil has already been returned for previously visited partitions. return Nil } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 12668b6c0988e..02c67073af6a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -17,9 +17,8 @@ package org.apache.spark.scheduler -import com.codahale.metrics.{Gauge,MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 0b1d47cff3746..8321037cdc026 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -38,7 +38,7 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests) * for an extensive design discussion. */ -private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { +private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) extends Logging { // Initialized by SparkEnv var coordinatorRef: Option[RpcEndpointRef] = None @@ -129,9 +129,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { } def stop(): Unit = synchronized { - coordinatorRef.foreach(_ send StopCoordinator) - coordinatorRef = None - authorizedCommittersByStage.clear() + if (isDriver) { + coordinatorRef.foreach(_ send StopCoordinator) + coordinatorRef = None + authorizedCommittersByStage.clear() + } } // Marked private[scheduler] instead of private so this can be mocked in tests diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 86f357abb8723..c6d957b65f3fb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -41,7 +41,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { * * @param logData Stream containing event log data. * @param sourceName Filename (or other source identifier) from whence @logData is being read - * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations + * @param maybeTruncated Indicate whether log file might be truncated (some abnormal situations * encountered, log file might not finished writing) or not */ def replay( @@ -62,7 +62,7 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { if (!maybeTruncated || lines.hasNext) { throw jpe } else { - logWarning(s"Got JsonParseException from log file $sourceName" + + logWarning(s"Got JsonParseException from log file $sourceName" + s" at line $lineNumber, the file might not have finished writing cleanly.") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index c0f3d5a13d623..bf81b9aca4810 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -28,9 +28,9 @@ private[spark] class ResultStage( rdd: RDD[_], numTasks: Int, parents: List[Stage], - jobId: Int, + firstJobId: Int, callSite: CallSite) - extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { // The active job for this result stage. Will be empty if the job has already finished // (e.g., because the job was cancelled). diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 646820520ea1b..8801a761afae3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -49,4 +49,11 @@ private[spark] trait SchedulerBackend { */ def applicationAttemptId(): Option[String] = None + /** + * Get the URLs for the driver logs. These URLs are used to display the links in the UI + * Executors tab for the driver. + * @return Map containing the log names and their respective URLs + */ + def getDriverLogUrls: Option[Map[String, String]] = None + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala index 5e62c8468f007..864941d468af9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala @@ -56,7 +56,7 @@ private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble - var compare:Int = 0 + var compare: Int = 0 if (s1Needy && !s2Needy) { return true diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index d02210743484c..66c75f325fcde 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -30,10 +30,10 @@ private[spark] class ShuffleMapStage( rdd: RDD[_], numTasks: Int, parents: List[Stage], - jobId: Int, + firstJobId: Int, callSite: CallSite, val shuffleDep: ShuffleDependency[_, _, _]) - extends Stage(id, rdd, numTasks, parents, jobId, callSite) { + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { override def toString: String = "ShuffleMapStage " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 169d4fd3a94f0..9620915f495ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -110,8 +110,13 @@ case class SparkListenerExecutorMetricsUpdate( extends SparkListenerEvent @DeveloperApi -case class SparkListenerApplicationStart(appName: String, appId: Option[String], - time: Long, sparkUser: String, appAttemptId: Option[String]) extends SparkListenerEvent +case class SparkListenerApplicationStart( + appName: String, + appId: Option[String], + time: Long, + sparkUser: String, + appAttemptId: Option[String], + driverLogs: Option[Map[String, String]] = None) extends SparkListenerEvent @DeveloperApi case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent @@ -265,7 +270,7 @@ class StatsReportListener extends SparkListener with Logging { private[spark] object StatsReportListener extends Logging { // For profiling, the extremes are more interesting - val percentiles = Array[Int](0,5,10,25,50,75,90,95,100) + val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100) val probabilities = percentiles.map(_ / 100.0) val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" @@ -299,7 +304,7 @@ private[spark] object StatsReportListener extends Logging { dOpt.foreach { d => showDistribution(heading, d, formatNumber)} } - def showDistribution(heading: String, dOpt: Option[Distribution], format:String) { + def showDistribution(heading: String, dOpt: Option[Distribution], format: String) { def f(d: Double): String = format.format(d) showDistribution(heading, dOpt, f _) } @@ -313,7 +318,7 @@ private[spark] object StatsReportListener extends Logging { } def showBytesDistribution( - heading:String, + heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long], taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 5d0ddb8377c33..c59d6e4f5bc04 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.CallSite * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes * that each output partition is on. * - * Each Stage also has a jobId, identifying the job that first submitted the stage. When FIFO + * Each Stage also has a firstJobId, identifying the job that first submitted the stage. When FIFO * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered * faster on failure. * @@ -51,7 +51,7 @@ private[spark] abstract class Stage( val rdd: RDD[_], val numTasks: Int, val parents: List[Stage], - val jobId: Int, + val firstJobId: Int, val callSite: CallSite) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 586d1e06204c1..15101c64f0503 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -125,7 +125,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - } + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index 1f114a0207f7b..8b2a742b96988 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -40,6 +40,9 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long var metrics: TaskMetrics) extends TaskResult[T] with Externalizable { + private var valueObjectDeserialized = false + private var valueObject: T = _ + def this() = this(null.asInstanceOf[ByteBuffer], null, null) override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { @@ -72,10 +75,26 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long } } metrics = in.readObject().asInstanceOf[TaskMetrics] + valueObjectDeserialized = false } + /** + * When `value()` is called at the first time, it needs to deserialize `valueObject` from + * `valueBytes`. It may cost dozens of seconds for a large instance. So when calling `value` at + * the first time, the caller should avoid to block other threads. + * + * After the first time, `value()` is trivial and just returns the deserialized `valueObject`. + */ def value(): T = { - val resultSer = SparkEnv.get.serializer.newInstance() - resultSer.deserialize(valueBytes) + if (valueObjectDeserialized) { + valueObject + } else { + // This should not run when holding a lock because it may cost dozens of seconds for a large + // value. + val resultSer = SparkEnv.get.serializer.newInstance() + valueObject = resultSer.deserialize(valueBytes) + valueObjectDeserialized = true + valueObject + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 391827c1d2156..46a6f6537e2ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -54,6 +54,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { return } + // deserialize "value" without holding any lock so that it won't block other threads. + // We should call it here, so that when it's called again in + // "TaskSetManager.handleSuccessfulTask", it does not need to deserialize the value. + directResult.value() (directResult, serializedData.limit()) case IndirectTaskResult(blockId, size) => if (!taskSetManager.canFetchMoreResults(size)) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index b4b8a630694bb..ed3dde0fc3055 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer import java.util.{TimerTask, Timer} +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong -import scala.concurrent.duration._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet @@ -32,7 +32,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId @@ -64,6 +64,9 @@ private[spark] class TaskSchedulerImpl( // How often to check for speculative tasks val SPECULATION_INTERVAL_MS = conf.getTimeAsMs("spark.speculation.interval", "100ms") + private val speculationScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("task-scheduler-speculation") + // Threshold above which we warn user initial TaskSet may be starved val STARVATION_TIMEOUT_MS = conf.getTimeAsMs("spark.starvation.timeout", "15s") @@ -142,10 +145,11 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL_MS milliseconds, - SPECULATION_INTERVAL_MS milliseconds) { - Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } - }(sc.env.actorSystem.dispatcher) + speculationScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryOrStopSparkContext(sc) { + checkSpeculatableTasks() + } + }, SPECULATION_INTERVAL_MS, SPECULATION_INTERVAL_MS, TimeUnit.MILLISECONDS) } } @@ -412,6 +416,7 @@ private[spark] class TaskSchedulerImpl( } override def stop() { + speculationScheduler.shutdown() if (backend != null) { backend.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 7dc325283d961..673cd0e19eba2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -620,6 +620,12 @@ private[spark] class TaskSetManager( val index = info.index info.markSuccessful() removeRunningTask(tid) + // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the + // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not + // "deserialize" the value when holding a lock to avoid blocking other threads. So we call + // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. + // Note: "result.value()" only deserializes the value when it's called at the first time, so + // here "result.value()" just returns the value and won't block other threads. sched.dagScheduler.taskEnded( tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics) if (!successful(index)) { @@ -775,10 +781,10 @@ private[spark] class TaskSetManager( // that it's okay if we add a task to the same queue twice (if it had multiple preferred // locations), because dequeueTaskFromList will skip already-running tasks. for (index <- getPendingTasksForExecutor(execId)) { - addPendingTask(index, readding=true) + addPendingTask(index, readding = true) } for (index <- getPendingTasksForHost(host)) { - addPendingTask(index, readding=true) + addPendingTask(index, readding = true) } // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, @@ -855,9 +861,9 @@ private[spark] class TaskSetManager( case TaskLocality.RACK_LOCAL => "spark.locality.wait.rack" case _ => null } - + if (localityWaitKey != null) { - conf.getTimeAsMs(localityWaitKey, defaultWait) + conf.getTimeAsMs(localityWaitKey, defaultWait) } else { 0L } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 70364cea62a80..4be1eda2e9291 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -75,7 +75,8 @@ private[spark] object CoarseGrainedClusterMessages { case class SetupDriver(driver: RpcEndpointRef) extends CoarseGrainedClusterMessage // Exchanged between the driver and the AM in Yarn client mode - case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) + case class AddWebUIFilter( + filterName: String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage // Messages exchanged between the driver and the cluster manager for executor allocation diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f107148f3b8c6..fcad959540f5a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -69,6 +69,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { + // If this DriverEndpoint is changed to support multiple threads, + // then this may need to be changed so that we don't share the serializer + // instance across threads + private val ser = SparkEnv.get.closureSerializer.newInstance() + override protected def log = CoarseGrainedSchedulerBackend.this.log private val addressToExecutorId = new HashMap[RpcAddress, String] @@ -79,7 +84,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveIntervalMs = conf.getTimeAsMs("spark.scheduler.revive.interval", "1s") - + reviveThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ReviveOffers)) @@ -163,7 +168,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Make fake resource offers on all executors - def makeOffers() { + private def makeOffers() { launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) }.toSeq)) @@ -175,16 +180,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Make fake resource offers on just one executor - def makeOffers(executorId: String) { + private def makeOffers(executorId: String) { val executorData = executorDataMap(executorId) launchTasks(scheduler.resourceOffers( Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)))) } // Launch tasks returned by a set of resource offers - def launchTasks(tasks: Seq[Seq[TaskDescription]]) { + private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { - val ser = SparkEnv.get.closureSerializer.newInstance() val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 2a3a5d925d06f..190ff61d689d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -149,7 +149,7 @@ private[spark] abstract class YarnSchedulerBackend( } } - override def onStop(): Unit ={ + override def onStop(): Unit = { askAmThreadPool.shutdownNow() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index dc59545b43314..6b8edca5aa485 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -25,9 +25,10 @@ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, _} +import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** @@ -51,7 +52,7 @@ private[spark] class CoarseMesosSchedulerBackend( val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) - val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt + val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt // Cores we have acquired with each Mesos task ID val coresByTaskId = new HashMap[Int, Int] @@ -115,11 +116,9 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(sc.env.actorSystem), + val driverUrl = sc.env.rpcEnv.uriOf( SparkEnv.driverActorSystemName, - conf.get("spark.driver.host"), - conf.get("spark.driver.port"), + RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) val uri = conf.getOption("spark.executor.uri") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index db0a080b3b0c0..49de85ef48ada 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -146,7 +146,7 @@ private[spark] class MesosSchedulerBackend( private def createExecArg(): Array[Byte] = { if (execArgs == null) { val props = new HashMap[String, String] - for ((key,value) <- sc.conf.getAll) { + for ((key, value) <- sc.conf.getAll) { props(key) = value } // Serialize the map as an array of (String, String) pairs diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 928c5cfed417a..e79c543a9de27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -37,14 +37,14 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .newBuilder() .setMode(Volume.Mode.RW) spec match { - case Array(container_path) => + case Array(container_path) => Some(vol.setContainerPath(container_path)) case Array(container_path, "rw") => Some(vol.setContainerPath(container_path)) case Array(container_path, "ro") => Some(vol.setContainerPath(container_path) .setMode(Volume.Mode.RO)) - case Array(host_path, container_path) => + case Array(host_path, container_path) => Some(vol.setContainerPath(container_path) .setHostPath(host_path)) case Array(host_path, container_path, "rw") => @@ -108,7 +108,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { image: String, volumes: Option[List[Volume]] = None, network: Option[ContainerInfo.DockerInfo.Network] = None, - portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None):Unit = { + portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None): Unit = { val docker = ContainerInfo.DockerInfo.newBuilder().setImage(image) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index e64d06c4d3cfc..3078a1b10be8b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -18,14 +18,12 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer -import java.util.concurrent.TimeUnit import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.{ThreadUtils, Utils} private case class ReviveOffers() @@ -47,9 +45,6 @@ private[spark] class LocalEndpoint( private val totalCores: Int) extends ThreadSafeRpcEndpoint with Logging { - private val reviveThread = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("local-revive-thread") - private var freeCores = totalCores private val localExecutorId = SparkContext.DRIVER_IDENTIFIER @@ -79,27 +74,13 @@ private[spark] class LocalEndpoint( context.reply(true) } - def reviveOffers() { val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) - val tasks = scheduler.resourceOffers(offers).flatten - for (task <- tasks) { + for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber, task.name, task.serializedTask) } - if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) { - // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout - reviveThread.schedule(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - Option(self).foreach(_.send(ReviveOffers)) - } - }, 1000, TimeUnit.MILLISECONDS) - } - } - - override def onStop(): Unit = { - reviveThread.shutdownNow() } } diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index dfbde7c8a1b0d..698d1384d580d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -121,6 +121,8 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true) + protected def this() = this(new SparkConf()) // For deserialization only + override def newInstance(): SerializerInstance = { val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 64ba27f34d2f1..cd8a82347a1e9 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,8 +17,9 @@ package org.apache.spark.serializer -import java.io.{EOFException, InputStream, OutputStream} +import java.io.{EOFException, IOException, InputStream, OutputStream} import java.nio.ByteBuffer +import javax.annotation.Nullable import scala.reflect.ClassTag @@ -51,7 +52,7 @@ class KryoSerializer(conf: SparkConf) with Serializable { private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") - + if (bufferSizeKb >= ByteUnit.GiB.toKiB(2)) { throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + s"2048 mb, got: + ${ByteUnit.KiB.toMiB(bufferSizeKb)} mb.") @@ -136,21 +137,45 @@ class KryoSerializer(conf: SparkConf) } private[spark] -class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { - val output = new KryoOutput(outStream) +class KryoSerializationStream( + serInstance: KryoSerializerInstance, + outStream: OutputStream) extends SerializationStream { + + private[this] var output: KryoOutput = new KryoOutput(outStream) + private[this] var kryo: Kryo = serInstance.borrowKryo() override def writeObject[T: ClassTag](t: T): SerializationStream = { kryo.writeClassAndObject(output, t) this } - override def flush() { output.flush() } - override def close() { output.close() } + override def flush() { + if (output == null) { + throw new IOException("Stream is closed") + } + output.flush() + } + + override def close() { + if (output != null) { + try { + output.close() + } finally { + serInstance.releaseKryo(kryo) + kryo = null + output = null + } + } + } } private[spark] -class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - private val input = new KryoInput(inStream) +class KryoDeserializationStream( + serInstance: KryoSerializerInstance, + inStream: InputStream) extends DeserializationStream { + + private[this] var input: KryoInput = new KryoInput(inStream) + private[this] var kryo: Kryo = serInstance.borrowKryo() override def readObject[T: ClassTag](): T = { try { @@ -163,50 +188,105 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser } override def close() { - // Kryo's Input automatically closes the input stream it is using. - input.close() + if (input != null) { + try { + // Kryo's Input automatically closes the input stream it is using. + input.close() + } finally { + serInstance.releaseKryo(kryo) + kryo = null + input = null + } + } } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - private val kryo = ks.newKryo() - // Make these lazy vals to avoid creating a buffer unless we use them + /** + * A re-used [[Kryo]] instance. Methods will borrow this instance by calling `borrowKryo()`, do + * their work, then release the instance by calling `releaseKryo()`. Logically, this is a caching + * pool of size one. SerializerInstances are not thread-safe, hence accesses to this field are + * not synchronized. + */ + @Nullable private[this] var cachedKryo: Kryo = borrowKryo() + + /** + * Borrows a [[Kryo]] instance. If possible, this tries to re-use a cached Kryo instance; + * otherwise, it allocates a new instance. + */ + private[serializer] def borrowKryo(): Kryo = { + if (cachedKryo != null) { + val kryo = cachedKryo + // As a defensive measure, call reset() to clear any Kryo state that might have been modified + // by the last operation to borrow this instance (see SPARK-7766 for discussion of this issue) + kryo.reset() + cachedKryo = null + kryo + } else { + ks.newKryo() + } + } + + /** + * Release a borrowed [[Kryo]] instance. If this serializer instance already has a cached Kryo + * instance, then the given Kryo instance is discarded; otherwise, the Kryo is stored for later + * re-use. + */ + private[serializer] def releaseKryo(kryo: Kryo): Unit = { + if (cachedKryo == null) { + cachedKryo = kryo + } + } + + // Make these lazy vals to avoid creating a buffer unless we use them. private lazy val output = ks.newKryoOutput() private lazy val input = new KryoInput() override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() + val kryo = borrowKryo() try { kryo.writeClassAndObject(output, t) } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + "increase spark.kryoserializer.buffer.max value.") + } finally { + releaseKryo(kryo) } ByteBuffer.wrap(output.toBytes) } override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - input.setBuffer(bytes.array) - kryo.readClassAndObject(input).asInstanceOf[T] + val kryo = borrowKryo() + try { + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } finally { + releaseKryo(kryo) + } } override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + val kryo = borrowKryo() val oldClassLoader = kryo.getClassLoader - kryo.setClassLoader(loader) - input.setBuffer(bytes.array) - val obj = kryo.readClassAndObject(input).asInstanceOf[T] - kryo.setClassLoader(oldClassLoader) - obj + try { + kryo.setClassLoader(loader) + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] + } finally { + kryo.setClassLoader(oldClassLoader) + releaseKryo(kryo) + } } override def serializeStream(s: OutputStream): SerializationStream = { - new KryoSerializationStream(kryo, s) + new KryoSerializationStream(this, s) } override def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(kryo, s) + new KryoDeserializationStream(this, s) } /** @@ -216,7 +296,12 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ def getAutoReset(): Boolean = { val field = classOf[Kryo].getDeclaredField("autoReset") field.setAccessible(true) - field.get(kryo).asInstanceOf[Boolean] + val kryo = borrowKryo() + try { + field.get(kryo).asInstanceOf[Boolean] + } finally { + releaseKryo(kryo) + } } } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index 5abfa467c0ec8..bb5db545531d2 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -27,7 +27,7 @@ import scala.util.control.NonFatal import org.apache.spark.Logging -private[serializer] object SerializationDebugger extends Logging { +private[spark] object SerializationDebugger extends Logging { /** * Improve the given NotSerializableException with the serialization path leading from the given diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 6078c9d433ebf..f1bdff96d3df1 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import javax.annotation.concurrent.NotThreadSafe import scala.reflect.ClassTag @@ -114,8 +115,12 @@ object Serializer { /** * :: DeveloperApi :: * An instance of a serializer, for use by one thread at a time. + * + * It is legal to create multiple serialization / deserialization streams from the same + * SerializerInstance as long as those streams are all used within the same thread. */ @DeveloperApi +@NotThreadSafe abstract class SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6ad427bcac7f9..6c3b3080d2605 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -76,7 +76,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) private val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala index f6e6fe5defe09..4cc4ef5f1886e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -17,14 +17,17 @@ package org.apache.spark.shuffle +import java.io.IOException + import org.apache.spark.scheduler.MapStatus /** * Obtained inside a map task to write out records to the shuffle system. */ -private[spark] trait ShuffleWriter[K, V] { +private[spark] abstract class ShuffleWriter[K, V] { /** Write a sequence of records to this task's output */ - def write(records: Iterator[_ <: Product2[K, V]]): Unit + @throws[IOException] + def write(records: Iterator[Product2[K, V]]): Unit /** Close this writer, passing along whether the map completed */ def stop(success: Boolean): Option[MapStatus] diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 80374adc44296..597d46a3d2223 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -80,7 +80,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blocksByAddress, serializer, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 897f0a5dc5bcc..eb87cee15903c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -49,7 +49,7 @@ private[spark] class HashShuffleWriter[K, V]( writeMetrics) /** Write a bunch of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + override def write(records: Iterator[Product2[K, V]]): Unit = { val iter = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { dep.aggregator.get.combineValuesByKey(records, context) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 15842941daaab..d7fab351ca3b8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -72,7 +72,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager true } - override def shuffleBlockResolver: IndexShuffleBlockResolver = { + override val shuffleBlockResolver: IndexShuffleBlockResolver = { indexShuffleBlockResolver } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index add2656294ca2..5865e7640c1cf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -17,9 +17,10 @@ package org.apache.spark.shuffle.sort -import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus +import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -35,7 +36,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private val blockManager = SparkEnv.get.blockManager - private var sorter: ExternalSorter[K, V, _] = null + private var sorter: SortShuffleFileWriter[K, V] = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -48,19 +49,28 @@ private[spark] class SortShuffleWriter[K, V, C]( context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics) /** Write a bunch of records to this task's output */ - override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { - if (dep.mapSideCombine) { + override def write(records: Iterator[Product2[K, V]]): Unit = { + sorter = if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") - sorter = new ExternalSorter[K, V, C]( + new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - sorter.insertAll(records) + } else if (SortShuffleWriter.shouldBypassMergeSort( + SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) { + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't + // need local aggregation and sorting, write numPartitions files directly and just concatenate + // them at the end. This avoids doing serialization and deserialization twice to merge + // together the spilled files, which would happen with the normal code path. The downside is + // having multiple files open at a time and thus more memory allocated to buffers. + new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner, + writeMetrics, Serializer.getSerializer(dep.serializer)) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer) - sorter.insertAll(records) + new ExternalSorter[K, V, V]( + aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } + sorter.insertAll(records) // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately @@ -100,3 +110,13 @@ private[spark] class SortShuffleWriter[K, V, C]( } } +private[spark] object SortShuffleWriter { + def shouldBypassMergeSort( + conf: SparkConf, + numPartitions: Int, + aggregator: Option[Aggregator[_, _, _]], + keyOrdering: Option[Ordering[_]]): Boolean = { + val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala new file mode 100644 index 0000000000000..f2bfef376d3ca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.unsafe + +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark._ +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.sort.SortShuffleManager + +/** + * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle. + */ +private[spark] class UnsafeShuffleHandle[K, V]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, V]) + extends BaseShuffleHandle(shuffleId, numMaps, dependency) { +} + +private[spark] object UnsafeShuffleManager extends Logging { + + /** + * The maximum number of shuffle output partitions that UnsafeShuffleManager supports. + */ + val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + + /** + * Helper method for determining whether a shuffle should use the optimized unsafe shuffle + * path or whether it should fall back to the original sort-based shuffle. + */ + def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = { + val shufId = dependency.shuffleId + val serializer = Serializer.getSerializer(dependency.serializer) + if (!serializer.supportsRelocationOfSerializedObjects) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " + + s"${serializer.getClass.getName}, does not support object relocation") + false + } else if (dependency.aggregator.isDefined) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined") + false + } else if (dependency.keyOrdering.isDefined) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined") + false + } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) { + log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " + + s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions") + false + } else { + log.debug(s"Can use UnsafeShuffle for shuffle $shufId") + true + } + } +} + +/** + * A shuffle implementation that uses directly-managed memory to implement several performance + * optimizations for certain types of shuffles. In cases where the new performance optimizations + * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those + * shuffles. + * + * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold: + * + * - The shuffle dependency specifies no aggregation or output ordering. + * - The shuffle serializer supports relocation of serialized values (this is currently supported + * by KryoSerializer and Spark SQL's custom serializers). + * - The shuffle produces fewer than 16777216 output partitions. + * - No individual record is larger than 128 MB when serialized. + * + * In addition, extra spill-merging optimizations are automatically applied when the shuffle + * compression codec supports concatenation of serialized streams. This is currently supported by + * Spark's LZF serializer. + * + * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager. + * In sort-based shuffle, incoming records are sorted according to their target partition ids, then + * written to a single map output file. Reducers fetch contiguous regions of this file in order to + * read their portion of the map output. In cases where the map output data is too large to fit in + * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged + * to produce the final output file. + * + * UnsafeShuffleManager optimizes this process in several ways: + * + * - Its sort operates on serialized binary data rather than Java objects, which reduces memory + * consumption and GC overheads. This optimization requires the record serializer to have certain + * properties to allow serialized records to be re-ordered without requiring deserialization. + * See SPARK-4550, where this optimization was first proposed and implemented, for more details. + * + * - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts + * arrays of compressed record pointers and partition ids. By using only 8 bytes of space per + * record in the sorting array, this fits more of the array into cache. + * + * - The spill merging procedure operates on blocks of serialized records that belong to the same + * partition and does not need to deserialize records during the merge. + * + * - When the spill compression codec supports concatenation of compressed data, the spill merge + * simply concatenates the serialized and compressed spill partitions to produce the final output + * partition. This allows efficient data copying methods, like NIO's `transferTo`, to be used + * and avoids the need to allocate decompression or copying buffers during the merge. + * + * For more details on UnsafeShuffleManager's design, see SPARK-7081. + */ +private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + + if (!conf.getBoolean("spark.shuffle.spill", true)) { + logWarning( + "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " + + "manager; its optimized shuffles will continue to spill to disk when necessary.") + } + + private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf) + private[this] val shufflesThatFellBackToSortShuffle = + Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]()) + private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]() + + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) { + new UnsafeShuffleHandle[K, V]( + shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + } else { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + sortShuffleManager.getReader(handle, startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { + handle match { + case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] => + numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps) + val env = SparkEnv.get + new UnsafeShuffleWriter( + env.blockManager, + shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + context.taskMemoryManager(), + env.shuffleMemoryManager, + unsafeShuffleHandle, + mapId, + context, + env.conf) + case other => + shufflesThatFellBackToSortShuffle.add(handle.shuffleId) + sortShuffleManager.getWriter(handle, mapId, context) + } + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Boolean = { + if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) { + sortShuffleManager.unregisterShuffle(shuffleId) + } else { + Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps => + (0 until numMaps).foreach { mapId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + } + } + true + } + } + + override val shuffleBlockResolver: IndexShuffleBlockResolver = { + sortShuffleManager.shuffleBlockResolver + } + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = { + sortShuffleManager.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 50608588f09ae..390c136df79b3 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -169,7 +169,7 @@ private[v1] object AllStagesResource { val outputMetrics: Option[OutputMetricDistributions] = new MetricHelper[InternalOutputMetrics, OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw:InternalTaskMetrics): Option[InternalOutputMetrics] = { + def getSubmetrics(raw: InternalTaskMetrics): Option[InternalOutputMetrics] = { raw.outputMetrics } def build: OutputMetricDistributions = new OutputMetricDistributions( @@ -284,7 +284,7 @@ private[v1] object AllStagesResource { * the options (returning None if the metrics are all empty), and extract the quantiles for each * metric. After creating an instance, call metricOption to get the result type. */ -private[v1] abstract class MetricHelper[I,O]( +private[v1] abstract class MetricHelper[I, O]( rawMetrics: Seq[InternalTaskMetrics], quantiles: Array[Double]) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JsonRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/status/api/v1/JsonRootResource.scala rename to core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index c3ec45f54681b..f73c742732dec 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JsonRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -39,7 +39,7 @@ import org.apache.spark.ui.SparkUI * HistoryServerSuite. */ @Path("/v1") -private[v1] class JsonRootResource extends UIRootFromServletContext { +private[v1] class ApiRootResource extends UIRootFromServletContext { @Path("applications") def getApplicationList(): ApplicationListResource = { @@ -101,7 +101,7 @@ private[v1] class JsonRootResource extends UIRootFromServletContext { @Path("applications/{appId}/stages") - def getStages(@PathParam("appId") appId: String): AllStagesResource= { + def getStages(@PathParam("appId") appId: String): AllStagesResource = { uiRoot.withSparkUI(appId, None) { ui => new AllStagesResource(ui) } @@ -110,14 +110,14 @@ private[v1] class JsonRootResource extends UIRootFromServletContext { @Path("applications/{appId}/{attemptId}/stages") def getStages( @PathParam("appId") appId: String, - @PathParam("attemptId") attemptId: String): AllStagesResource= { + @PathParam("attemptId") attemptId: String): AllStagesResource = { uiRoot.withSparkUI(appId, Some(attemptId)) { ui => new AllStagesResource(ui) } } @Path("applications/{appId}/stages/{stageId: \\d+}") - def getStage(@PathParam("appId") appId: String): OneStageResource= { + def getStage(@PathParam("appId") appId: String): OneStageResource = { uiRoot.withSparkUI(appId, None) { ui => new OneStageResource(ui) } @@ -166,12 +166,12 @@ private[v1] class JsonRootResource extends UIRootFromServletContext { } -private[spark] object JsonRootResource { +private[spark] object ApiRootResource { - def getJsonServlet(uiRoot: UIRoot): ServletContextHandler = { + def getServletHandler(uiRoot: UIRoot): ServletContextHandler = { val jerseyContext = new ServletContextHandler(ServletContextHandler.NO_SESSIONS) - jerseyContext.setContextPath("/json") - val holder:ServletHolder = new ServletHolder(classOf[ServletContainer]) + jerseyContext.setContextPath("/api") + val holder: ServletHolder = new ServletHolder(classOf[ServletContainer]) holder.setInitParameter("com.sun.jersey.config.property.resourceConfigClass", "com.sun.jersey.api.core.PackagesResourceConfig") holder.setInitParameter("com.sun.jersey.config.property.packages", diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala index 07b224fac4786..dfdc09c6caf3b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala @@ -25,7 +25,7 @@ import org.apache.spark.ui.SparkUI private[v1] class OneRDDResource(ui: SparkUI) { @GET - def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { + def rddData(@PathParam("rddId") rddId: Int): RDDStorageInfo = { AllRDDResource.getRDDStorageInfo(rddId, ui.storageListener, true).getOrElse( throw new NotFoundException(s"no rdd found w/ id $rddId") ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala index fd24aea63a8a1..f9812f06cf527 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala @@ -83,7 +83,7 @@ private[v1] class OneStageResource(ui: SparkUI) { withStageAttempt(stageId, stageAttemptId) { stage => val tasks = stage.ui.taskData.values.map{AllStagesResource.convertTaskData}.toIndexedSeq .sorted(OneStageResource.ordering(sortBy)) - tasks.slice(offset, offset + length) + tasks.slice(offset, offset + length) } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index cee29786c3019..0c71cd2382225 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -16,40 +16,33 @@ */ package org.apache.spark.status.api.v1 -import java.text.SimpleDateFormat +import java.text.{ParseException, SimpleDateFormat} import java.util.TimeZone import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status -import scala.util.Try - private[v1] class SimpleDateParam(val originalValue: String) { - val timestamp: Long = { - SimpleDateParam.formats.collectFirst { - case fmt if Try(fmt.parse(originalValue)).isSuccess => - fmt.parse(originalValue).getTime() - }.getOrElse( - throw new WebApplicationException( - Response - .status(Status.BAD_REQUEST) - .entity("Couldn't parse date: " + originalValue) - .build() - ) - ) - } -} -private[v1] object SimpleDateParam { - - val formats: Seq[SimpleDateFormat] = { - - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") - gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) - - Seq( - new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz"), - gmtDay - ) + val timestamp: Long = { + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + try { + format.parse(originalValue).getTime() + } catch { + case _: ParseException => + val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) + try { + gmtDay.parse(originalValue).getTime() + } catch { + case _: ParseException => + throw new WebApplicationException( + Response + .status(Status.BAD_REQUEST) + .entity("Couldn't parse date: " + originalValue) + .build() + ) + } + } } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index ef3c8570d8186..2bec64f2ef02b 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -134,7 +134,7 @@ class StageData private[spark]( val accumulatorUpdates: Seq[AccumulableInfo], val tasks: Option[Map[Long, TaskData]], - val executorSummary:Option[Map[String,ExecutorStageSummary]]) + val executorSummary: Option[Map[String, ExecutorStageSummary]]) class TaskData private[spark]( val taskId: Long, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index cc794e5c90ffa..5048c7dab240b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,12 +17,11 @@ package org.apache.spark.storage -import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream} +import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.{Await, Future} -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.{ExecutionContext, Await, Future} import scala.concurrent.duration._ import scala.util.Random @@ -77,6 +76,9 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] + private val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) + // Actual storage of where blocks are kept private var externalBlockStoreInitialized = false private[spark] val memoryStore = new MemoryStore(this, maxMemory) @@ -266,11 +268,13 @@ private[spark] class BlockManager( asyncReregisterLock.synchronized { if (asyncReregisterTask == null) { asyncReregisterTask = Future[Unit] { + // This is a blocking action and should run in futureExecutionContext which is a cached + // thread pool reregister() asyncReregisterLock.synchronized { asyncReregisterTask = null } - } + }(futureExecutionContext) } } } @@ -485,16 +489,17 @@ private[spark] class BlockManager( if (level.useOffHeap) { logDebug(s"Getting block $blockId from ExternalBlockStore") if (externalBlockStore.contains(blockId)) { - externalBlockStore.getBytes(blockId) match { - case Some(bytes) => - if (!asBlockResult) { - return Some(bytes) - } else { - return Some(new BlockResult( - dataDeserialize(blockId, bytes), DataReadMethod.Memory, info.size)) - } + val result = if (asBlockResult) { + externalBlockStore.getValues(blockId) + .map(new BlockResult(_, DataReadMethod.Memory, info.size)) + } else { + externalBlockStore.getBytes(blockId) + } + result match { + case Some(values) => + return result case None => - logDebug(s"Block $blockId not found in externalBlockStore") + logDebug(s"Block $blockId not found in ExternalBlockStore") } } } @@ -744,7 +749,11 @@ private[spark] class BlockManager( case b: ByteBufferValues if putLevel.replication > 1 => // Duplicate doesn't copy the bytes, but just creates a wrapper val bufferView = b.buffer.duplicate() - Future { replicate(blockId, bufferView, putLevel) } + Future { + // This is a blocking action and should run in futureExecutionContext which is a cached + // thread pool + replicate(blockId, bufferView, putLevel) + }(futureExecutionContext) case _ => null } @@ -1198,8 +1207,19 @@ private[spark] class BlockManager( bytes: ByteBuffer, serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() - val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) - serializer.newInstance().deserializeStream(stream).asIterator + dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true), serializer) + } + + /** + * Deserializes a InputStream into an iterator of values and disposes of it when the end of + * the iterator is reached. + */ + def dataDeserializeStream( + blockId: BlockId, + inputStream: InputStream, + serializer: Serializer = defaultSerializer): Iterator[Any] = { + val stream = new BufferedInputStream(inputStream) + serializer.newInstance().deserializeStream(wrapForCompression(blockId, stream)).asIterator } def stop(): Unit = { @@ -1218,6 +1238,7 @@ private[spark] class BlockManager( } metadataCleaner.cancel() broadcastCleaner.cancel() + futureExecutionContext.shutdownNow() logInfo("BlockManager stopped") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 3afb4c3c02e2d..2cd8c5297b741 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -292,16 +292,16 @@ class BlockManagerMasterEndpoint( blockManagerIdByExecutor.get(id.executorId) match { case Some(oldId) => // A block manager of the same executor already exists, so remove it (assumed dead) - logError("Got two different block manager registrations on same executor - " + logError("Got two different block manager registrations on same executor - " + s" will replace old one $oldId with new one $id") - removeExecutor(id.executorId) + removeExecutor(id.executorId) case None => } logInfo("Registering block manager %s with %s RAM, %s".format( id.hostPort, Utils.bytesToString(maxMemSize), id)) - + blockManagerIdByExecutor(id.executorId) = id - + blockManagerInfo(id) = new BlockManagerInfo( id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 543df4e1350dd..7478ab0fc2f7a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -40,7 +40,7 @@ class BlockManagerSlaveEndpoint( private implicit val asyncExecutionContext = ExecutionContext.fromExecutorService(asyncThreadPool) // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, context) { blockManager.removeBlock(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 8569c6f3cbbc3..c5ba9af3e2658 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -17,9 +17,8 @@ package org.apache.spark.storage -import com.codahale.metrics.{Gauge,MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.SparkContext import org.apache.spark.metrics.source.Source private[spark] class BlockManagerSource(val blockManager: BlockManager) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 8bc4e205bc3c6..7eeabd1e0489c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -86,16 +86,6 @@ private[spark] class DiskBlockObjectWriter( extends BlockObjectWriter(blockId) with Logging { - /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ - private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - override def write(i: Int): Unit = callWithTiming(out.write(i)) - override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b)) - override def write(b: Array[Byte], off: Int, len: Int): Unit = { - callWithTiming(out.write(b, off, len)) - } - override def close(): Unit = out.close() - override def flush(): Unit = out.flush() - } /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null @@ -105,6 +95,7 @@ private[spark] class DiskBlockObjectWriter( private var objOut: SerializationStream = null private var initialized = false private var hasBeenClosed = false + private var commitAndCloseHasBeenCalled = false /** * Cursors used to represent positions in the file. @@ -136,7 +127,7 @@ private[spark] class DiskBlockObjectWriter( throw new IllegalStateException("Writer already closed. Cannot be reopened.") } fos = new FileOutputStream(file, true) - ts = new TimeTrackingOutputStream(fos) + ts = new TimeTrackingOutputStream(writeMetrics, fos) channel = fos.getChannel() bs = compressStream(new BufferedOutputStream(ts, bufferSize)) objOut = serializerInstance.serializeStream(bs) @@ -150,9 +141,9 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - callWithTiming { - fos.getFD.sync() - } + val start = System.nanoTime() + fos.getFD.sync() + writeMetrics.incShuffleWriteTime(System.nanoTime() - start) } } { objOut.close() @@ -177,20 +168,22 @@ private[spark] class DiskBlockObjectWriter( objOut.flush() bs.flush() close() + finalPosition = file.length() + // In certain compression codecs, more bytes are written after close() is called + writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + } else { + finalPosition = file.length() } - finalPosition = file.length() - // In certain compression codecs, more bytes are written after close() is called - writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition) + commitAndCloseHasBeenCalled = true } // Discard current writes. We do this by flushing the outstanding writes and then // truncating the file to its initial position. override def revertPartialWritesAndClose() { try { - writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) - writeMetrics.decShuffleRecordsWritten(numRecordsWritten) - if (initialized) { + writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition) + writeMetrics.decShuffleRecordsWritten(numRecordsWritten) objOut.flush() bs.flush() close() @@ -238,6 +231,10 @@ private[spark] class DiskBlockObjectWriter( } override def fileSegment(): FileSegment = { + if (!commitAndCloseHasBeenCalled) { + throw new IllegalStateException( + "fileSegment() is only valid after commitAndClose() has been called") + } new FileSegment(file, initialPosition, finalPosition - initialPosition) } @@ -251,12 +248,6 @@ private[spark] class DiskBlockObjectWriter( reportedPosition = pos } - private def callWithTiming(f: => Unit) = { - val start = System.nanoTime() - f - writeMetrics.incShuffleWriteTime(System.nanoTime() - start) - } - // For testing private[spark] override def flush() { objOut.flush() diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 2a4447705fa65..91ef86389a0c3 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -139,8 +139,8 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook(): AnyRef = { - Utils.addShutdownHook { () => - logDebug("Shutdown hook called") + Utils.addShutdownHook(Utils.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => + logInfo("Shutdown hook called") DiskBlockManager.this.doStop() } } @@ -151,7 +151,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon try { Utils.removeShutdownHook(shutdownHook) } catch { - case e: Exception => + case e: Exception => logError(s"Exception while removing shutdown hook.", e) } doStop() diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala index 8964762df6af3..f39325a12d244 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala @@ -32,6 +32,8 @@ import java.nio.ByteBuffer */ private[spark] abstract class ExternalBlockManager { + protected var blockManager: BlockManager = _ + override def toString: String = {"External Block Store"} /** @@ -41,7 +43,9 @@ private[spark] abstract class ExternalBlockManager { * * @throws java.io.IOException if there is any file system failure during the initialization. */ - def init(blockManager: BlockManager, executorId: String): Unit + def init(blockManager: BlockManager, executorId: String): Unit = { + this.blockManager = blockManager + } /** * Drop the block from underlying external block store, if it exists.. @@ -73,6 +77,11 @@ private[spark] abstract class ExternalBlockManager { */ def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit + def putValues(blockId: BlockId, values: Iterator[_]): Unit = { + val bytes = blockManager.dataSerialize(blockId, values) + putBytes(blockId, bytes) + } + /** * Retrieve the block bytes. * @return Some(ByteBuffer) if the block bytes is successfully retrieved @@ -82,6 +91,17 @@ private[spark] abstract class ExternalBlockManager { */ def getBytes(blockId: BlockId): Option[ByteBuffer] + /** + * Retrieve the block data. + * @return Some(Iterator[Any]) if the block data is successfully retrieved + * None if the block does not exist in the external block store. + * + * @throws java.io.IOException if there is any file system failure in getting the block. + */ + def getValues(blockId: BlockId): Option[Iterator[_]] = { + getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) + } + /** * Get the size of the block saved in the underlying external block store, * which is saved before by putBytes. diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala index 0bf770306ae9b..291394ed34816 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -18,9 +18,11 @@ package org.apache.spark.storage import java.nio.ByteBuffer + +import scala.util.control.NonFatal + import org.apache.spark.Logging import org.apache.spark.util.Utils -import scala.util.control.NonFatal /** @@ -40,7 +42,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: externalBlockManager.map(_.getSize(blockId)).getOrElse(0) } catch { case NonFatal(t) => - logError(s"error in getSize from $blockId", t) + logError(s"Error in getSize($blockId)", t) 0L } } @@ -54,7 +56,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: values: Array[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - putIterator(blockId, values.toIterator, level, returnValues) + putIntoExternalBlockStore(blockId, values.toIterator, returnValues) } override def putIterator( @@ -62,42 +64,70 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: values: Iterator[Any], level: StorageLevel, returnValues: Boolean): PutResult = { - logDebug(s"Attempting to write values for block $blockId") - val bytes = blockManager.dataSerialize(blockId, values) - putIntoExternalBlockStore(blockId, bytes, returnValues) + putIntoExternalBlockStore(blockId, values, returnValues) } private def putIntoExternalBlockStore( blockId: BlockId, - bytes: ByteBuffer, + values: Iterator[_], returnValues: Boolean): PutResult = { - // So that we do not modify the input offsets ! - // duplicate does not copy buffer, so inexpensive - val byteBuffer = bytes.duplicate() - byteBuffer.rewind() - logDebug(s"Attempting to put block $blockId into ExtBlk store") + logTrace(s"Attempting to put block $blockId into ExternalBlockStore") // we should never hit here if externalBlockManager is None. Handle it anyway for safety. try { val startTime = System.currentTimeMillis if (externalBlockManager.isDefined) { - externalBlockManager.get.putBytes(blockId, bytes) + externalBlockManager.get.putValues(blockId, values) + val size = getSize(blockId) + val data = if (returnValues) { + Left(getValues(blockId).get) + } else { + null + } val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format( - blockId, Utils.bytesToString(byteBuffer.limit), finishTime - startTime)) + blockId, Utils.bytesToString(size), finishTime - startTime)) + PutResult(size, data) + } else { + logError(s"Error in putValues($blockId): no ExternalBlockManager has been configured") + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) + } + } catch { + case NonFatal(t) => + logError(s"Error in putValues($blockId)", t) + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) + } + } - if (returnValues) { - PutResult(bytes.limit(), Right(bytes.duplicate())) + private def putIntoExternalBlockStore( + blockId: BlockId, + bytes: ByteBuffer, + returnValues: Boolean): PutResult = { + logTrace(s"Attempting to put block $blockId into ExternalBlockStore") + // we should never hit here if externalBlockManager is None. Handle it anyway for safety. + try { + val startTime = System.currentTimeMillis + if (externalBlockManager.isDefined) { + val byteBuffer = bytes.duplicate() + byteBuffer.rewind() + externalBlockManager.get.putBytes(blockId, byteBuffer) + val size = bytes.limit() + val data = if (returnValues) { + Right(bytes) } else { - PutResult(bytes.limit(), null) + null } + val finishTime = System.currentTimeMillis + logDebug("Block %s stored as %s file in ExternalBlockStore in %d ms".format( + blockId, Utils.bytesToString(size), finishTime - startTime)) + PutResult(size, data) } else { - logError(s"error in putBytes $blockId") - PutResult(bytes.limit(), null, Seq((blockId, BlockStatus.empty))) + logError(s"Error in putBytes($blockId): no ExternalBlockManager has been configured") + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) } } catch { case NonFatal(t) => - logError(s"error in putBytes $blockId", t) - PutResult(bytes.limit(), null, Seq((blockId, BlockStatus.empty))) + logError(s"Error in putBytes($blockId)", t) + PutResult(-1, null, Seq((blockId, BlockStatus.empty))) } } @@ -107,13 +137,19 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: externalBlockManager.map(_.removeBlock(blockId)).getOrElse(true) } catch { case NonFatal(t) => - logError(s"error in removing $blockId", t) + logError(s"Error in removeBlock($blockId)", t) true } } override def getValues(blockId: BlockId): Option[Iterator[Any]] = { - getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) + try { + externalBlockManager.flatMap(_.getValues(blockId)) + } catch { + case NonFatal(t) => + logError(s"Error in getValues($blockId)", t) + None + } } override def getBytes(blockId: BlockId): Option[ByteBuffer] = { @@ -121,7 +157,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: externalBlockManager.flatMap(_.getBytes(blockId)) } catch { case NonFatal(t) => - logError(s"error in getBytes from $blockId", t) + logError(s"Error in getBytes($blockId)", t) None } } @@ -130,13 +166,13 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: try { val ret = externalBlockManager.map(_.blockExists(blockId)).getOrElse(false) if (!ret) { - logInfo(s"remove block $blockId") + logInfo(s"Remove block $blockId") blockManager.removeBlock(blockId, true) } ret } catch { case NonFatal(t) => - logError(s"error in getBytes from $blockId", t) + logError(s"Error in getBytes($blockId)", t) false } } diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 95e2d688d9b17..021a9facfb0b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -24,6 +24,8 @@ import java.io.File * based off an offset and a length. */ private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { + require(offset >= 0, s"File segment offset cannot be negative (got $offset)") + require(length >= 0, s"File segment length cannot be negative (got $length)") override def toString: String = { "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index bdc6276e41915..b53c86e89a273 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -22,7 +22,10 @@ import java.nio.ByteBuffer import java.text.SimpleDateFormat import java.util.{Date, Random} +import scala.util.control.NonFatal + import com.google.common.io.ByteStreams + import tachyon.client.{ReadType, WriteType, TachyonFS, TachyonFile} import tachyon.TachyonURI @@ -38,7 +41,6 @@ import org.apache.spark.util.Utils */ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Logging { - var blockManager: BlockManager =_ var rootDirs: String = _ var master: String = _ var client: tachyon.client.TachyonFS = _ @@ -52,7 +54,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def init(blockManager: BlockManager, executorId: String): Unit = { - this.blockManager = blockManager + super.init(blockManager, executorId) val storeDir = blockManager.conf.get(ExternalBlockStore.BASE_DIR, "/tmp_spark_tachyon") val appFolderName = blockManager.conf.get(ExternalBlockStore.FOLD_NAME) @@ -95,8 +97,29 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit = { val file = getFile(blockId) val os = file.getOutStream(WriteType.TRY_CACHE) - os.write(bytes.array()) - os.close() + try { + os.write(bytes.array()) + } catch { + case NonFatal(e) => + logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) + os.cancel() + } finally { + os.close() + } + } + + override def putValues(blockId: BlockId, values: Iterator[_]): Unit = { + val file = getFile(blockId) + val os = file.getOutStream(WriteType.TRY_CACHE) + try { + blockManager.dataSerializeStream(blockId, os, values) + } catch { + case NonFatal(e) => + logWarning(s"Failed to put values of block $blockId into Tachyon", e) + os.cancel() + } finally { + os.close() + } } override def getBytes(blockId: BlockId): Option[ByteBuffer] = { @@ -105,21 +128,31 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log return None } val is = file.getInStream(ReadType.CACHE) - assert (is != null) try { val size = file.length val bs = new Array[Byte](size.asInstanceOf[Int]) ByteStreams.readFully(is, bs) Some(ByteBuffer.wrap(bs)) } catch { - case ioe: IOException => - logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe) + case NonFatal(e) => + logWarning(s"Failed to get bytes of block $blockId from Tachyon", e) None } finally { is.close() } } + override def getValues(blockId: BlockId): Option[Iterator[_]] = { + val file = getFile(blockId) + if (file == null || file.getLocationHosts().size() == 0) { + return None + } + val is = file.getInStream(ReadType.CACHE) + Option(is).map { is => + blockManager.dataDeserializeStream(blockId, is) + } + } + override def getSize(blockId: BlockId): Long = { getFile(blockId.name).length } @@ -184,7 +217,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log tachyonDir = client.getFile(path) } } catch { - case e: Exception => + case NonFatal(e) => logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e) } } @@ -206,7 +239,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log Utils.deleteRecursively(tachyonDir, client) } } catch { - case e: Exception => + case NonFatal(e) => logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index bfe4a180e8a6f..3788916cf39bb 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -19,7 +19,8 @@ package org.apache.spark.ui import java.util.Date -import org.apache.spark.status.api.v1.{ApplicationAttemptInfo, ApplicationInfo, JsonRootResource, UIRoot} +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, + UIRoot} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener @@ -64,7 +65,7 @@ private[spark] class SparkUI private ( attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) - attachHandler(JsonRootResource.getJsonServlet(this)) + attachHandler(ApiRootResource.getServletHandler(this)) // This should be POST only, but, the YARN AM proxy won't proxy POSTs attachHandler(createRedirectHandler( "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, @@ -136,7 +137,7 @@ private[spark] object SparkUI { jobProgressListener: JobProgressListener, securityManager: SecurityManager, appName: String, - startTime: Long): SparkUI = { + startTime: Long): SparkUI = { create(Some(sc), conf, listenerBus, securityManager, appName, jobProgressListener = Some(jobProgressListener), startTime = startTime) } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index ad16becde85dd..65162f4fdcd62 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -309,7 +309,7 @@ private[spark] object UIUtils extends Logging { started: Int, completed: Int, failed: Int, - skipped:Int, + skipped: Int, total: Int): Seq[Node] = { val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) val startWidth = "width: %s%%".format((started.toDouble/total)*100) @@ -352,10 +352,12 @@ private[spark] object UIUtils extends Logging {
    -
    + @@ -237,11 +249,11 @@ public class GaussianMixtureExample { GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd()); // Save and load GaussianMixtureModel - gmm.save(sc, "myGMMModel") - GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel") + gmm.save(sc.sc(), "myGMMModel"); + GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc.sc(), "myGMMModel"); // Output the parameters of the mixture model for(int j=0; j println(s"${a.id} -> ${a.cluster}") } + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = PowerIterationClusteringModel.load(sc, "myModelPath") {% endhighlight %} A full example that produces the experiment described in the PIC paper can be found under @@ -360,6 +376,10 @@ PowerIterationClusteringModel model = pic.run(similarities); for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { System.out.println(a.id() + " -> " + a.cluster()); } + +// Save and load model +model.save(sc.sc(), "myModelPath"); +PowerIterationClusteringModel sameModel = PowerIterationClusteringModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 7b397e30b2d90..dfdf6216b270c 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -107,7 +107,8 @@ other signals), you can use the `trainImplicit` method to get better results. {% highlight scala %} val alpha = 0.01 -val model = ALS.trainImplicit(ratings, rank, numIterations, alpha) +val lambda = 0.01 +val model = ALS.trainImplicit(ratings, rank, numIterations, lambda, alpha) {% endhighlight %} diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 4f2a2f71048f7..d824dab1d7f7b 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -31,7 +31,7 @@ The base class of local vectors is implementations: [`DenseVector`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseVector) and [`SparseVector`](api/scala/index.html#org.apache.spark.mllib.linalg.SparseVector). We recommend using the factory methods implemented in -[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) to create local vectors. +[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) to create local vectors. {% highlight scala %} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -57,7 +57,7 @@ The base class of local vectors is implementations: [`DenseVector`](api/java/org/apache/spark/mllib/linalg/DenseVector.html) and [`SparseVector`](api/java/org/apache/spark/mllib/linalg/SparseVector.html). We recommend using the factory methods implemented in -[`Vectors`](api/java/org/apache/spark/mllib/linalg/Vector.html) to create local vectors. +[`Vectors`](api/java/org/apache/spark/mllib/linalg/Vectors.html) to create local vectors. {% highlight java %} import org.apache.spark.mllib.linalg.Vector; @@ -84,7 +84,7 @@ and the following as sparse vectors: with a single column We recommend using NumPy arrays over lists for efficiency, and using the factory methods implemented -in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vector) to create sparse vectors. +in [`Vectors`](api/python/pyspark.mllib.html#pyspark.mllib.linalg.Vectors) to create sparse vectors. {% highlight python %} import numpy as np @@ -241,7 +241,7 @@ The base class of local matrices is [`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide one implementation: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix). We recommend using the factory methods implemented -in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices) to create local +in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices$) to create local matrices. {% highlight scala %} @@ -296,70 +296,6 @@ backed by an RDD of its entries. The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. In general the use of non-deterministic RDDs can lead to errors. -### BlockMatrix - -A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where a `MatrixBlock` is -a tuple of `((Int, Int), Matrix)`, where the `(Int, Int)` is the index of the block, and `Matrix` is -the sub-matrix at the given index with size `rowsPerBlock` x `colsPerBlock`. -`BlockMatrix` supports methods such as `add` and `multiply` with another `BlockMatrix`. -`BlockMatrix` also has a helper function `validate` which can be used to check whether the -`BlockMatrix` is set up properly. - -
    -
    - -A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be -most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. -`toBlockMatrix` creates blocks of size 1024 x 1024 by default. -Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. - -{% highlight scala %} -import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry} - -val entries: RDD[MatrixEntry] = ... // an RDD of (i, j, v) matrix entries -// Create a CoordinateMatrix from an RDD[MatrixEntry]. -val coordMat: CoordinateMatrix = new CoordinateMatrix(entries) -// Transform the CoordinateMatrix to a BlockMatrix -val matA: BlockMatrix = coordMat.toBlockMatrix().cache() - -// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. -// Nothing happens if it is valid. -matA.validate() - -// Calculate A^T A. -val ata = matA.transpose.multiply(matA) -{% endhighlight %} -
    - -
    - -A [`BlockMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) can be -most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. -`toBlockMatrix` creates blocks of size 1024 x 1024 by default. -Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. - -{% highlight java %} -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.distributed.BlockMatrix; -import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; -import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix; - -JavaRDD entries = ... // a JavaRDD of (i, j, v) Matrix Entries -// Create a CoordinateMatrix from a JavaRDD. -CoordinateMatrix coordMat = new CoordinateMatrix(entries.rdd()); -// Transform the CoordinateMatrix to a BlockMatrix -BlockMatrix matA = coordMat.toBlockMatrix().cache(); - -// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. -// Nothing happens if it is valid. -matA.validate(); - -// Calculate A^T A. -BlockMatrix ata = matA.transpose().multiply(matA); -{% endhighlight %} -
    -
    - ### RowMatrix A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, backed by an RDD @@ -530,3 +466,67 @@ IndexedRowMatrix indexedRowMatrix = mat.toIndexedRowMatrix(); {% endhighlight %} + +### BlockMatrix + +A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where a `MatrixBlock` is +a tuple of `((Int, Int), Matrix)`, where the `(Int, Int)` is the index of the block, and `Matrix` is +the sub-matrix at the given index with size `rowsPerBlock` x `colsPerBlock`. +`BlockMatrix` supports methods such as `add` and `multiply` with another `BlockMatrix`. +`BlockMatrix` also has a helper function `validate` which can be used to check whether the +`BlockMatrix` is set up properly. + +
    +
    + +A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry} + +val entries: RDD[MatrixEntry] = ... // an RDD of (i, j, v) matrix entries +// Create a CoordinateMatrix from an RDD[MatrixEntry]. +val coordMat: CoordinateMatrix = new CoordinateMatrix(entries) +// Transform the CoordinateMatrix to a BlockMatrix +val matA: BlockMatrix = coordMat.toBlockMatrix().cache() + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate() + +// Calculate A^T A. +val ata = matA.transpose.multiply(matA) +{% endhighlight %} +
    + +
    + +A [`BlockMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. + +{% highlight java %} +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.mllib.linalg.distributed.BlockMatrix; +import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; +import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix; + +JavaRDD entries = ... // a JavaRDD of (i, j, v) Matrix Entries +// Create a CoordinateMatrix from a JavaRDD. +CoordinateMatrix coordMat = new CoordinateMatrix(entries.rdd()); +// Transform the CoordinateMatrix to a BlockMatrix +BlockMatrix matA = coordMat.toBlockMatrix().cache(); + +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Nothing happens if it is valid. +matA.validate(); + +// Calculate A^T A. +BlockMatrix ata = matA.transpose().multiply(matA); +{% endhighlight %} +
    +
    diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index f723cd6b9dfab..4fe470a8de810 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -188,7 +188,7 @@ Here we assume the extracted file is `text8` and in same directory as you run th import org.apache.spark._ import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.feature.Word2Vec +import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel} val input = sc.textFile("text8").map(line => line.split(" ").toSeq) @@ -201,6 +201,10 @@ val synonyms = model.findSynonyms("china", 40) for((synonym, cosineSimilarity) <- synonyms) { println(s"$synonym $cosineSimilarity") } + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = Word2VecModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -410,6 +414,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.mllib.feature.ChiSqSelector // Load some data in libsvm format val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") @@ -505,7 +510,7 @@ v_N ### Example -This example below demonstrates how to load a simple vectors file, extract a set of vectors, then transform those vectors using a transforming vector value. +This example below demonstrates how to transform vectors using a transforming vector value.
    @@ -514,16 +519,44 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.feature.ElementwiseProduct import org.apache.spark.mllib.linalg.Vectors -// Load and parse the data: -val data = sc.textFile("data/mllib/kmeans_data.txt") -val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))) +// Create some vector data; also works for sparse vectors +val data = sc.parallelize(Array(Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))) val transformingVector = Vectors.dense(0.0, 1.0, 2.0) val transformer = new ElementwiseProduct(transformingVector) // Batch transform and per-row transform give the same results: -val transformedData = transformer.transform(parsedData) -val transformedData2 = parsedData.map(x => transformer.transform(x)) +val transformedData = transformer.transform(data) +val transformedData2 = data.map(x => transformer.transform(x)) + +{% endhighlight %} +
    + +
    +{% highlight java %} +import java.util.Arrays; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.feature.ElementwiseProduct; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +// Create some vector data; also works for sparse vectors +JavaRDD data = sc.parallelize(Arrays.asList( + Vectors.dense(1.0, 2.0, 3.0), Vectors.dense(4.0, 5.0, 6.0))); +Vector transformingVector = Vectors.dense(0.0, 1.0, 2.0); +ElementwiseProduct transformer = new ElementwiseProduct(transformingVector); + +// Batch transform and per-row transform give the same results: +JavaRDD transformedData = transformer.transform(data); +JavaRDD transformedData2 = data.map( + new Function() { + @Override + public Vector call(Vector v) { + return transformer.transform(v); + } + } +); {% endhighlight %}
    diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index f8e879496c135..de7d66fb2dedf 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -39,6 +39,7 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) +* [PMML model export](mllib-pmml-model-export.html) MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index b521c2f27cd6e..5732bc4c7e79e 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -60,7 +60,7 @@ Model is created using the training set and a mean squared error is calculated f labels and real labels in the test set. {% highlight scala %} -import org.apache.spark.mllib.regression.IsotonicRegression +import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt") @@ -88,6 +88,10 @@ val predictionAndLabel = test.map { point => // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean() println("Mean Squared Error = " + meanSquaredError) + +// Save and load model +model.save(sc, "myModelPath") +val sameModel = IsotonicRegressionModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -150,6 +154,10 @@ Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( ).rdd()).mean(); System.out.println("Mean Squared Error = " + meanSquaredError); + +// Save and load model +model.save(sc.sc(), "myModelPath"); +IsotonicRegressionModel sameModel = IsotonicRegressionModel.load(sc.sc(), "myModelPath"); {% endhighlight %}
    diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 2b2be4d9d0273..3dc8cc902fa72 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -163,11 +163,8 @@ object, and make predictions with the resulting model to compute the training error. {% highlight scala %} -import org.apache.spark.SparkContext import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils // Load training data in LIBSVM format. @@ -231,15 +228,13 @@ calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} -import java.util.Random; - import scala.Tuple2; import org.apache.spark.api.java.*; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.*; import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.linalg.Vector; + import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.SparkConf; @@ -282,8 +277,8 @@ public class SVMClassifier { System.out.println("Area under ROC = " + auROC); // Save and load model - model.save(sc.sc(), "myModelPath"); - SVMModel sameModel = SVMModel.load(sc.sc(), "myModelPath"); + model.save(sc, "myModelPath"); + SVMModel sameModel = SVMModel.load(sc, "myModelPath"); } } {% endhighlight %} @@ -315,15 +310,12 @@ a dependency.
    -The following example shows how to load a sample dataset, build Logistic Regression model, +The following example shows how to load a sample dataset, build SVM model, and make predictions with the resulting model to compute the training error. -Note that the Python API does not yet support model save/load but will in the future. - {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithSGD +from pyspark.mllib.classification import SVMWithSGD, SVMModel from pyspark.mllib.regression import LabeledPoint -from numpy import array # Load and parse the data def parsePoint(line): @@ -334,12 +326,16 @@ data = sc.textFile("data/mllib/sample_svm_data.txt") parsedData = data.map(parsePoint) # Build the model -model = LogisticRegressionWithSGD.train(parsedData) +model = SVMWithSGD.train(parsedData, iterations=100) # Evaluating the model on training data labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = SVMModel.load(sc, "myModelPath") {% endhighlight %}
    @@ -785,8 +781,7 @@ gradient descent (`stepSize`, `numIterations`, `miniBatchFraction`). For each o all three possible regularizations (none, L1 or L2). For Logistic Regression, [L-BFGS](api/scala/index.html#org.apache.spark.mllib.optimization.LBFGS) -version is implemented under [LogisticRegressionWithLBFGS] -(api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS), and this +version is implemented under [LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS), and this version supports both binary and multinomial Logistic Regression while SGD version only supports binary Logistic Regression. However, L-BFGS version doesn't support L1 regularization but SGD one supports L1 regularization. When L1 regularization is not required, L-BFGS version is strongly diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 9780ea52c4994..bf6d124fd5d8d 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -14,14 +14,13 @@ and use it for prediction. MLlib supports [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) -and [Bernoulli naive Bayes] (http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). -These models are typically used for [document classification] -(http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). Within that context, each observation is a document and each feature represents a term whose value is the frequency of the term (in multinomial naive Bayes) or a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes). Feature values must be nonnegative. The model type is selected with an optional parameter -"Multinomial" or "Bernoulli" with "Multinomial" as the default. +"multinomial" or "bernoulli" with "multinomial" as the default. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of @@ -35,7 +34,7 @@ sparsity. Since the training data is only used once, it is not necessary to cach [NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements multinomial naive Bayes. It takes an RDD of [LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an optional -smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial), and outputs a +smoothing parameter `lambda` as input, an optional model type parameter (default is "multinomial"), and outputs a [NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel), which can be used for evaluation and prediction. @@ -54,7 +53,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L) val training = splits(0) val test = splits(1) -val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial") +val model = NaiveBayes.train(training, lambda = 1.0, modelType = "multinomial") val predictionAndLabel = test.map(p => (model.predict(p.features), p.label)) val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count() @@ -75,6 +74,8 @@ optionally smoothing parameter `lambda` as input, and output a can be used for evaluation and prediction. {% highlight java %} +import scala.Tuple2; + import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; @@ -82,7 +83,6 @@ import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.classification.NaiveBayes; import org.apache.spark.mllib.classification.NaiveBayesModel; import org.apache.spark.mllib.regression.LabeledPoint; -import scala.Tuple2; JavaRDD training = ... // training set JavaRDD test = ... // test set diff --git a/docs/mllib-pmml-model-export.md b/docs/mllib-pmml-model-export.md new file mode 100644 index 0000000000000..42ea2ca81f80d --- /dev/null +++ b/docs/mllib-pmml-model-export.md @@ -0,0 +1,86 @@ +--- +layout: global +title: PMML model export - MLlib +displayTitle: MLlib - PMML model export +--- + +* Table of contents +{:toc} + +## MLlib supported models + +MLlib supports model export to Predictive Model Markup Language ([PMML](http://en.wikipedia.org/wiki/Predictive_Model_Markup_Language)). + +The table below outlines the MLlib models that can be exported to PMML and their equivalent PMML model. + + + + + + + + + + + + + + + + + + + + + + + + + +
    MLlib modelPMML model
    KMeansModelClusteringModel
    LinearRegressionModelRegressionModel (functionName="regression")
    RidgeRegressionModelRegressionModel (functionName="regression")
    LassoModelRegressionModel (functionName="regression")
    SVMModelRegressionModel (functionName="classification" normalizationMethod="none")
    Binary LogisticRegressionModelRegressionModel (functionName="classification" normalizationMethod="logit")
    + +## Examples +
    + +
    +To export a supported `model` (see table above) to PMML, simply call `model.toPMML`. + +Here a complete example of building a KMeansModel and print it out in PMML format: +{% highlight scala %} +import org.apache.spark.mllib.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors + +// Load and parse the data +val data = sc.textFile("data/mllib/kmeans_data.txt") +val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() + +// Cluster the data into two classes using KMeans +val numClusters = 2 +val numIterations = 20 +val clusters = KMeans.train(parsedData, numClusters, numIterations) + +// Export to PMML +println("PMML Model:\n" + clusters.toPMML) +{% endhighlight %} + +As well as exporting the PMML model to a String (`model.toPMML` as in the example above), you can export the PMML model to other formats: + +{% highlight scala %} +// Export the model to a String in PMML format +clusters.toPMML + +// Export the model to a local file in PMML format +clusters.toPMML("/tmp/kmeans.xml") + +// Export the model to a directory on a distributed file system in PMML format +clusters.toPMML(sc,"/tmp/kmeans") + +// Export the model to the OutputStream in PMML format +clusters.toPMML(System.out) +{% endhighlight %} + +For unsupported models, either you will not find a `.toPMML` method or an `IllegalArgumentException` will be thrown. + +
    + +
    diff --git a/docs/monitoring.md b/docs/monitoring.md index 1e0fc150862fb..e75018499003a 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -178,9 +178,9 @@ Note that the history server only displays completed Spark jobs. One way to sign In addition to viewing the metrics in the UI, they are also available as JSON. This gives developers an easy way to create new visualizations and monitoring tools for Spark. The JSON is available for -both running applications, and in the history server. The endpoints are mounted at `/json/v1`. Eg., -for the history server, they would typically be accessible at `http://:18080/json/v1`, and -for a running application, at `http://localhost:4040/json/v1`. +both running applications, and in the history server. The endpoints are mounted at `/api/v1`. Eg., +for the history server, they would typically be accessible at `http://:18080/api/v1`, and +for a running application, at `http://localhost:4040/api/v1`. @@ -240,12 +240,12 @@ These endpoints have been strongly versioned to make it easier to develop applic * Individual fields will never be removed for any given endpoint * New endpoints may be added * New fields may be added to existing endpoints -* New versions of the api may be added in the future at a separate endpoint (eg., `json/v2`). New versions are *not* required to be backwards compatible. +* New versions of the api may be added in the future at a separate endpoint (eg., `api/v2`). New versions are *not* required to be backwards compatible. * Api versions may be dropped, but only after at least one minor release of co-existing with a new api version Note that even when examining the UI of a running applications, the `applications/[app-id]` portion is still required, though there is only one application available. Eg. to see the list of jobs for the -running app, you would go to `http://localhost:4040/json/v1/applications/[app-id]/jobs`. This is to +running app, you would go to `http://localhost:4040/api/v1/applications/[app-id]/jobs`. This is to keep the paths consistent in both modes. # Metrics diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 27816515c5de2..10f474f237bfa 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -41,14 +41,15 @@ In addition, if you wish to access an HDFS cluster, you need to add a dependency artifactId = hadoop-client version = -Finally, you need to import some Spark classes and implicit conversions into your program. Add the following lines: +Finally, you need to import some Spark classes into your program. Add the following lines: {% highlight scala %} import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ import org.apache.spark.SparkConf {% endhighlight %} +(Before Spark 1.3.0, you need to explicitly `import org.apache.spark.SparkContext._` to enable essential implicit conversions.) +
    @@ -97,9 +98,9 @@ to your version of HDFS. Some common HDFS version tags are listed on the [Prebuilt packages](http://spark.apache.org/downloads.html) are also available on the Spark homepage for common HDFS versions. -Finally, you need to import some Spark classes into your program. Add the following lines: +Finally, you need to import some Spark classes into your program. Add the following line: -{% highlight scala %} +{% highlight python %} from pyspark import SparkContext, SparkConf {% endhighlight %} @@ -477,7 +478,6 @@ the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main for examples of using Cassandra / HBase ```InputFormat``` and ```OutputFormat``` with custom converters.
    - ## RDD Operations @@ -821,11 +821,9 @@ by a key. In Scala, these operations are automatically available on RDDs containing [Tuple2](http://www.scala-lang.org/api/{{site.SCALA_VERSION}}/index.html#scala.Tuple2) objects -(the built-in tuples in the language, created by simply writing `(a, b)`), as long as you -import `org.apache.spark.SparkContext._` in your program to enable Spark's implicit -conversions. The key-value pair operations are available in the +(the built-in tuples in the language, created by simply writing `(a, b)`). The key-value pair operations are available in the [PairRDDFunctions](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions) class, -which automatically wraps around an RDD of tuples if you import the conversions. +which automatically wraps around an RDD of tuples. For example, the following code uses the `reduceByKey` operation on key-value pairs to count how many times each line of text occurs in a file: @@ -916,7 +914,8 @@ The following table lists some of the common transformations supported by Spark. RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.html#pyspark.RDD)) + [Python](api/python/pyspark.html#pyspark.RDD), + [R](api/R/index.html)) and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1029,7 +1028,9 @@ The following table lists some of the common actions supported by Spark. Refer t RDD API doc ([Scala](api/scala/index.html#org.apache.spark.rdd.RDD), [Java](api/java/index.html?org/apache/spark/api/java/JavaRDD.html), - [Python](api/python/pyspark.html#pyspark.RDD)) + [Python](api/python/pyspark.html#pyspark.RDD), + [R](api/R/index.html)) + and pair RDD functions doc ([Scala](api/scala/index.html#org.apache.spark.rdd.PairRDDFunctions), [Java](api/java/index.html?org/apache/spark/api/java/JavaPairRDD.html)) @@ -1071,7 +1072,7 @@ for details.
    - @@ -1122,7 +1123,7 @@ ordered data following shuffle then it's possible to use: * `sortBy` to make a globally ordered RDD Operations which can cause a shuffle include **repartition** operations like -[`repartition`](#RepartitionLink), and [`coalesce`](#CoalesceLink), **'ByKey** operations +[`repartition`](#RepartitionLink) and [`coalesce`](#CoalesceLink), **'ByKey** operations (except for counting) like [`groupByKey`](#GroupByLink) and [`reduceByKey`](#ReduceByLink), and **join** operations like [`cogroup`](#CogroupLink) and [`join`](#JoinLink). @@ -1138,7 +1139,7 @@ read the relevant sorted blocks. Certain shuffle operations can consume significant amounts of heap memory since they employ in-memory data structures to organize records before or after transferring them. Specifically, -`reduceByKey` and `aggregateByKey` create these structures on the map side and `'ByKey` operations +`reduceByKey` and `aggregateByKey` create these structures on the map side, and `'ByKey` operations generate these on the reduce side. When data does not fit in memory Spark will spill these tables to disk, incurring the additional overhead of disk I/O and increased garbage collection. @@ -1213,9 +1214,11 @@ storage levels is: Compared to MEMORY_ONLY_SER, OFF_HEAP reduces garbage collection overhead and allows executors to be smaller and to share a pool of memory, making it attractive in environments with large heaps or multiple concurrent applications. Furthermore, as the RDDs reside in Tachyon, - the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory + the crash of an executor does not lead to losing the in-memory cache. In this mode, the memory in Tachyon is discardable. Thus, Tachyon does not attempt to reconstruct a block that it evicts - from memory. + from memory. If you plan to use Tachyon as the off heap store, Spark is compatible with Tachyon + out-of-the-box. Please refer to this page + for the suggested version pairings.
    EndpointMeaning
    saveAsSequenceFile(path)
    (Java and Scala)
    Write the elements of the dataset as a Hadoop SequenceFile in a given path in the local filesystem, HDFS or any other Hadoop-supported file system. This is available on RDDs of key-value pairs that either implement Hadoop's Writable interface. In Scala, it is also + Write the elements of the dataset as a Hadoop SequenceFile in a given path in the local filesystem, HDFS or any other Hadoop-supported file system. This is available on RDDs of key-value pairs that implement Hadoop's Writable interface. In Scala, it is also available on types that are implicitly convertible to Writable (Spark includes conversions for basic types like Int, Double, String, etc).
    @@ -1566,7 +1569,8 @@ You can see some [example Spark programs](http://spark.apache.org/examples.html) In addition, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), - [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)). + [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), + [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)). You can run Java and Scala examples by passing the class name to Spark's `bin/run-example` script; for instance: ./bin/run-example SparkPi @@ -1575,6 +1579,10 @@ For Python examples, use `spark-submit` instead: ./bin/spark-submit examples/src/main/python/pi.py +For R examples, use `spark-submit` instead: + + ./bin/spark-submit examples/src/main/r/dataframe.R + For help on optimizing your programs, the [configuration](configuration.html) and [tuning](tuning.html) guides provide information on best practices. They are especially important for making sure that your data is stored in memory in an efficient format. @@ -1582,4 +1590,4 @@ For help on deploying, the [cluster mode overview](cluster-overview.html) descri in distributed operation and supported cluster managers. Finally, full API documentation is available in -[Scala](api/scala/#org.apache.spark.package), [Java](api/java/) and [Python](api/python/). +[Scala](api/scala/#org.apache.spark.package), [Java](api/java/), [Python](api/python/) and [R](api/R/). diff --git a/docs/quick-start.md b/docs/quick-start.md index 81143da865cf0..bb39e4111f244 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -184,10 +184,10 @@ scala> linesWithSpark.cache() res7: spark.RDD[String] = spark.FilteredRDD@17e51082 scala> linesWithSpark.count() -res8: Long = 15 +res8: Long = 19 scala> linesWithSpark.count() -res9: Long = 15 +res9: Long = 19 {% endhighlight %} It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is @@ -202,10 +202,10 @@ a cluster, as described in the [programming guide](programming-guide.html#initia >>> linesWithSpark.cache() >>> linesWithSpark.count() -15 +19 >>> linesWithSpark.count() -15 +19 {% endhighlight %} It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is @@ -423,14 +423,14 @@ dependencies to `spark-submit` through its `--py-files` argument by packaging th We can run this application using the `bin/spark-submit` script: -{% highlight python %} +{% highlight bash %} # Use spark-submit to run your application $ YOUR_SPARK_HOME/bin/spark-submit \ --master local[4] \ SimpleApp.py ... Lines with a: 46, Lines with b: 23 -{% endhighlight python %} +{% endhighlight %} @@ -444,7 +444,8 @@ Congratulations on running your first Spark application! * Finally, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples), - [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python)). + [Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python), + [R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r)). You can run them as follows: {% highlight bash %} @@ -453,4 +454,7 @@ You can run them as follows: # For Python examples, use spark-submit directly: ./bin/spark-submit examples/src/main/python/pi.py + +# For R examples, use spark-submit directly: +./bin/spark-submit examples/src/main/r/dataframe.R {% endhighlight %} diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 51c1339165024..96cf612c54fdd 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -71,9 +71,22 @@ Most of the configs are the same for Spark on YARN as for other deployment modes spark.yarn.scheduler.heartbeat.interval-ms - 5000 + 3000 The interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. + The value is capped at half the value of YARN's configuration for the expiry interval + (yarn.am.liveness-monitor.expiry-interval-ms). + + + + spark.yarn.scheduler.initial-allocation.interval + 200ms + + The initial interval in which the Spark application master eagerly heartbeats to the YARN ResourceManager + when there are pending container allocation requests. It should be no larger than + spark.yarn.scheduler.heartbeat.interval-ms. The allocation interval will doubled on + successive eager heartbeats if pending containers still exist, until + spark.yarn.scheduler.heartbeat.interval-ms is reached. @@ -229,6 +242,22 @@ Most of the configs are the same for Spark on YARN as for other deployment modes running against earlier versions, this property will be ignored. + + spark.yarn.keytab + (none) + + The full path to the file that contains the keytab for the principal specified above. + This keytab will be copied to the node running the Application Master via the Secure Distributed Cache, + for renewing the login tickets and the delegation tokens periodically. + + + + spark.yarn.principal + (none) + + Principal to be used to login to KDC, while running on secure HDFS. + + # Launching Spark on YARN diff --git a/docs/sparkr.md b/docs/sparkr.md new file mode 100644 index 0000000000000..4d82129921a37 --- /dev/null +++ b/docs/sparkr.md @@ -0,0 +1,223 @@ +--- +layout: global +displayTitle: SparkR (R on Spark) +title: SparkR (R on Spark) +--- + +* This will become a table of contents (this text will be scraped). +{:toc} + +# Overview +SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. +In Spark {{site.SPARK_VERSION}}, SparkR provides a distributed data frame implementation that +supports operations like selection, filtering, aggregation etc. (similar to R data frames, +[dplyr](https://github.com/hadley/dplyr)) but on large datasets. + +# SparkR DataFrames + +A DataFrame is a distributed collection of data organized into named columns. It is conceptually +equivalent to a table in a relational database or a data frame in R, but with richer +optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: +structured data files, tables in Hive, external databases, or existing local R data frames. + +All of the examples on this page use sample data included in R or the Spark distribution and can be run using the `./bin/sparkR` shell. + +## Starting Up: SparkContext, SQLContext + +
    +The entry point into SparkR is the `SparkContext` which connects your R program to a Spark cluster. +You can create a `SparkContext` using `sparkR.init` and pass in options such as the application name +etc. Further, to work with DataFrames we will need a `SQLContext`, which can be created from the +SparkContext. If you are working from the SparkR shell, the `SQLContext` and `SparkContext` should +already be created for you. + +{% highlight r %} +sc <- sparkR.init() +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} + +
    + +## Creating DataFrames +With a `SQLContext`, applications can create `DataFrame`s from a local R data frame, from a [Hive table](sql-programming-guide.html#hive-tables), or from other [data sources](sql-programming-guide.html#data-sources). + +### From local data frames +The simplest way to create a data frame is to convert a local R data frame into a SparkR DataFrame. Specifically we can use `createDataFrame` and pass in the local R data frame to create a SparkR DataFrame. As an example, the following creates a `DataFrame` based using the `faithful` dataset from R. + +
    +{% highlight r %} +df <- createDataFrame(sqlContext, faithful) + +# Displays the content of the DataFrame to stdout +head(df) +## eruptions waiting +##1 3.600 79 +##2 1.800 54 +##3 3.333 74 + +{% endhighlight %} +
    + +### From Data Sources + +SparkR supports operating on a variety of data sources through the `DataFrame` interface. This section describes the general methods for loading and saving data using Data Sources. You can check the Spark SQL programming guide for more [specific options](sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. + +The general method for creating DataFrames from data sources is `read.df`. This method takes in the `SQLContext`, the path for the file to load and the type of data source. SparkR supports reading JSON and Parquet files natively and through [Spark Packages](http://spark-packages.org/) you can find data source connectors for popular file formats like [CSV](http://spark-packages.org/package/databricks/spark-csv) and [Avro](http://spark-packages.org/package/databricks/spark-avro). + +We can see how to use data sources using an example JSON input file. Note that the file that is used here is _not_ a typical JSON file. Each line in the file must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. + +
    + +{% highlight r %} +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") +head(people) +## age name +##1 NA Michael +##2 30 Andy +##3 19 Justin + +# SparkR automatically infers the schema from the JSON file +printSchema(people) +# root +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) + +{% endhighlight %} +
    + +The data sources API can also be used to save out DataFrames into multiple file formats. For example we can save the DataFrame from the previous example +to a Parquet file using `write.df` + +
    +{% highlight r %} +write.df(people, path="people.parquet", source="parquet", mode="overwrite") +{% endhighlight %} +
    + +### From Hive tables + +You can also create SparkR DataFrames from Hive tables. To do this we will need to create a HiveContext which can access tables in the Hive MetaStore. Note that Spark should have been built with [Hive support](building-spark.html#building-with-hive-and-jdbc-support) and more details on the difference between SQLContext and HiveContext can be found in the [SQL programming guide](sql-programming-guide.html#starting-point-sqlcontext). + +
    +{% highlight r %} +# sc is an existing SparkContext. +hiveContext <- sparkRHive.init(sc) + +sql(hiveContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(hiveContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results <- hiveContext.sql("FROM src SELECT key, value") + +# results is now a DataFrame +head(results) +## key value +## 1 238 val_238 +## 2 86 val_86 +## 3 311 val_311 + +{% endhighlight %} +
    + +## DataFrame Operations + +SparkR DataFrames support a number of functions to do structured data processing. +Here we include some basic examples and a complete list can be found in the [API](api/R/index.html) docs: + +### Selecting rows, columns + +
    +{% highlight r %} +# Create the DataFrame +df <- createDataFrame(sqlContext, faithful) + +# Get basic information about the DataFrame +df +## DataFrame[eruptions:double, waiting:double] + +# Select only the "eruptions" column +head(select(df, df$eruptions)) +## eruptions +##1 3.600 +##2 1.800 +##3 3.333 + +# You can also pass in column name as strings +head(select(df, "eruptions")) + +# Filter the DataFrame to only retain rows with wait times shorter than 50 mins +head(filter(df, df$waiting < 50)) +## eruptions waiting +##1 1.750 47 +##2 1.750 47 +##3 1.867 48 + +{% endhighlight %} + +
    + +### Grouping, Aggregation + +SparkR data frames support a number of commonly used functions to aggregate data after grouping. For example we can compute a histogram of the `waiting` time in the `faithful` dataset as shown below + +
    +{% highlight r %} + +# We use the `n` operator to count the number of times each waiting time appears +head(summarize(groupBy(df, df$waiting), count = n(df$waiting))) +## waiting count +##1 81 13 +##2 60 6 +##3 68 1 + +# We can also sort the output from the aggregation to get the most common waiting times +waiting_counts <- summarize(groupBy(df, df$waiting), count = n(df$waiting)) +head(arrange(waiting_counts, desc(waiting_counts$count))) + +## waiting count +##1 78 15 +##2 83 14 +##3 81 13 + +{% endhighlight %} +
    + +### Operating on Columns + +SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. + +
    +{% highlight r %} + +# Convert waiting time from hours to seconds. +# Note that we can assign this to a new column in the same DataFrame +df$waiting_secs <- df$waiting * 60 +head(df) +## eruptions waiting waiting_secs +##1 3.600 79 4740 +##2 1.800 54 3240 +##3 3.333 74 4440 + +{% endhighlight %} +
    + +## Running SQL Queries from SparkR +A SparkR DataFrame can also be registered as a temporary table in Spark SQL and registering a DataFrame as a table allows you to run SQL queries over its data. +The `sql` function enables applications to run SQL queries programmatically and returns the result as a `DataFrame`. + +
    +{% highlight r %} +# Load a JSON file +people <- read.df(sqlContext, "./examples/src/main/resources/people.json", "json") + +# Register this DataFrame as a table. +registerTempTable(people, "people") + +# SQL statements can be run by using the sql method +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +head(teenagers) +## name +##1 Justin + +{% endhighlight %} +
    diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 78b8e8ad515a0..282ea75e1e785 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -11,14 +11,15 @@ title: Spark SQL and DataFrames Spark SQL is a Spark module for structured data processing. It provides a programming abstraction called DataFrames and can also act as distributed SQL query engine. +For how to enable Hive support, please refer to the [Hive Tables](#hive-tables) section. # DataFrames A DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. DataFrames can be constructed from a wide array of sources such as: structured data files, tables in Hive, external databases, or existing RDDs. -The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), and [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame). +The DataFrame API is available in [Scala](api/scala/index.html#org.apache.spark.sql.DataFrame), [Java](api/java/index.html?org/apache/spark/sql/DataFrame.html), [Python](api/python/pyspark.sql.html#pyspark.sql.DataFrame), and [R](api/R/index.html). -All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell` or the `pyspark` shell. +All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`, `pyspark` shell, or `sparkR` shell. ## Starting Point: `SQLContext` @@ -64,6 +65,17 @@ from pyspark.sql import SQLContext sqlContext = SQLContext(sc) {% endhighlight %} + + +
    + +The entry point into all relational functionality in Spark is the +`SQLContext` class, or one of its decedents. To create a basic `SQLContext`, all you need is a SparkContext. + +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) +{% endhighlight %} +
    @@ -97,7 +109,7 @@ As an example, the following creates a `DataFrame` based on the content of a JSO val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) -val df = sqlContext.jsonFile("examples/src/main/resources/people.json") +val df = sqlContext.read.json("examples/src/main/resources/people.json") // Displays the content of the DataFrame to stdout df.show() @@ -110,7 +122,7 @@ df.show() JavaSparkContext sc = ...; // An existing JavaSparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); -DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); // Displays the content of the DataFrame to stdout df.show(); @@ -123,13 +135,26 @@ df.show(); from pyspark.sql import SQLContext sqlContext = SQLContext(sc) -df = sqlContext.jsonFile("examples/src/main/resources/people.json") +df = sqlContext.read.json("examples/src/main/resources/people.json") # Displays the content of the DataFrame to stdout df.show() {% endhighlight %} + +
    +{% highlight r %} +sqlContext <- SQLContext(sc) + +df <- jsonFile(sqlContext, "examples/src/main/resources/people.json") + +# Displays the content of the DataFrame to stdout +showDF(df) +{% endhighlight %} + +
    + @@ -146,7 +171,7 @@ val sc: SparkContext // An existing SparkContext. val sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame -val df = sqlContext.jsonFile("examples/src/main/resources/people.json") +val df = sqlContext.read.json("examples/src/main/resources/people.json") // Show the content of the DataFrame df.show() @@ -196,7 +221,7 @@ JavaSparkContext sc // An existing SparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame -DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); +DataFrame df = sqlContext.read().json("examples/src/main/resources/people.json"); // Show the content of the DataFrame df.show(); @@ -252,7 +277,7 @@ from pyspark.sql import SQLContext sqlContext = SQLContext(sc) # Create the DataFrame -df = sqlContext.jsonFile("examples/src/main/resources/people.json") +df = sqlContext.read.json("examples/src/main/resources/people.json") # Show the content of the DataFrame df.show() @@ -296,6 +321,57 @@ df.groupBy("age").count().show() {% endhighlight %} + +
    +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) + +# Create the DataFrame +df <- jsonFile(sqlContext, "examples/src/main/resources/people.json") + +# Show the content of the DataFrame +showDF(df) +## age name +## null Michael +## 30 Andy +## 19 Justin + +# Print the schema in a tree format +printSchema(df) +## root +## |-- age: long (nullable = true) +## |-- name: string (nullable = true) + +# Select only the "name" column +showDF(select(df, "name")) +## name +## Michael +## Andy +## Justin + +# Select everybody, but increment the age by 1 +showDF(select(df, df$name, df$age + 1)) +## name (age + 1) +## Michael null +## Andy 31 +## Justin 20 + +# Select people older than 21 +showDF(where(df, df$age > 21)) +## age name +## 30 Andy + +# Count people by age +showDF(count(groupBy(df, "age"))) +## age count +## null 1 +## 19 1 +## 30 1 + +{% endhighlight %} + +
    + @@ -325,6 +401,14 @@ sqlContext = SQLContext(sc) df = sqlContext.sql("SELECT * FROM table") {% endhighlight %} + +
    +{% highlight r %} +sqlContext <- sparkRSQL.init(sc) +df <- sql(sqlContext, "SELECT * FROM table") +{% endhighlight %} +
    + @@ -693,8 +777,8 @@ In the simplest form, the default data source (`parquet` unless otherwise config
    {% highlight scala %} -val df = sqlContext.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").save("namesAndFavColors.parquet") +val df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") {% endhighlight %}
    @@ -703,8 +787,8 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet") {% highlight java %} -DataFrame df = sqlContext.load("examples/src/main/resources/users.parquet"); -df.select("name", "favorite_color").save("namesAndFavColors.parquet"); +DataFrame df = sqlContext.read().load("examples/src/main/resources/users.parquet"); +df.select("name", "favorite_color").write().save("namesAndFavColors.parquet"); {% endhighlight %} @@ -714,9 +798,18 @@ df.select("name", "favorite_color").save("namesAndFavColors.parquet"); {% highlight python %} -df = sqlContext.load("examples/src/main/resources/users.parquet") -df.select("name", "favorite_color").save("namesAndFavColors.parquet") +df = sqlContext.read.load("examples/src/main/resources/users.parquet") +df.select("name", "favorite_color").write.save("namesAndFavColors.parquet") + +{% endhighlight %} + + + +
    +{% highlight r %} +df <- loadDF(sqlContext, "people.parquet") +saveDF(select(df, "name", "age"), "namesAndAges.parquet") {% endhighlight %}
    @@ -734,8 +827,8 @@ using this syntax.
    {% highlight scala %} -val df = sqlContext.load("examples/src/main/resources/people.json", "json") -df.select("name", "age").save("namesAndAges.parquet", "parquet") +val df = sqlContext.read.format("json").load("examples/src/main/resources/people.json") +df.select("name", "age").write.format("json").save("namesAndAges.parquet") {% endhighlight %}
    @@ -744,8 +837,8 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet") {% highlight java %} -DataFrame df = sqlContext.load("examples/src/main/resources/people.json", "json"); -df.select("name", "age").save("namesAndAges.parquet", "parquet"); +DataFrame df = sqlContext.read().format("json").load("examples/src/main/resources/people.json"); +df.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); {% endhighlight %} @@ -755,8 +848,18 @@ df.select("name", "age").save("namesAndAges.parquet", "parquet"); {% highlight python %} -df = sqlContext.load("examples/src/main/resources/people.json", "json") -df.select("name", "age").save("namesAndAges.parquet", "parquet") +df = sqlContext.read.load("examples/src/main/resources/people.json", format="json") +df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") + +{% endhighlight %} + + +
    + +{% highlight r %} + +df <- loadDF(sqlContext, "people.json", "json") +saveDF(select(df, "name", "age"), "namesAndAges.parquet", "parquet") {% endhighlight %} @@ -804,7 +907,7 @@ new data. Ignore mode means that when saving a DataFrame to a data source, if data already exists, the save operation is expected to not save the contents of the DataFrame and to not - change the existing data. This is similar to a `CREATE TABLE IF NOT EXISTS` in SQL. + change the existing data. This is similar to a CREATE TABLE IF NOT EXISTS in SQL. @@ -844,11 +947,11 @@ import sqlContext.implicits._ val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. // The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet. -people.saveAsParquetFile("people.parquet") +people.write.parquet("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a Parquet file is also a DataFrame. -val parquetFile = sqlContext.parquetFile("people.parquet") +val parquetFile = sqlContext.read.parquet("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile") @@ -866,11 +969,11 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println) DataFrame schemaPeople = ... // The DataFrame from the previous example. // DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet"); +schemaPeople.write().parquet("people.parquet"); // Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. -DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); +DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -892,11 +995,11 @@ List teenagerNames = teenagers.javaRDD().map(new Function() schemaPeople # The DataFrame from the previous example. # DataFrames can be saved as Parquet files, maintaining the schema information. -schemaPeople.saveAsParquetFile("people.parquet") +schemaPeople.read.parquet("people.parquet") # Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. # The result of loading a parquet file is also a DataFrame. -parquetFile = sqlContext.parquetFile("people.parquet") +parquetFile = sqlContext.write.parquet("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -908,6 +1011,31 @@ for teenName in teenNames.collect():
    +
    + +{% highlight r %} +# sqlContext from the previous example is used in this example. + +schemaPeople # The DataFrame from the previous example. + +# DataFrames can be saved as Parquet files, maintaining the schema information. +saveAsParquetFile(schemaPeople, "people.parquet") + +# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved. +# The result of loading a parquet file is also a DataFrame. +parquetFile <- parquetFile(sqlContext, "people.parquet") + +# Parquet files can also be registered as tables and then used in SQL statements. +registerTempTable(parquetFile, "parquetFile"); +teenagers <- sql(sqlContext, "SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") +teenNames <- map(teenagers, function(p) { paste("Name:", p$name)}) +for (teenName in collect(teenNames)) { + cat(teenName, "\n") +} +{% endhighlight %} + +
    +
    {% highlight sql %} @@ -959,9 +1087,9 @@ path {% endhighlight %} -By passing `path/to/table` to either `SQLContext.parquetFile` or `SQLContext.load`, Spark SQL will -automatically extract the partitioning information from the paths. Now the schema of the returned -DataFrame becomes: +By passing `path/to/table` to either `SQLContext.read.parquet` or `SQLContext.read.load`, Spark SQL +will automatically extract the partitioning information from the paths. +Now the schema of the returned DataFrame becomes: {% highlight text %} @@ -994,15 +1122,15 @@ import sqlContext.implicits._ // Create a simple DataFrame, stored into a partition directory val df1 = sparkContext.makeRDD(1 to 5).map(i => (i, i * 2)).toDF("single", "double") -df1.saveAsParquetFile("data/test_table/key=1") +df1.write.parquet("data/test_table/key=1") // Create another DataFrame in a new partition directory, // adding a new column and dropping an existing column val df2 = sparkContext.makeRDD(6 to 10).map(i => (i, i * 3)).toDF("single", "triple") -df2.saveAsParquetFile("data/test_table/key=2") +df2.write.parquet("data/test_table/key=2") // Read the partitioned table -val df3 = sqlContext.parquetFile("data/test_table") +val df3 = sqlContext.read.parquet("data/test_table") df3.printSchema() // The final schema consists of all 3 columns in the Parquet files together @@ -1033,7 +1161,7 @@ df2 = sqlContext.createDataFrame(sc.parallelize(range(6, 11)) df2.save("data/test_table/key=2", "parquet") # Read the partitioned table -df3 = sqlContext.parquetFile("data/test_table") +df3 = sqlContext.load("data/test_table", "parquet") df3.printSchema() # The final schema consists of all 3 columns in the Parquet files together @@ -1047,6 +1175,33 @@ df3.printSchema()
    +
    + +{% highlight r %} +# sqlContext from the previous example is used in this example. + +# Create a simple DataFrame, stored into a partition directory +saveDF(df1, "data/test_table/key=1", "parquet", "overwrite") + +# Create another DataFrame in a new partition directory, +# adding a new column and dropping an existing column +saveDF(df2, "data/test_table/key=2", "parquet", "overwrite") + +# Read the partitioned table +df3 <- loadDF(sqlContext, "data/test_table", "parquet") +printSchema(df3) + +# The final schema consists of all 3 columns in the Parquet files together +# with the partiioning column appeared in the partition directory paths. +# root +# |-- single: int (nullable = true) +# |-- double: int (nullable = true) +# |-- triple: int (nullable = true) +# |-- key : int (nullable = true) +{% endhighlight %} + +
    + ### Configuration @@ -1114,12 +1269,10 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext`
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +This conversion can be done using `SQLContext.read.json()` on either an RDD of String, +or a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1130,8 +1283,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. val path = "examples/src/main/resources/people.json" -// Create a DataFrame from the file(s) pointed to by path -val people = sqlContext.jsonFile(path) +val people = sqlContext.read.json(path) // The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -1149,19 +1301,17 @@ val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age // an RDD[String] storing one JSON object per string. val anotherPeopleRDD = sc.parallelize( """{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil) -val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) +val anotherPeople = sqlContext.read.json(anotherPeopleRDD) {% endhighlight %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext` : +This conversion can be done using `SQLContext.read().json()` on either an RDD of String, +or a JSON file. -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. - -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1171,9 +1321,7 @@ SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); // A JSON dataset is pointed to by path. // The path can be either a single text file or a directory storing text files. -String path = "examples/src/main/resources/people.json"; -// Create a DataFrame from the file(s) pointed to by path -DataFrame people = sqlContext.jsonFile(path); +DataFrame people = sqlContext.read().json("examples/src/main/resources/people.json"); // The inferred schema can be visualized using the printSchema() method. people.printSchema(); @@ -1192,18 +1340,15 @@ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AN List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = sc.parallelize(jsonData); -DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD); +DataFrame anotherPeople = sqlContext.read().json(anotherPeopleRDD); {% endhighlight %}
    Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. -This conversion can be done using one of two methods in a `SQLContext`: - -* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object. -* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object. +This conversion can be done using `SQLContext.read.json` on a JSON file. -Note that the file that is offered as _jsonFile_ is not a typical JSON file. Each +Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. As a consequence, a regular multi-line JSON file will most often fail. @@ -1214,9 +1359,7 @@ sqlContext = SQLContext(sc) # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. -path = "examples/src/main/resources/people.json" -# Create a DataFrame from the file(s) pointed to by path -people = sqlContext.jsonFile(path) +people = sqlContext.read.json("examples/src/main/resources/people.json") # The inferred schema can be visualized using the printSchema() method. people.printSchema() @@ -1238,6 +1381,39 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) {% endhighlight %}
    +
    +Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame. using +the `jsonFile` function, which loads data from a directory of JSON files where each line of the +files is a JSON object. + +Note that the file that is offered as _a json file_ is not a typical JSON file. Each +line must contain a separate, self-contained valid JSON object. As a consequence, +a regular multi-line JSON file will most often fail. + +{% highlight r %} +# sc is an existing SparkContext. +sqlContext <- sparkRSQL.init(sc) + +# A JSON dataset is pointed to by path. +# The path can be either a single text file or a directory storing text files. +path <- "examples/src/main/resources/people.json" +# Create a DataFrame from the file(s) pointed to by path +people <- jsonFile(sqlContex,t path) + +# The inferred schema can be visualized using the printSchema() method. +printSchema(people) +# root +# |-- age: integer (nullable = true) +# |-- name: string (nullable = true) + +# Register this DataFrame as a table. +registerTempTable(people, "people") + +# SQL statements can be run by using the sql methods provided by `sqlContext`. +teenagers <- sql(sqlContext, "SELECT name FROM people WHERE age >= 13 AND age <= 19") +{% endhighlight %} +
    +
    {% highlight sql %} @@ -1314,10 +1490,7 @@ Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
    When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to -the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be -expressed in HiveQL. - +adds support for finding tables in the MetaStore and writing queries using HiveQL. {% highlight python %} # sc is an existing SparkContext. from pyspark.sql import HiveContext @@ -1331,9 +1504,91 @@ results = sqlContext.sql("FROM src SELECT key, value").collect() {% endhighlight %} +
    + +
    + +When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and +adds support for finding tables in the MetaStore and writing queries using HiveQL. +{% highlight r %} +# sc is an existing SparkContext. +sqlContext <- sparkRHive.init(sc) + +sql(sqlContext, "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sql(sqlContext, "LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") + +# Queries can be expressed in HiveQL. +results = sqlContext.sql("FROM src SELECT key, value").collect() + +{% endhighlight %} +
    +### Interacting with Different Versions of Hive Metastore + +One of the most important pieces of Spark SQL's Hive support is interaction with Hive metastore, +which enables Spark SQL to access metadata of Hive tables. Starting from Spark 1.4.0, a single binary build of Spark SQL can be used to query different versions of Hive metastores, using the configuration described below. + +Internally, Spark SQL uses two Hive clients, one for executing native Hive commands like `SET` +and `DESCRIBE`, the other dedicated for communicating with Hive metastore. The former uses Hive +jars of version 0.13.1, which are bundled with Spark 1.4.0. The latter uses Hive jars of the +version specified by users. An isolated classloader is used here to avoid dependency conflicts. + + + + + + + + + + + + + + + + + + + + + + + +
    Property NameDefaultMeaning
    spark.sql.hive.metastore.version0.13.1 + Version of the Hive metastore. Available + options are 0.12.0 and 0.13.1. Support for more versions is coming in the future. +
    spark.sql.hive.metastore.jarsbuiltin + Location of the jars that should be used to instantiate the HiveMetastoreClient. This + property can be one of three options: +
      +
    1. builtin
    2. + Use Hive 0.13.1, which is bundled with the Spark assembly jar when -Phive is + enabled. When this option is chosen, spark.sql.hive.metastore.version must be + either 0.13.1 or not defined. +
    3. maven
    4. + Use Hive jars of specified version downloaded from Maven repositories. +
    5. A classpath in the standard format for both Hive and Hadoop.
    6. +
    +
    spark.sql.hive.metastore.sharedPrefixescom.mysql.jdbc,
    org.postgresql,
    com.microsoft.sqlserver,
    oracle.jdbc
    +

    + A comma separated list of class prefixes that should be loaded using the classloader that is + shared between Spark SQL and a specific version of Hive. An example of classes that should + be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + to be shared are those that interact with classes that are already shared. For example, + custom appenders that are used by log4j. +

    +
    spark.sql.hive.metastore.barrierPrefixes(empty) +

    + A comma separated list of class prefixes that should explicitly be reloaded for each version + of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + prefix that typically would be shared (i.e. org.apache.spark.*). +

    +
    + + ## JDBC To Other Databases Spark SQL also includes a data source that can read data from other databases using JDBC. This @@ -1367,7 +1622,7 @@ the Data Sources API. The following options are supported: dbtable - The JDBC table that should be read. Note that anything that is valid in a `FROM` clause of + The JDBC table that should be read. Note that anything that is valid in a FROM clause of a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses. @@ -1430,6 +1685,16 @@ df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="sch +
    + +{% highlight r %} + +df <- loadDF(sqlContext, source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") + +{% endhighlight %} + +
    +
    {% highlight sql %} @@ -1501,7 +1766,7 @@ that these options will be deprecated in future release as more optimizations ar Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently statistics are only supported for Hive Metastore tables where the command - `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run. + ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run. @@ -1524,7 +1789,9 @@ that these options will be deprecated in future release as more optimizations ar # Distributed SQL Engine -Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, without the need to write any code. +Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface. +In this mode, end-users or applications can interact with Spark SQL directly to run SQL queries, +without the need to write any code. ## Running the Thrift JDBC/ODBC server @@ -1603,6 +1870,25 @@ options. ## Upgrading from Spark SQL 1.3 to 1.4 +#### DataFrame data reader/writer interface + +Based on user feedback, we created a new, more fluid API for reading data in (`SQLContext.read`) +and writing data out (`DataFrame.write`), +and deprecated the old APIs (e.g. `SQLContext.parquetFile`, `SQLContext.jsonFile`). + +See the API docs for `SQLContext.read` ( + Scala, + Java, + Python +) and `DataFrame.write` ( + Scala, + Java, + Python +) more information. + + +#### DataFrame.groupBy retains grouping columns + Based on user feedback, we changed the default behavior of `DataFrame.groupBy().agg()` to retain the grouping columns in the resulting `DataFrame`. To keep the behavior in 1.3, set `spark.sql.retainGroupColumns` to `false`.
    @@ -1726,7 +2012,7 @@ sqlContext.udf.register("strLen", (s: String) => s.length())
    {% highlight java %} -sqlContext.udf().register("strLen", (String s) -> { s.length(); }); +sqlContext.udf().register("strLen", (String s) -> s.length(), DataTypes.IntegerType); {% endhighlight %}
    @@ -2354,5 +2640,151 @@ from pyspark.sql.types import *
    +
    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    Data typeValue type in RAPI to access or create a data type
    ByteType + integer
    + Note: Numbers will be converted to 1-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -128 to 127. +
    + "byte" +
    ShortType + integer
    + Note: Numbers will be converted to 2-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -32768 to 32767. +
    + "short" +
    IntegerType integer + "integer" +
    LongType + integer
    + Note: Numbers will be converted to 8-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of + -9223372036854775808 to 9223372036854775807. + Otherwise, please convert data to decimal.Decimal and use DecimalType. +
    + "long" +
    FloatType + numeric
    + Note: Numbers will be converted to 4-byte single-precision floating + point numbers at runtime. +
    + "float" +
    DoubleType numeric + "double" +
    DecimalType Not supported + Not supported +
    StringType character + "string" +
    BinaryType raw + "binary" +
    BooleanType logical + "bool" +
    TimestampType POSIXct + "timestamp" +
    DateType Date + "date" +
    ArrayType vector or list + list(type="array", elementType=elementType, containsNull=[containsNull])
    + Note: The default value of containsNull is True. +
    MapType enviroment + list(type="map", keyType=keyType, valueType=valueType, valueContainsNull=[valueContainsNull])
    + Note: The default value of valueContainsNull is True. +
    StructType named list + list(type="struct", fields=fields)
    + Note: fields is a Seq of StructFields. Also, two fields with the same + name are not allowed. +
    StructField The value type in R of the data type of this field + (For example, integer for a StructField with the data type IntegerType) + list(name=name, type=dataType, nullable=nullable) +
    + +
    +
    diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index bd863d48d53e3..42b33947873b0 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1946,10 +1946,10 @@ creates a single receiver (running on a worker machine) that receives a single s Receiving multiple data streams can therefore be achieved by creating multiple input DStreams and configuring them to receive different partitions of the data stream from the source(s). For example, a single Kafka input DStream receiving two topics of data can be split into two -Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to be received in parallel, and increasing overall throughput. These multiple -DStream can be unioned together to create a single DStream. Then the transformations that was -being applied on the single input DStream can applied on the unified stream. This is done as follows. +Kafka input streams, each receiving only one topic. This would run two receivers, +allowing data to be received in parallel, and increasing overall throughput. These multiple +DStreams can be unioned together to create a single DStream. Then the transformations that were +being applied on a single input DStream can be applied on the unified stream. This is done as follows.
    diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index ab4a96f232c13..ee0904c9e5d54 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -19,8 +19,9 @@ # limitations under the License. # -from __future__ import with_statement, print_function +from __future__ import division, print_function, with_statement +import codecs import hashlib import itertools import logging @@ -47,8 +48,10 @@ else: from urllib.request import urlopen, Request from urllib.error import HTTPError + raw_input = input + xrange = range -SPARK_EC2_VERSION = "1.2.1" +SPARK_EC2_VERSION = "1.3.1" SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) VALID_SPARK_VERSIONS = set([ @@ -65,6 +68,8 @@ "1.1.1", "1.2.0", "1.2.1", + "1.3.0", + "1.3.1", ]) SPARK_TACHYON_MAP = { @@ -75,6 +80,8 @@ "1.1.1": "0.5.0", "1.2.0": "0.5.0", "1.2.1": "0.5.0", + "1.3.0": "0.5.0", + "1.3.1": "0.5.0", } DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION @@ -419,13 +426,14 @@ def get_spark_ami(opts): b=opts.spark_ec2_git_branch) ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) + reader = codecs.getreader("ascii") try: - ami = urlopen(ami_path).read().strip() - print("Spark AMI: " + ami) + ami = reader(urlopen(ami_path)).read().strip() except: print("Could not resolve AMI at: " + ami_path, file=stderr) sys.exit(1) + print("Spark AMI: " + ami) return ami @@ -483,6 +491,8 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('udp', 2049, 2049, authorized_address) master_group.authorize('tcp', 4242, 4242, authorized_address) master_group.authorize('udp', 4242, 4242, authorized_address) + # RM in YARN mode uses 8088 + master_group.authorize('tcp', 8088, 8088, authorized_address) if opts.ganglia: master_group.authorize('tcp', 5080, 5080, authorized_address) if slave_group.rules == []: # Group was just now created @@ -746,7 +756,7 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): 'mapreduce', 'spark-standalone', 'tachyon'] if opts.hadoop_major_version == "1": - modules = filter(lambda x: x != "mapreduce", modules) + modules = list(filter(lambda x: x != "mapreduce", modules)) if opts.ganglia: modules.append('ganglia') @@ -860,7 +870,11 @@ def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): for i in cluster_instances: i.update() - statuses = conn.get_all_instance_status(instance_ids=[i.id for i in cluster_instances]) + max_batch = 100 + statuses = [] + for j in xrange(0, len(cluster_instances), max_batch): + batch = [i.id for i in cluster_instances[j:j + max_batch]] + statuses.extend(conn.get_all_instance_status(instance_ids=batch)) if cluster_state == 'ssh-ready': if all(i.state == 'running' for i in cluster_instances) and \ @@ -1152,7 +1166,7 @@ def get_zones(conn, opts): # Gets the number of items in a partition def get_partition(total, num_partitions, current_partitions): - num_slaves_this_zone = total / num_partitions + num_slaves_this_zone = total // num_partitions if (total % num_partitions) - current_partitions > 0: num_slaves_this_zone += 1 return num_slaves_this_zone diff --git a/examples/pom.xml b/examples/pom.xml index 5b04b4f8d6ca0..e4efee7b5e647 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -97,6 +97,11 @@ + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} + org.apache.hbase hbase-testing-util @@ -392,45 +397,6 @@ - - - scala-2.10 - - !scala-2.11 - - - - org.apache.spark - spark-streaming-kafka_${scala.binary.version} - ${project.version} - - - - - - org.codehaus.mojo - build-helper-maven-plugin - - - add-scala-sources - generate-sources - - add-source - - - - src/main/scala - scala-2.10/src/main/scala - scala-2.10/src/main/java - - - - - - - - diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index eac4f898a475d..ec533d174ebdc 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -28,6 +28,7 @@ import org.apache.spark.ml.classification.ClassificationModel; import org.apache.spark.ml.param.IntParam; import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.util.Identifiable$; import org.apache.spark.mllib.linalg.BLAS; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; @@ -103,7 +104,23 @@ public static void main(String[] args) throws Exception { * However, this should still compile and run successfully. */ class MyJavaLogisticRegression - extends Classifier { + extends Classifier { + + public MyJavaLogisticRegression() { + init(); + } + + public MyJavaLogisticRegression(String uid) { + this.uid_ = uid; + init(); + } + + private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); + + @Override + public String uid() { + return uid_; + } /** * Param for max number of iterations @@ -117,7 +134,7 @@ class MyJavaLogisticRegression int getMaxIter() { return (Integer) getOrDefault(maxIter); } - public MyJavaLogisticRegression() { + private void init() { setMaxIter(100); } @@ -137,7 +154,7 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. // Create a model, and return it. - return new MyJavaLogisticRegressionModel(this, weights); + return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this); } } @@ -149,17 +166,21 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) { * However, this should still compile and run successfully. */ class MyJavaLogisticRegressionModel - extends ClassificationModel { - - private MyJavaLogisticRegression parent_; - public MyJavaLogisticRegression parent() { return parent_; } + extends ClassificationModel { private Vector weights_; public Vector weights() { return weights_; } - public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) { - this.parent_ = parent_; - this.weights_ = weights_; + public MyJavaLogisticRegressionModel(String uid, Vector weights) { + this.uid_ = uid; + this.weights_ = weights; + } + + private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg"); + + @Override + public String uid() { + return uid_; } // This uses the default implementation of transform(), which reads column "features" and outputs @@ -204,6 +225,6 @@ public Vector predictRaw(Vector features) { */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra); + return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra); } } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java new file mode 100644 index 0000000000000..75063dbf800d8 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.commons.cli.*; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.OneVsRest; +import org.apache.spark.ml.classification.OneVsRestModel; +import org.apache.spark.ml.util.MetadataUtils; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructField; + +/** + * An example runner for Multiclass to Binary Reduction with One Vs Rest. + * The example uses Logistic Regression as the base classifier. All parameters that + * can be specified on the base classifier can be passed in to the runner options. + * Run with + *
    + * bin/run-example ml.JavaOneVsRestExample [options]
    + * 
    + */ +public class JavaOneVsRestExample { + + private static class Params { + String input; + String testInput = null; + Integer maxIter = 100; + double tol = 1E-6; + boolean fitIntercept = true; + Double regParam = null; + Double elasticNetParam = null; + double fracTest = 0.2; + } + + public static void main(String[] args) { + // parse the arguments + Params params = parse(args); + SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // configure the base classifier + LogisticRegression classifier = new LogisticRegression() + .setMaxIter(params.maxIter) + .setTol(params.tol) + .setFitIntercept(params.fitIntercept); + + if (params.regParam != null) { + classifier.setRegParam(params.regParam); + } + if (params.elasticNetParam != null) { + classifier.setElasticNetParam(params.elasticNetParam); + } + + // instantiate the One Vs Rest Classifier + OneVsRest ovr = new OneVsRest().setClassifier(classifier); + + String input = params.input; + RDD inputData = MLUtils.loadLibSVMFile(jsc.sc(), input); + RDD train; + RDD test; + + // compute the train/ test split: if testInput is not provided use part of input + String testInput = params.testInput; + if (testInput != null) { + train = inputData; + // compute the number of features in the training set. + int numFeatures = inputData.first().features().size(); + test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures); + } else { + double f = params.fracTest; + RDD[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345); + train = tmp[0]; + test = tmp[1]; + } + + // train the multiclass model + DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class); + OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache()); + + // score the model on test data + DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class); + DataFrame predictions = ovrModel.transform(testDataFrame.cache()) + .select("prediction", "label"); + + // obtain metrics + MulticlassMetrics metrics = new MulticlassMetrics(predictions); + StructField predictionColSchema = predictions.schema().apply("prediction"); + Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get(); + + // compute the false positive rate per label + StringBuilder results = new StringBuilder(); + results.append("label\tfpr\n"); + for (int label = 0; label < numClasses; label++) { + results.append(label); + results.append("\t"); + results.append(metrics.falsePositiveRate((double) label)); + results.append("\n"); + } + + Matrix confusionMatrix = metrics.confusionMatrix(); + // output the Confusion Matrix + System.out.println("Confusion Matrix"); + System.out.println(confusionMatrix); + System.out.println(); + System.out.println(results); + + jsc.stop(); + } + + private static Params parse(String[] args) { + Options options = generateCommandlineOptions(); + CommandLineParser parser = new PosixParser(); + Params params = new Params(); + + try { + CommandLine cmd = parser.parse(options, args); + String value; + if (cmd.hasOption("input")) { + params.input = cmd.getOptionValue("input"); + } + if (cmd.hasOption("maxIter")) { + value = cmd.getOptionValue("maxIter"); + params.maxIter = Integer.parseInt(value); + } + if (cmd.hasOption("tol")) { + value = cmd.getOptionValue("tol"); + params.tol = Double.parseDouble(value); + } + if (cmd.hasOption("fitIntercept")) { + value = cmd.getOptionValue("fitIntercept"); + params.fitIntercept = Boolean.parseBoolean(value); + } + if (cmd.hasOption("regParam")) { + value = cmd.getOptionValue("regParam"); + params.regParam = Double.parseDouble(value); + } + if (cmd.hasOption("elasticNetParam")) { + value = cmd.getOptionValue("elasticNetParam"); + params.elasticNetParam = Double.parseDouble(value); + } + if (cmd.hasOption("testInput")) { + value = cmd.getOptionValue("testInput"); + params.testInput = value; + } + if (cmd.hasOption("fracTest")) { + value = cmd.getOptionValue("fracTest"); + params.fracTest = Double.parseDouble(value); + } + + } catch (ParseException e) { + printHelpAndQuit(options); + } + return params; + } + + private static Options generateCommandlineOptions() { + Option input = OptionBuilder.withArgName("input") + .hasArg() + .isRequired() + .withDescription("input path to labeled examples. This path must be specified") + .create("input"); + Option testInput = OptionBuilder.withArgName("testInput") + .hasArg() + .withDescription("input path to test examples") + .create("testInput"); + Option fracTest = OptionBuilder.withArgName("testInput") + .hasArg() + .withDescription("fraction of data to hold out for testing." + + " If given option testInput, this option is ignored. default: 0.2") + .create("fracTest"); + Option maxIter = OptionBuilder.withArgName("maxIter") + .hasArg() + .withDescription("maximum number of iterations for Logistic Regression. default:100") + .create("maxIter"); + Option tol = OptionBuilder.withArgName("tol") + .hasArg() + .withDescription("the convergence tolerance of iterations " + + "for Logistic Regression. default: 1E-6") + .create("tol"); + Option fitIntercept = OptionBuilder.withArgName("fitIntercept") + .hasArg() + .withDescription("fit intercept for logistic regression. default true") + .create("fitIntercept"); + Option regParam = OptionBuilder.withArgName( "regParam" ) + .hasArg() + .withDescription("the regularization parameter for Logistic Regression.") + .create("regParam"); + Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" ) + .hasArg() + .withDescription("the ElasticNet mixing parameter for Logistic Regression.") + .create("elasticNetParam"); + + Options options = new Options() + .addOption(input) + .addOption(testInput) + .addOption(fracTest) + .addOption(maxIter) + .addOption(tol) + .addOption(fitIntercept) + .addOption(regParam) + .addOption(elasticNetParam); + + return options; + } + + private static void printHelpAndQuit(Options options) { + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp("JavaOneVsRestExample", options); + System.exit(-1); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 29158d5c85651..dac649d1d5ae6 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -97,7 +97,7 @@ public static void main(String[] args) { DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. + // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. DataFrame results = model2.transform(test); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index ef1ec103a879f..54738813d0016 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -66,7 +66,7 @@ public static void main(String[] args) { .setOutputCol("features"); LogisticRegression lr = new LogisticRegression() .setMaxIter(10) - .setRegParam(0.01); + .setRegParam(0.001); Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); @@ -77,7 +77,7 @@ public static void main(String[] args) { List localTest = Lists.newArrayList( new Document(4L, "spark i j k"), new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), + new Document(6L, "spark hadoop spark"), new Document(7L, "apache hadoop")); DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 8159ffbe2d269..afee279ec32b1 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -94,12 +94,12 @@ public String call(Row row) { System.out.println("=== Data source: Parquet File ==="); // DataFrames can be saved as parquet files, maintaining the schema information. - schemaPeople.saveAsParquetFile("people.parquet"); + schemaPeople.write().parquet("people.parquet"); // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. // The result of loading a parquet file is also a DataFrame. - DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); + DataFrame parquetFile = sqlContext.read().parquet("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); @@ -120,7 +120,7 @@ public String call(Row row) { // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; // Create a DataFrame from the file(s) pointed by path - DataFrame peopleFromJsonFile = sqlContext.jsonFile(path); + DataFrame peopleFromJsonFile = sqlContext.read().json(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -151,7 +151,7 @@ public String call(Row row) { List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - DataFrame peopleFromJsonRDD = sqlContext.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlContext.read().json(anotherPeopleRDD.rdd()); // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java similarity index 100% rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java similarity index 100% rename from examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java rename to examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index 5b82a14fba413..c5ae5d043b8ea 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -18,6 +18,7 @@ from __future__ import print_function import sys +import json from pyspark import SparkContext @@ -27,24 +28,24 @@ hbase(main):016:0> create 'test', 'f1' 0 row(s) in 1.0430 seconds -hbase(main):017:0> put 'test', 'row1', 'f1', 'value1' +hbase(main):017:0> put 'test', 'row1', 'f1:a', 'value1' 0 row(s) in 0.0130 seconds -hbase(main):018:0> put 'test', 'row2', 'f1', 'value2' +hbase(main):018:0> put 'test', 'row1', 'f1:b', 'value2' 0 row(s) in 0.0030 seconds -hbase(main):019:0> put 'test', 'row3', 'f1', 'value3' +hbase(main):019:0> put 'test', 'row2', 'f1', 'value3' 0 row(s) in 0.0050 seconds -hbase(main):020:0> put 'test', 'row4', 'f1', 'value4' +hbase(main):020:0> put 'test', 'row3', 'f1', 'value4' 0 row(s) in 0.0110 seconds hbase(main):021:0> scan 'test' ROW COLUMN+CELL - row1 column=f1:, timestamp=1401883411986, value=value1 - row2 column=f1:, timestamp=1401883415212, value=value2 - row3 column=f1:, timestamp=1401883417858, value=value3 - row4 column=f1:, timestamp=1401883420805, value=value4 + row1 column=f1:a, timestamp=1401883411986, value=value1 + row1 column=f1:b, timestamp=1401883415212, value=value2 + row2 column=f1:, timestamp=1401883417858, value=value3 + row3 column=f1:, timestamp=1401883420805, value=value4 4 row(s) in 0.0240 seconds """ if __name__ == "__main__": @@ -64,6 +65,8 @@ table = sys.argv[2] sc = SparkContext(appName="HBaseInputFormat") + # Other options for configuring scan behavior are available. More information available at + # https://github.com/apache/hbase/blob/master/hbase-server/src/main/java/org/apache/hadoop/hbase/mapreduce/TableInputFormat.java conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table} if len(sys.argv) > 3: conf = {"hbase.zookeeper.quorum": host, "zookeeper.znode.parent": sys.argv[3], @@ -78,6 +81,8 @@ keyConverter=keyConv, valueConverter=valueConv, conf=conf) + hbase_rdd = hbase_rdd.flatMapValues(lambda v: v.split("\n")).mapValues(json.loads) + output = hbase_rdd.collect() for (k, v) in output: print((k, v)) diff --git a/examples/src/main/python/ml/cross_validator.py b/examples/src/main/python/ml/cross_validator.py new file mode 100644 index 0000000000000..f0ca97c724940 --- /dev/null +++ b/examples/src/main/python/ml/cross_validator.py @@ -0,0 +1,96 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import LogisticRegression +from pyspark.ml.evaluation import BinaryClassificationEvaluator +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.ml.tuning import CrossValidator, ParamGridBuilder +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating model selection using CrossValidator. +This example also demonstrates how Pipelines are Estimators. +Run with: + + bin/spark-submit examples/src/main/python/ml/cross_validator.py +""" + +if __name__ == "__main__": + sc = SparkContext(appName="CrossValidatorExample") + sqlContext = SQLContext(sc) + + # Prepare training documents, which are labeled. + LabeledDocument = Row("id", "text", "label") + training = sc.parallelize([(0, "a b c d e spark", 1.0), + (1, "b d", 0.0), + (2, "spark f g h", 1.0), + (3, "hadoop mapreduce", 0.0), + (4, "b spark who", 1.0), + (5, "g d a y", 0.0), + (6, "spark fly", 1.0), + (7, "was mapreduce", 0.0), + (8, "e spark program", 1.0), + (9, "a e c l", 0.0), + (10, "spark compile", 1.0), + (11, "hadoop software", 0.0) + ]) \ + .map(lambda x: LabeledDocument(*x)).toDF() + + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer(inputCol="text", outputCol="words") + hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") + lr = LogisticRegression(maxIter=10) + pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) + + # We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. + # This will allow us to jointly choose parameters for all Pipeline stages. + # A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. + # We use a ParamGridBuilder to construct a grid of parameters to search over. + # With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, + # this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. + paramGrid = ParamGridBuilder() \ + .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \ + .addGrid(lr.regParam, [0.1, 0.01]) \ + .build() + + crossval = CrossValidator(estimator=pipeline, + estimatorParamMaps=paramGrid, + evaluator=BinaryClassificationEvaluator(), + numFolds=2) # use 3+ folds in practice + + # Run cross-validation, and choose the best set of parameters. + cvModel = crossval.fit(training) + + # Prepare test documents, which are unlabeled. + Document = Row("id", "text") + test = sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) \ + .map(lambda x: Document(*x)).toDF() + + # Make predictions on test documents. cvModel uses the best model found (lrModel). + prediction = cvModel.transform(test) + selected = prediction.select("id", "text", "probability", "prediction") + for row in selected.collect(): + print(row) + + sc.stop() diff --git a/examples/src/main/python/ml/gradient_boosted_trees.py b/examples/src/main/python/ml/gradient_boosted_trees.py new file mode 100644 index 0000000000000..6446f0fe5eeab --- /dev/null +++ b/examples/src/main/python/ml/gradient_boosted_trees.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import GBTClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import GBTRegressor +from pyspark.mllib.evaluation import BinaryClassificationMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a Gradient Boosted Trees Classification/Regression Pipeline. +Note: GBTClassifier only supports binary classification currently +Run with: + bin/spark-submit examples/src/main/python/ml/gradient_boosted_trees.py +""" + + +def testClassification(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTClassifier(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = BinaryClassificationMetrics(predictionAndLabels) + print("AUC %.3f" % metrics.areaUnderROC) + + +def testRegression(train, test): + # Train a GradientBoostedTrees model. + + rf = GBTRegressor(maxIter=30, maxDepth=4, labelCol="indexedLabel") + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: gradient_boosted_trees", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonGBTExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/random_forest_example.py b/examples/src/main/python/ml/random_forest_example.py new file mode 100644 index 0000000000000..c7730e1bfacd9 --- /dev/null +++ b/examples/src/main/python/ml/random_forest_example.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import RandomForestClassifier +from pyspark.ml.feature import StringIndexer +from pyspark.ml.regression import RandomForestRegressor +from pyspark.mllib.evaluation import MulticlassMetrics, RegressionMetrics +from pyspark.mllib.util import MLUtils +from pyspark.sql import Row, SQLContext + +""" +A simple example demonstrating a RandomForest Classification/Regression Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/random_forest_example.py +""" + + +def testClassification(train, test): + # Train a RandomForest model. + # Setting featureSubsetStrategy="auto" lets the algorithm choose. + # Note: Use larger numTrees in practice. + + rf = RandomForestClassifier(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + +def testRegression(train, test): + # Train a RandomForest model. + # Note: Use larger numTrees in practice. + + rf = RandomForestRegressor(labelCol="indexedLabel", numTrees=3, maxDepth=4) + + model = rf.fit(train) + predictionAndLabels = model.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = RegressionMetrics(predictionAndLabels) + print("rmse %.3f" % metrics.rootMeanSquaredError) + print("r2 %.3f" % metrics.r2) + print("mae %.3f" % metrics.meanAbsoluteError) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: random_forest_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonRandomForestExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [train, test] = td.randomSplit([0.7, 0.3]) + testClassification(train, test) + testRegression(train, test) + sc.stop() diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py new file mode 100644 index 0000000000000..a9f29dab2d602 --- /dev/null +++ b/examples/src/main/python/ml/simple_params_example.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import pprint +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.linalg import DenseVector +from pyspark.mllib.regression import LabeledPoint +from pyspark.sql import SQLContext + +""" +A simple example demonstrating ways to specify parameters for Estimators and Transformers. +Run with: + bin/spark-submit examples/src/main/python/ml/simple_params_example.py +""" + +if __name__ == "__main__": + if len(sys.argv) > 1: + print("Usage: simple_params_example", file=sys.stderr) + exit(1) + sc = SparkContext(appName="PythonSimpleParamsExample") + sqlContext = SQLContext(sc) + + # prepare training data. + # We create an RDD of LabeledPoints and convert them into a DataFrame. + # A LabeledPoint is an Object with two fields named label and features + # and Spark SQL identifies these fields and creates the schema appropriately. + training = sc.parallelize([ + LabeledPoint(1.0, DenseVector([0.0, 1.1, 0.1])), + LabeledPoint(0.0, DenseVector([2.0, 1.0, -1.0])), + LabeledPoint(0.0, DenseVector([2.0, 1.3, 1.0])), + LabeledPoint(1.0, DenseVector([0.0, 1.2, -0.5]))]).toDF() + + # Create a LogisticRegression instance with maxIter = 10. + # This instance is an Estimator. + lr = LogisticRegression(maxIter=10) + # Print out the parameters, documentation, and any default values. + print("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + + # We may also set parameters using setter methods. + lr.setRegParam(0.01) + + # Learn a LogisticRegression model. This uses the parameters stored in lr. + model1 = lr.fit(training) + + # Since model1 is a Model (i.e., a Transformer produced by an Estimator), + # we can view the parameters it used during fit(). + # This prints the parameter (name: value) pairs, where names are unique IDs for this + # LogisticRegression instance. + print("Model 1 was fit using parameters:\n") + pprint.pprint(model1.extractParamMap()) + + # We may alternatively specify parameters using a parameter map. + # paramMap overrides all lr parameters set earlier. + paramMap = {lr.maxIter: 20, lr.threshold: 0.55, lr.probabilityCol: "myProbability"} + + # Now learn a new model using the new parameters. + model2 = lr.fit(training, paramMap) + print("Model 2 was fit using parameters:\n") + pprint.pprint(model2.extractParamMap()) + + # prepare test data. + test = sc.parallelize([ + LabeledPoint(1.0, DenseVector([-1.0, 1.5, 1.3])), + LabeledPoint(0.0, DenseVector([3.0, 2.0, -0.1])), + LabeledPoint(0.0, DenseVector([0.0, 2.2, -1.5]))]).toDF() + + # Make predictions on test data using the Transformer.transform() method. + # LogisticRegressionModel.transform will only use the 'features' column. + # Note that model2.transform() outputs a 'myProbability' column instead of the usual + # 'probability' column since we renamed the lr.probabilityCol parameter previously. + result = model2.transform(test) \ + .select("features", "label", "myProbability", "prediction") \ + .collect() + + for row in result: + print("features=%s,label=%s -> prob=%s, prediction=%s" + % (row.features, row.label, row.myProbability, row.prediction)) + + sc.stop() diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py index fab21f003b233..b4f06bf888746 100644 --- a/examples/src/main/python/ml/simple_text_classification_pipeline.py +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -48,7 +48,7 @@ # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. tokenizer = Tokenizer(inputCol="text", outputCol="words") hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features") - lr = LogisticRegression(maxIter=10, regParam=0.01) + lr = LogisticRegression(maxIter=10, regParam=0.001) pipeline = Pipeline(stages=[tokenizer, hashingTF, lr]) # Fit the pipeline to training documents. @@ -58,7 +58,7 @@ Document = Row("id", "text") test = sc.parallelize([(4, "spark i j k"), (5, "l m n"), - (6, "mapreduce spark"), + (6, "spark hadoop spark"), (7, "apache hadoop")]) \ .map(lambda x: Document(*x)).toDF() diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 11d5c92c5952d..023bb3ee2d108 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -104,8 +104,8 @@ object CassandraCQLTest { val casRdd = sc.newAPIHadoopRDD(job.getConfiguration(), classOf[CqlPagingInputFormat], - classOf[java.util.Map[String,ByteBuffer]], - classOf[java.util.Map[String,ByteBuffer]]) + classOf[java.util.Map[String, ByteBuffer]], + classOf[java.util.Map[String, ByteBuffer]]) println("Count: " + casRdd.count) val productSaleRDD = casRdd.map { @@ -118,7 +118,7 @@ object CassandraCQLTest { case (productId, saleCount) => println(productId + ":" + saleCount) } - val casoutputCF = aggregatedRDD.map { + val casoutputCF = aggregatedRDD.map { case (productId, saleCount) => { val outColFamKey = Map("prod_id" -> ByteBufferUtil.bytes(productId)) val outKey: java.util.Map[String, ByteBuffer] = outColFamKey diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 849887d23c9cf..95c96111c9b1f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -59,5 +59,6 @@ object HBaseTest { hBaseRDD.count() sc.stop() + admin.close() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index a55e0dc8d36c2..c3fc74a116c0a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -39,7 +39,7 @@ object LocalLR { def generateData: Array[DataPoint] = { def generatePoint(i: Int): DataPoint = { - val y = if(i % 2 == 0) -1 else 1 + val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala index 32e02eab8b031..75c82117cbad2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LogQuery.scala @@ -22,7 +22,7 @@ import org.apache.spark.SparkContext._ /** * Executes a roll up-style query against Apache logs. - * + * * Usage: LogQuery [logFile] */ object LogQuery { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index 6c0ac8013ce34..30c4261551837 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -117,7 +117,7 @@ object SparkALS { var us = Array.fill(U)(randomVector(F)) // Iteratively update movies then users - val Rc = sc.broadcast(R) + val Rc = sc.broadcast(R) var msb = sc.broadcast(ms) var usb = sc.broadcast(us) for (iter <- 1 to ITERATIONS) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index 8c01a60844620..1e6b4fb0c7514 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -44,7 +44,7 @@ object SparkLR { def generateData: Array[DataPoint] = { def generatePoint(i: Int): DataPoint = { - val y = if(i % 2 == 0) -1 else 1 + val y = if (i % 2 == 0) -1 else 1 val x = DenseVector.fill(D){rand.nextGaussian + y * R} DataPoint(x, y) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 8d092b6506d33..bd7894f184c4c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -51,7 +51,7 @@ object SparkPageRank { showWarning() val sparkConf = new SparkConf().setAppName("PageRank") - val iters = if (args.length > 0) args(1).toInt else 10 + val iters = if (args.length > 1) args(1).toInt else 10 val ctx = new SparkContext(sparkConf) val lines = ctx.textFile(args(0), 1) val links = lines.map{ s => diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala deleted file mode 100644 index ab6e63deb3c95..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import org.apache.spark._ -import org.apache.spark.bagel._ - -class PageRankUtils extends Serializable { - def computeWithCombiner(numVertices: Long, epsilon: Double)( - self: PRVertex, messageSum: Option[Double], superstep: Int - ): (PRVertex, Array[PRMessage]) = { - val newValue = messageSum match { - case Some(msgSum) if msgSum != 0 => - 0.15 / numVertices + 0.85 * msgSum - case _ => self.value - } - - val terminate = superstep >= 10 - - val outbox: Array[PRMessage] = - if (!terminate) { - self.outEdges.map(targetId => new PRMessage(targetId, newValue / self.outEdges.size)) - } else { - Array[PRMessage]() - } - - (new PRVertex(newValue, self.outEdges, !terminate), outbox) - } - - def computeNoCombiner(numVertices: Long, epsilon: Double) - (self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int) - : (PRVertex, Array[PRMessage]) = - computeWithCombiner(numVertices, epsilon)(self, messages match { - case Some(msgs) => Some(msgs.map(_.value).sum) - case None => None - }, superstep) -} - -class PRCombiner extends Combiner[PRMessage, Double] with Serializable { - def createCombiner(msg: PRMessage): Double = - msg.value - def mergeMsg(combiner: Double, msg: PRMessage): Double = - combiner + msg.value - def mergeCombiners(a: Double, b: Double): Double = - a + b -} - -class PRVertex() extends Vertex with Serializable { - var value: Double = _ - var outEdges: Array[String] = _ - var active: Boolean = _ - - def this(value: Double, outEdges: Array[String], active: Boolean = true) { - this() - this.value = value - this.outEdges = outEdges - this.active = active - } - - override def toString(): String = { - "PRVertex(value=%f, outEdges.length=%d, active=%s)" - .format(value, outEdges.length, active.toString) - } -} - -class PRMessage() extends Message[String] with Serializable { - var targetId: String = _ - var value: Double = _ - - def this(targetId: String, value: Double) { - this() - this.targetId = targetId - this.value = value - } -} - -class CustomPartitioner(partitions: Int) extends Partitioner { - def numPartitions: Int = partitions - - def getPartition(key: Any): Int = { - val hash = key match { - case k: Long => (k & 0x00000000FFFFFFFFL).toInt - case _ => key.hashCode - } - - val mod = key.hashCode % partitions - if (mod < 0) mod + partitions else mod - } - - override def equals(other: Any): Boolean = other match { - case c: CustomPartitioner => - c.numPartitions == numPartitions - case _ => false - } - - override def hashCode: Int = numPartitions -} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala deleted file mode 100644 index 859abedf2a55e..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import org.apache.spark._ -import org.apache.spark.SparkContext._ - -import org.apache.spark.bagel._ - -import scala.xml.{XML,NodeSeq} - -/** - * Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles" - * files from there, which contains one line per wiki article in a tab-separated format - * (http://wiki.freebase.com/wiki/WEX/Documentation#articles). - */ -object WikipediaPageRank { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println( - "Usage: WikipediaPageRank ") - System.exit(-1) - } - val sparkConf = new SparkConf() - sparkConf.setAppName("WikipediaPageRank") - sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage])) - - val inputFile = args(0) - val threshold = args(1).toDouble - val numPartitions = args(2).toInt - val usePartitioner = args(3).toBoolean - - sparkConf.setAppName("WikipediaPageRank") - val sc = new SparkContext(sparkConf) - - // Parse the Wikipedia page data into a graph - val input = sc.textFile(inputFile) - - println("Counting vertices...") - val numVertices = input.count() - println("Done counting vertices.") - - println("Parsing input file...") - var vertices = input.map(line => { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val links = - if (body == "\\N") { - NodeSeq.Empty - } else { - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body) - NodeSeq.Empty - } - } - val outEdges = links.map(link => new String(link.text)).toArray - val id = new String(title) - (id, new PRVertex(1.0 / numVertices, outEdges)) - }) - if (usePartitioner) { - vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache() - } else { - vertices = vertices.cache() - } - println("Done parsing input file.") - - // Do the computation - val epsilon = 0.01 / numVertices - val messages = sc.parallelize(Array[(String, PRMessage)]()) - val utils = new PageRankUtils - val result = - Bagel.run( - sc, vertices, messages, combiner = new PRCombiner(), - numPartitions = numPartitions)( - utils.computeWithCombiner(numVertices, epsilon)) - - // Print the result - System.err.println("Articles with PageRank >= " + threshold + ":") - val top = - (result - .filter { case (id, vertex) => vertex.value >= threshold } - .map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) } - .collect().mkString) - println(top) - - sc.stop() - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala deleted file mode 100644 index 576a3e371b993..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.bagel - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer -import scala.xml.{XML, NodeSeq} - -import org.apache.spark._ -import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - -import scala.reflect.ClassTag - -object WikipediaPageRankStandalone { - def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: WikipediaPageRankStandalone " + - " ") - System.exit(-1) - } - val sparkConf = new SparkConf() - sparkConf.set("spark.serializer", "spark.bagel.examples.WPRSerializer") - - val inputFile = args(0) - val threshold = args(1).toDouble - val numIterations = args(2).toInt - val usePartitioner = args(3).toBoolean - - sparkConf.setAppName("WikipediaPageRankStandalone") - - val sc = new SparkContext(sparkConf) - - val input = sc.textFile(inputFile) - val partitioner = new HashPartitioner(sc.defaultParallelism) - val links = - if (usePartitioner) { - input.map(parseArticle _).partitionBy(partitioner).cache() - } else { - input.map(parseArticle _).cache() - } - val n = links.count() - val defaultRank = 1.0 / n - val a = 0.15 - - // Do the computation - val startTime = System.currentTimeMillis - val ranks = - pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, - sc.defaultParallelism) - - // Print the result - System.err.println("Articles with PageRank >= " + threshold + ":") - val top = - (ranks - .filter { case (id, rank) => rank >= threshold } - .map { case (id, rank) => "%s\t%s\n".format(id, rank) } - .collect().mkString) - println(top) - - val time = (System.currentTimeMillis - startTime) / 1000.0 - println("Completed %d iterations in %f seconds: %f seconds per iteration" - .format(numIterations, time, time / numIterations)) - sc.stop() - } - - def parseArticle(line: String): (String, Array[String]) = { - val fields = line.split("\t") - val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) - val id = new String(title) - val links = - if (body == "\\N") { - NodeSeq.Empty - } else { - try { - XML.loadString(body) \\ "link" \ "target" - } catch { - case e: org.xml.sax.SAXParseException => - System.err.println("Article \"" + title + "\" has malformed XML in body:\n" + body) - NodeSeq.Empty - } - } - val outEdges = links.map(link => new String(link.text)).toArray - (id, outEdges) - } - - def pageRank( - links: RDD[(String, Array[String])], - numIterations: Int, - defaultRank: Double, - a: Double, - n: Long, - partitioner: Partitioner, - usePartitioner: Boolean, - numPartitions: Int - ): RDD[(String, Double)] = { - var ranks = links.mapValues { edges => defaultRank } - for (i <- 1 to numIterations) { - val contribs = links.groupWith(ranks).flatMap { - case (id, (linksWrapperIterable, rankWrapperIterable)) => - val linksWrapper = linksWrapperIterable.iterator - val rankWrapper = rankWrapperIterable.iterator - if (linksWrapper.hasNext) { - val linksWrapperHead = linksWrapper.next - if (rankWrapper.hasNext) { - val rankWrapperHead = rankWrapper.next - linksWrapperHead.map(dest => (dest, rankWrapperHead / linksWrapperHead.size)) - } else { - linksWrapperHead.map(dest => (dest, defaultRank / linksWrapperHead.size)) - } - } else { - Array[(String, Double)]() - } - } - ranks = (contribs.combineByKey((x: Double) => x, - (x: Double, y: Double) => x + y, - (x: Double, y: Double) => x + y, - partitioner) - .mapValues(sum => a/n + (1-a)*sum)) - } - ranks - } -} - -class WPRSerializer extends org.apache.spark.serializer.Serializer { - def newInstance(): SerializerInstance = new WPRSerializerInstance() -} - -class WPRSerializerInstance extends SerializerInstance { - def serialize[T: ClassTag](t: T): ByteBuffer = { - throw new UnsupportedOperationException() - } - - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { - throw new UnsupportedOperationException() - } - - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { - throw new UnsupportedOperationException() - } - - def serializeStream(s: OutputStream): SerializationStream = { - new WPRSerializationStream(s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new WPRDeserializationStream(s) - } -} - -class WPRSerializationStream(os: OutputStream) extends SerializationStream { - val dos = new DataOutputStream(os) - - def writeObject[T: ClassTag](t: T): SerializationStream = t match { - case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { - case links: Array[String] => { - dos.writeInt(0) // links - dos.writeUTF(id) - dos.writeInt(links.length) - for (link <- links) { - dos.writeUTF(link) - } - this - } - case rank: Double => { - dos.writeInt(1) // rank - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - case (id: String, rank: Double) => { - dos.writeInt(2) // rank without wrapper - dos.writeUTF(id) - dos.writeDouble(rank) - this - } - } - - def flush() { dos.flush() } - def close() { dos.close() } -} - -class WPRDeserializationStream(is: InputStream) extends DeserializationStream { - val dis = new DataInputStream(is) - - def readObject[T: ClassTag](): T = { - val typeId = dis.readInt() - typeId match { - case 0 => { - val id = dis.readUTF() - val numLinks = dis.readInt() - val links = new Array[String](numLinks) - for (i <- 0 until numLinks) { - val link = dis.readUTF() - links(i) = link - } - (id, ArrayBuffer(links)).asInstanceOf[T] - } - case 1 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, ArrayBuffer(rank)).asInstanceOf[T] - } - case 2 => { - val id = dis.readUTF() - val rank = dis.readDouble() - (id, rank).asInstanceOf[T] - } - } - } - - def close() { dis.close() } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 2a2d0677272a0..3ee456edbe01e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -20,6 +20,7 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams} import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.{DataFrame, Row, SQLContext} @@ -106,10 +107,12 @@ private trait MyLogisticRegressionParams extends ClassifierParams { * * NOTE: This is private since it is an example. In practice, you may not want it to be private. */ -private class MyLogisticRegression +private class MyLogisticRegression(override val uid: String) extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel] with MyLogisticRegressionParams { + def this() = this(Identifiable.randomUID("myLogReg")) + setMaxIter(100) // Initialize // The parameter setter is in this class since it should return type MyLogisticRegression. @@ -125,7 +128,7 @@ private class MyLogisticRegression val weights = Vectors.zeros(numFeatures) // Learning would happen here. // Create a model, and return it. - new MyLogisticRegressionModel(this, weights) + new MyLogisticRegressionModel(uid, weights).setParent(this) } } @@ -135,7 +138,7 @@ private class MyLogisticRegression * NOTE: This is private since it is an example. In practice, you may not want it to be private. */ private class MyLogisticRegressionModel( - override val parent: MyLogisticRegression, + override val uid: String, val weights: Vector) extends ClassificationModel[Vector, MyLogisticRegressionModel] with MyLogisticRegressionParams { @@ -173,6 +176,6 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(parent, weights), extra) + copyValues(new MyLogisticRegressionModel(uid, weights), extra) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala new file mode 100644 index 0000000000000..b54466fd48bc5 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionExample.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.sql.DataFrame + +/** + * An example runner for linear regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LinearRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LinearRegressionExample --regParam 0.15 --elasticNetParam 1.0 \ + * data/mllib/sample_linear_regression_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LinearRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LinearRegressionExample") { + head("LinearRegressionExample: an example Linear Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LinearRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LinearRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "regression", params.fracTest) + + val lir = new LinearRegression() + .setFeaturesCol("features") + .setLabelCol("label") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + // Train the model + val startTime = System.nanoTime() + val lirModel = lir.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Print the weights and intercept for linear regression. + println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label") + println("Test data results:") + DecisionTreeExample.evaluateRegressionModel(lirModel, test, "label") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala new file mode 100644 index 0000000000000..b12f833ce94c8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.StringIndexer +import org.apache.spark.sql.DataFrame + +/** + * An example runner for logistic regression with elastic-net (mixing L1/L2) regularization. + * Run with + * {{{ + * bin/run-example ml.LogisticRegressionExample [options] + * }}} + * A synthetic dataset can be found at `data/mllib/sample_libsvm_data.txt` which can be + * trained by + * {{{ + * bin/run-example ml.LogisticRegressionExample --regParam 0.3 --elasticNetParam 0.8 \ + * data/mllib/sample_libsvm_data.txt + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LogisticRegressionExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + regParam: Double = 0.0, + elasticNetParam: Double = 0.0, + maxIter: Int = 100, + fitIntercept: Boolean = true, + tol: Double = 1E-6, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LogisticRegressionExample") { + head("LogisticRegressionExample: an example Logistic Regression with Elastic-Net app.") + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + .action((x, c) => c.copy(regParam = x)) + opt[Double]("elasticNetParam") + .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + + s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + + s"L1 and L2, default: ${defaultParams.elasticNetParam}") + .action((x, c) => c.copy(elasticNetParam = x)) + opt[Int]("maxIter") + .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Boolean]("fitIntercept") + .text(s"whether to fit an intercept term, default: ${defaultParams.fitIntercept}") + .action((x, c) => c.copy(fitIntercept = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations, Smaller value will lead " + + s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"LogisticRegressionExample with $params") + val sc = new SparkContext(conf) + + println(s"LogisticRegressionExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, "classification", params.fracTest) + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol("indexedLabel") + stages += labelIndexer + + val lor = new LogisticRegression() + .setFeaturesCol("features") + .setLabelCol("indexedLabel") + .setRegParam(params.regParam) + .setElasticNetParam(params.elasticNetParam) + .setMaxIter(params.maxIter) + .setTol(params.tol) + + stages += lor + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + val lirModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] + // Print the weights and intercept for logistic regression. + println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}") + + println("Training data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel") + println("Test data results:") + DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, "indexedLabel") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala new file mode 100644 index 0000000000000..6927eb8f275cf --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} + +import scopt.OptionParser + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext + +/** + * An example runner for Multiclass to Binary Reduction with One Vs Rest. + * The example uses Logistic Regression as the base classifier. All parameters that + * can be specified on the base classifier can be passed in to the runner options. + * Run with + * {{{ + * ./bin/run-example ml.OneVsRestExample [options] + * }}} + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.OneVsRestExample --driver-memory 1g + * [examples JAR path] [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object OneVsRestExample { + + case class Params private[ml] ( + input: String = null, + testInput: Option[String] = None, + maxIter: Int = 100, + tol: Double = 1E-6, + fitIntercept: Boolean = true, + regParam: Option[Double] = None, + elasticNetParam: Option[Double] = None, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("OneVsRest Example") { + head("OneVsRest Example: multiclass to binary reduction using OneVsRest") + opt[String]("input") + .text("input path to labeled examples. This path must be specified") + .required() + .action((x, c) => c.copy(input = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text("input path to test dataset. If given, option fracTest is ignored") + .action((x, c) => c.copy(testInput = Some(x))) + opt[Int]("maxIter") + .text(s"maximum number of iterations for Logistic Regression." + + s" default: ${defaultParams.maxIter}") + .action((x, c) => c.copy(maxIter = x)) + opt[Double]("tol") + .text(s"the convergence tolerance of iterations for Logistic Regression." + + s" default: ${defaultParams.tol}") + .action((x, c) => c.copy(tol = x)) + opt[Boolean]("fitIntercept") + .text(s"fit intercept for Logistic Regression." + + s" default: ${defaultParams.fitIntercept}") + .action((x, c) => c.copy(fitIntercept = x)) + opt[Double]("regParam") + .text(s"the regularization parameter for Logistic Regression.") + .action((x, c) => c.copy(regParam = Some(x))) + opt[Double]("elasticNetParam") + .text(s"the ElasticNet mixing parameter for Logistic Regression.") + .action((x, c) => c.copy(elasticNetParam = Some(x))) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest >= 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") + } else { + success + } + } + } + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + private def run(params: Params) { + val conf = new SparkConf().setAppName(s"OneVsRestExample with $params") + val sc = new SparkContext(conf) + val inputData = MLUtils.loadLibSVMFile(sc, params.input) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // compute the train/test split: if testInput is not provided use part of input. + val data = params.testInput match { + case Some(t) => { + // compute the number of features in the training set. + val numFeatures = inputData.first().features.size + val testData = MLUtils.loadLibSVMFile(sc, t, numFeatures) + Array[RDD[LabeledPoint]](inputData, testData) + } + case None => { + val f = params.fracTest + inputData.randomSplit(Array(1 - f, f), seed = 12345) + } + } + val Array(train, test) = data.map(_.toDF().cache()) + + // instantiate the base classifier + val classifier = new LogisticRegression() + .setMaxIter(params.maxIter) + .setTol(params.tol) + .setFitIntercept(params.fitIntercept) + + // Set regParam, elasticNetParam if specified in params + params.regParam.foreach(classifier.setRegParam) + params.elasticNetParam.foreach(classifier.setElasticNetParam) + + // instantiate the One Vs Rest Classifier. + + val ovr = new OneVsRest() + ovr.setClassifier(classifier) + + // train the multiclass model. + val (trainingDuration, ovrModel) = time(ovr.fit(train)) + + // score the model on test data. + val (predictionDuration, predictions) = time(ovrModel.transform(test)) + + // evaluate the model + val predictionsAndLabels = predictions.select("prediction", "label") + .map(row => (row.getDouble(0), row.getDouble(1))) + + val metrics = new MulticlassMetrics(predictionsAndLabels) + + val confusionMatrix = metrics.confusionMatrix + + // compute the false positive rate per label + val predictionColSchema = predictions.schema("prediction") + val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get + val fprs = Range(0, numClasses).map(p => (p, metrics.falsePositiveRate(p.toDouble))) + + println(s" Training Time ${trainingDuration} sec\n") + + println(s" Prediction Time ${predictionDuration} sec\n") + + println(s" Confusion Matrix\n ${confusionMatrix.toString}\n") + + println("label\tfpr") + + println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n")) + + sc.stop() + } + + private def time[R](block: => R): (Long, R) = { + val t0 = System.nanoTime() + val result = block // call-by-name + val t1 = System.nanoTime() + (NANO.toSeconds(t1 - t0), result) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a991f50e338..a0561e2573fc9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -87,7 +87,7 @@ object SimpleParamsExample { LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) // Make predictions on test data using the Transformer.transform() method. - // LogisticRegression.transform will only use the 'features' column. + // LogisticRegressionModel.transform will only use the 'features' column. // Note that model2.transform() outputs a 'myProbability' column instead of the usual // 'probability' column since we renamed the lr.probabilityCol parameter previously. model2.transform(test.toDF()) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index 6772efd2c581c..1324b066c30c3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -64,7 +64,7 @@ object SimpleTextClassificationPipeline { .setOutputCol("features") val lr = new LogisticRegression() .setMaxIter(10) - .setRegParam(0.01) + .setRegParam(0.001) val pipeline = new Pipeline() .setStages(Array(tokenizer, hashingTF, lr)) @@ -75,7 +75,7 @@ object SimpleTextClassificationPipeline { val test = sc.parallelize(Seq( Document(4L, "spark i j k"), Document(5L, "l m n"), - Document(6L, "mapreduce spark"), + Document(6L, "spark hadoop spark"), Document(7L, "apache hadoop"))) // Make predictions on test documents. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index e943d6c889fab..520893b26d595 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -103,10 +103,10 @@ object DatasetExample { tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") - df.saveAsParquetFile(outputDir) + df.write.parquet(outputDir) println(s"Loading Parquet file with UDT from $outputDir.") - val newDataset = sqlContext.parquetFile(outputDir) + val newDataset = sqlContext.read.parquet(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index b0613632c9946..3381941673db8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -22,7 +22,6 @@ import scala.language.reflectiveCalls import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -354,7 +353,11 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. + * + * This is just for demo purpose. In general, don't copy this code because it is NOT efficient + * due to the use of structural types, which leads to one reflection call per record. */ + // scalastyle:off structural.type private[mllib] def meanSquaredError( model: { def predict(features: Vector): Double }, data: RDD[LabeledPoint]): Double = { @@ -363,4 +366,5 @@ object DecisionTreeRunner { err * err }.mean() } + // scalastyle:on structural.type } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index df76b45e50810..f8c71ccabc43b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -40,23 +40,23 @@ object DenseGaussianMixture { private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) { val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example") - val ctx = new SparkContext(conf) - + val ctx = new SparkContext(conf) + val data = ctx.textFile(inputFile).map { line => Vectors.dense(line.trim.split(' ').map(_.toDouble)) }.cache() - + val clusters = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) .run(data) - + for (i <- 0 until clusters.k) { - println("weight=%f\nmu=%s\nsigma=\n%s\n" format + println("weight=%f\nmu=%s\nsigma=\n%s\n" format (clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma)) } - + println("Cluster labels (first <= 100):") val clusterLabels = clusters.predict(data) clusterLabels.take(100).foreach { x => diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index a11890d6f2b1c..3ebb112fc069e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -36,22 +36,21 @@ object AvroConversionUtil extends Serializable { return null } schema.getType match { - case UNION => unpackUnion(obj, schema) - case ARRAY => unpackArray(obj, schema) - case FIXED => unpackFixed(obj, schema) - case MAP => unpackMap(obj, schema) - case BYTES => unpackBytes(obj) - case RECORD => unpackRecord(obj) - case STRING => obj.toString - case ENUM => obj.toString - case NULL => obj + case UNION => unpackUnion(obj, schema) + case ARRAY => unpackArray(obj, schema) + case FIXED => unpackFixed(obj, schema) + case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) + case STRING => obj.toString + case ENUM => obj.toString + case NULL => obj case BOOLEAN => obj - case DOUBLE => obj - case FLOAT => obj - case INT => obj - case LONG => obj - case other => throw new SparkException( - s"Unknown Avro schema type ${other.getName}") + case DOUBLE => obj + case FLOAT => obj + case INT => obj + case LONG => obj + case other => throw new SparkException(s"Unknown Avro schema type ${other.getName}") } } diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala index 273bee0a8b30f..90d48a64106c7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -18,20 +18,34 @@ package org.apache.spark.examples.pythonconverters import scala.collection.JavaConversions._ +import scala.util.parsing.json.JSONObject import org.apache.spark.api.python.Converter import org.apache.hadoop.hbase.client.{Put, Result} import org.apache.hadoop.hbase.io.ImmutableBytesWritable import org.apache.hadoop.hbase.util.Bytes +import org.apache.hadoop.hbase.KeyValue.Type +import org.apache.hadoop.hbase.CellUtil /** - * Implementation of [[org.apache.spark.api.python.Converter]] that converts an - * HBase Result to a String + * Implementation of [[org.apache.spark.api.python.Converter]] that converts all + * the records in an HBase Result to a String */ class HBaseResultToStringConverter extends Converter[Any, String] { override def convert(obj: Any): String = { + import collection.JavaConverters._ val result = obj.asInstanceOf[Result] - Bytes.toStringBinary(result.value()) + val output = result.listCells.asScala.map(cell => + Map( + "row" -> Bytes.toStringBinary(CellUtil.cloneRow(cell)), + "columnFamily" -> Bytes.toStringBinary(CellUtil.cloneFamily(cell)), + "qualifier" -> Bytes.toStringBinary(CellUtil.cloneQualifier(cell)), + "timestamp" -> cell.getTimestamp.toString, + "type" -> Type.codeToType(cell.getTypeByte).toString, + "value" -> Bytes.toStringBinary(CellUtil.cloneValue(cell)) + ) + ) + output.map(JSONObject(_).toString()).mkString("\n") } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 6331d1c0060f8..b11e32047dc34 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -58,10 +58,10 @@ object RDDRelation { df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. - df.saveAsParquetFile("pair.parquet") + df.write.parquet("pair.parquet") // Read in parquet file. Parquet files are self-describing so the schmema is preserved. - val parquetFile = sqlContext.parquetFile("pair.parquet") + val parquetFile = sqlContext.read.parquet("pair.parquet") // Queries can be run using the DSL on parequet files just like the original RDD. parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index 92867b44be138..016de4c63d1d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -104,10 +104,8 @@ extends Actor with ActorHelper { object FeederActor { def main(args: Array[String]) { - if(args.length < 2){ - System.err.println( - "Usage: FeederActor \n" - ) + if (args.length < 2){ + System.err.println("Usage: FeederActor \n") System.exit(1) } val Seq(host, port) = args.toSeq diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala similarity index 97% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala index 11a8cf09533ce..fbe394de4a179 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala @@ -51,7 +51,7 @@ object DirectKafkaWordCount { // Create context with 2 second batch interval val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) + val ssc = new StreamingContext(sparkConf, Seconds(2)) // Create direct kafka stream with brokers and topics val topicsSet = topics.split(",").toSet diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala similarity index 95% rename from examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala rename to examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index f407367a54f6c..60416ee343544 100644 --- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -49,10 +49,10 @@ object KafkaWordCount { val Array(zkQuorum, group, topics, numThreads) = args val sparkConf = new SparkConf().setAppName("KafkaWordCount") - val ssc = new StreamingContext(sparkConf, Seconds(2)) + val ssc = new StreamingContext(sparkConf, Seconds(2)) ssc.checkpoint("checkpoint") - val topicMap = topics.split(",").map((_,numThreads.toInt)).toMap + val topicMap = topics.split(",").map((_, numThreads.toInt)).toMap val lines = KafkaUtils.createStream(ssc, zkQuorum, group, topicMap).map(_._2) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1L)) @@ -96,7 +96,7 @@ object KafkaWordCountProducer { producer.send(message) } - Thread.sleep(100) + Thread.sleep(1000) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 85b9a54b40baf..813c8554f5193 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -40,7 +40,7 @@ object MQTTPublisher { StreamingExamples.setStreamingLogLevels() val Seq(brokerUrl, topic) = args.toSeq - + var client: MqttClient = null try { @@ -49,7 +49,7 @@ object MQTTPublisher { client.connect() - val msgtopic = client.getTopic(topic) + val msgtopic = client.getTopic(topic) val msgContent = "hello mqtt demo for spark streaming" val message = new MqttMessage(msgContent.getBytes("utf-8")) @@ -59,10 +59,10 @@ object MQTTPublisher { println(s"Published data. topic: ${msgtopic.getName()}; Message: $message") } catch { case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - Thread.sleep(10) + Thread.sleep(10) println("Queue is full, wait for to consume data from the message queue") - } - } + } + } } catch { case e: MqttException => println("Exception Caught: " + e) } finally { @@ -107,7 +107,7 @@ object MQTTWordCount { val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) val words = lines.flatMap(x => x.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - + wordCounts.print() ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 54d996b8ac990..889f052c70263 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -57,8 +57,7 @@ object PageViewGenerator { 404 -> .05) val userZipCode = Map(94709 -> .5, 94117 -> .5) - val userID = Map((1 to 100).map(_ -> .01):_*) - + val userID = Map((1 to 100).map(_ -> .01) : _*) def pickFromDistribution[T](inputMap : Map[T, Double]) : T = { val rand = new Random().nextDouble() diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 1f3e619d97a24..71f2b6fe18bd1 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -42,15 +42,46 @@ org.apache.flume flume-ng-sdk + + + + com.google.guava + guava + + + + org.apache.thrift + libthrift + + org.apache.flume flume-ng-core + + + com.google.guava + guava + + + org.apache.thrift + libthrift + + org.scala-lang scala-library + + + com.google.guava + guava + test + + + + diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index fd01807fc3ac4..dc2a4ab138e18 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -21,7 +21,6 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.flume.Channel import org.apache.commons.lang3.RandomStringUtils @@ -45,8 +44,7 @@ import org.apache.commons.lang3.RandomStringUtils private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Channel, val transactionTimeout: Int, val backOffInterval: Int) extends SparkFlumeProtocol with Logging { val transactionExecutorOpt = Option(Executors.newFixedThreadPool(threads, - new ThreadFactoryBuilder().setDaemon(true) - .setNameFormat("Spark Sink Processor Thread - %d").build())) + new SparkSinkThreadFactory("Spark Sink Processor Thread - %d"))) // Protected by `sequenceNumberToProcessor` private val sequenceNumberToProcessor = mutable.HashMap[CharSequence, TransactionProcessor]() // This sink will not persist sequence numbers and reuses them if it gets restarted. diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala new file mode 100644 index 0000000000000..845fc8debda75 --- /dev/null +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSinkThreadFactory.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.util.concurrent.ThreadFactory +import java.util.concurrent.atomic.AtomicLong + +/** + * Thread factory that generates daemon threads with a specified name format. + */ +private[sink] class SparkSinkThreadFactory(nameFormat: String) extends ThreadFactory { + + private val threadId = new AtomicLong() + + override def newThread(r: Runnable): Thread = { + val t = new Thread(r, nameFormat.format(threadId.incrementAndGet())) + t.setDaemon(true) + t + } + +} diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index ea45b14294df9..7ad43b1d7b0a0 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -143,7 +143,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, eventBatch.setErrorMsg(msg) } else { // At this point, the events are available, so fill them into the event batch - eventBatch = new EventBatch("",seqNum, events) + eventBatch = new EventBatch("", seqNum, events) } }) } catch { diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 650b2fbe1c142..fa43629d49771 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -24,16 +24,24 @@ import scala.collection.JavaConversions._ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} -import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.event.EventBuilder import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory + +// Due to MNG-1378, there is not a way to include test dependencies transitively. +// We cannot include Spark core tests as a dependency here because it depends on +// Spark core main, which has too many dependencies to require here manually. +// For this reason, we continue to use FunSuite and ignore the scalastyle checks +// that fail if this is detected. +//scalastyle:off import org.scalatest.FunSuite class SparkSinkSuite extends FunSuite { +//scalastyle:on + val eventsPerBatch = 1000 val channelCapacity = 5000 @@ -185,9 +193,8 @@ class SparkSinkSuite extends FunSuite { count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { (1 to count).map(_ => { - lazy val channelFactoryExecutor = - Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). - setNameFormat("Flume Receiver Channel Thread - %d").build()) + lazy val channelFactoryExecutor = Executors.newCachedThreadPool( + new SparkSinkThreadFactory("Flume Receiver Channel Thread - %d")) lazy val channelFactory = new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) val transceiver = new NettyTransceiver(address, channelFactory) diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 8df7edbdcad33..a345c03582ad6 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming-flume-sink_${scala.binary.version} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index dc629df4f4ac2..65c49c131518b 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -60,7 +60,7 @@ private[streaming] object EventTransformer extends Logging { out.write(body) val numHeaders = headers.size() out.writeInt(numHeaders) - for ((k,v) <- headers) { + for ((k, v) <- headers) { val keyBuff = Utils.serialize(k.toString) out.writeInt(keyBuff.length) out.write(keyBuff) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 60e2994431b38..1e32a365a1eee 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -152,9 +152,9 @@ class FlumeReceiver( val channelFactory = new NioServerSocketChannelFactory(Executors.newCachedThreadPool(), Executors.newCachedThreadPool()) val channelPipelineFactory = new CompressionChannelPipelineFactory() - + new NettyServer( - responder, + responder, new InetSocketAddress(host, port), channelFactory, channelPipelineFactory, @@ -188,12 +188,12 @@ class FlumeReceiver( override def preferredLocation: Option[String] = Option(host) - /** A Netty Pipeline factory that will decompress incoming data from + /** A Netty Pipeline factory that will decompress incoming data from * and the Netty client and compress data going back to the client. * * The compression on the return is required because Flume requires - * a successful response to indicate it can remove the event/batch - * from the configured channel + * a successful response to indicate it can remove the event/batch + * from the configured channel */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 92fa5b41be89e..583e7dca317ad 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -110,7 +110,7 @@ private[streaming] class FlumePollingReceiver( } /** - * A wrapper around the transceiver and the Avro IPC API. + * A wrapper around the transceiver and the Avro IPC API. * @param transceiver The transceiver to use for communication with Flume * @param client The client that the callbacks are received on. */ diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 43c1b865b64a1..d772b9ca9b570 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -18,26 +18,29 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} +import java.util.concurrent._ import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.concurrent.duration._ +import scala.language.postfixOps import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.conf.Configurables import org.apache.flume.event.EventBuilder +import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.{ManualClock, Utils} -class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchCount = 5 val eventsPerBatch = 100 @@ -57,11 +60,11 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging before(beforeFunction()) - ignore("flume polling test") { + test("flume polling test") { testMultipleTimes(testFlumePolling) } - ignore("flume polling test multiple hosts") { + test("flume polling test multiple hosts") { testMultipleTimes(testFlumePollingMultipleHost) } @@ -100,18 +103,8 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging Configurables.configure(sink, context) sink.setChannel(channel) sink.start() - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", sink.getPort())), - StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) - outputStream.register() - ssc.start() - writeAndVerify(Seq(channel), ssc, outputBuffer) + writeAndVerify(Seq(sink), Seq(channel)) assertChannelIsEmpty(channel) sink.stop() channel.stop() @@ -142,10 +135,22 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging Configurables.configure(sink2, context) sink2.setChannel(channel2) sink2.start() + try { + writeAndVerify(Seq(sink, sink2), Seq(channel, channel2)) + assertChannelIsEmpty(channel) + assertChannelIsEmpty(channel2) + } finally { + sink.stop() + sink2.stop() + channel.stop() + channel2.stop() + } + } + def writeAndVerify(sinks: Seq[SparkSink], channels: Seq[MemoryChannel]) { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) - val addresses = Seq(sink.getPort(), sink2.getPort()).map(new InetSocketAddress("localhost", _)) + val addresses = sinks.map(sink => new InetSocketAddress("localhost", sink.getPort())) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 5) @@ -155,61 +160,49 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging outputStream.register() ssc.start() - try { - writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) - } finally { - sink.stop() - sink2.stop() - channel.stop() - channel2.stop() - } - } - - def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext, - outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]]) { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val executor = Executors.newCachedThreadPool() val executorCompletion = new ExecutorCompletionService[Void](executor) - channels.map(channel => { + + val latch = new CountDownLatch(batchCount * channels.size) + sinks.foreach(_.countdownWhenBatchReceived(latch)) + + channels.foreach(channel => { executorCompletion.submit(new TxnSubmitter(channel, clock)) }) + for (i <- 0 until channels.size) { executorCompletion.take() } - val startTime = System.currentTimeMillis() - while (outputBuffer.size < batchCount * channels.size && - System.currentTimeMillis() - startTime < 15000) { - logInfo("output.size = " + outputBuffer.size) - Thread.sleep(100) - } - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < 15000, "Operation timed out after " + timeTaken + " ms") - logInfo("Stopping context") - ssc.stop() - val flattenedBuffer = outputBuffer.flatten - assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) - var counter = 0 - for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { - val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + - String.valueOf(i)).getBytes("utf-8"), - Map[String, String]("test-" + i.toString -> "header")) - var found = false - var j = 0 - while (j < flattenedBuffer.size && !found) { - val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") - if (new String(eventToVerify.getBody, "utf-8") == strToCompare && - eventToVerify.getHeaders.get("test-" + i.toString) - .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { - found = true - counter += 1 + latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. + clock.advance(batchDuration.milliseconds) + + // The eventually is required to ensure that all data in the batch has been processed. + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val flattenedBuffer = outputBuffer.flatten + assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + + String.valueOf(i)).getBytes("utf-8"), + Map[String, String]("test-" + i.toString -> "header")) + var found = false + var j = 0 + while (j < flattenedBuffer.size && !found) { + val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") + if (new String(eventToVerify.getBody, "utf-8") == strToCompare && + eventToVerify.getHeaders.get("test-" + i.toString) + .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { + found = true + counter += 1 + } + j += 1 } - j += 1 } + assert(counter === totalEventsPerChannel * channels.size) } - assert(counter === totalEventsPerChannel * channels.size) + ssc.stop() } def assertChannelIsEmpty(channel: MemoryChannel): Unit = { @@ -234,7 +227,6 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging tx.commit() tx.close() Thread.sleep(500) // Allow some time for the events to reach - clock.advance(batchDuration.milliseconds) } null } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 39e6754c81dbf..c926359987d89 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -35,15 +35,15 @@ import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory import org.jboss.netty.handler.codec.compression._ -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} import org.apache.spark.util.Utils -class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") var ssc: StreamingContext = null @@ -138,7 +138,7 @@ class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with L val status = client.appendBatch(inputEvents.toList) status should be (avro.Status.OK) } - + eventually(timeout(10 seconds), interval(100 milliseconds)) { val outputEvents = outputBuffer.flatten.map { _.event } outputEvents.foreach { diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 243ce6eaca658..5734d55bf4784 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.kafka kafka_${scala.binary.version} diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 6715aede7928a..060c2f23eded8 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -65,6 +65,9 @@ class DirectKafkaInputDStream[ val maxRetries = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRetries", 1) + // Keep this consistent with how other streams are named (e.g. "Flume polling stream [2]") + private[streaming] override def name: String = s"Kafka direct stream [$id]" + protected[streaming] override val checkpointData = new DirectKafkaInputDStreamCheckpointData diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 6cf254a7b69cb..65d51d87f8486 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -113,7 +113,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable { r.flatMap { tm: TopicMetadata => tm.partitionsMetadata.map { pm: PartitionMetadata => TopicAndPartition(tm.topic, pm.partitionId) - } + } } } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index cca0fac0234e1..04b2dc10d39ea 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -135,7 +135,7 @@ class KafkaReceiver[ store((msgAndMetadata.key, msgAndMetadata.message)) } } catch { - case e: Throwable => logError("Error handling message; exiting", e) + case e: Throwable => reportError("Error handling message; exiting", e) } } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index d7cf500577c2a..0b8a391a2c569 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -189,7 +189,7 @@ object KafkaUtils { sc: SparkContext, kafkaParams: Map[String, String], offsetRanges: Array[OffsetRange] - ): RDD[(K, V)] = { + ): RDD[(K, V)] = sc.withScope { val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message) val leaders = leadersForRanges(kafkaParams, offsetRanges) new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler) @@ -224,7 +224,7 @@ object KafkaUtils { offsetRanges: Array[OffsetRange], leaders: Map[TopicAndPartition, Broker], messageHandler: MessageAndMetadata[K, V] => R - ): RDD[R] = { + ): RDD[R] = sc.withScope { val leaderMap = if (leaders.isEmpty) { leadersForRanges(kafkaParams, offsetRanges) } else { @@ -233,7 +233,8 @@ object KafkaUtils { case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port)) }.toMap } - new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler) + val cleanedHandler = sc.clean(messageHandler) + new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, cleanedHandler) } /** @@ -256,7 +257,7 @@ object KafkaUtils { valueDecoderClass: Class[VD], kafkaParams: JMap[String, String], offsetRanges: Array[OffsetRange] - ): JavaPairRDD[K, V] = { + ): JavaPairRDD[K, V] = jsc.sc.withScope { implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) @@ -294,7 +295,7 @@ object KafkaUtils { offsetRanges: Array[OffsetRange], leaders: JMap[TopicAndPartition, Broker], messageHandler: JFunction[MessageAndMetadata[K, V], R] - ): JavaRDD[R] = { + ): JavaRDD[R] = jsc.sc.withScope { implicit val keyCmt: ClassTag[K] = ClassTag(keyClass) implicit val valueCmt: ClassTag[V] = ClassTag(valueClass) implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) @@ -314,7 +315,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -348,8 +349,9 @@ object KafkaUtils { fromOffsets: Map[TopicAndPartition, Long], messageHandler: MessageAndMetadata[K, V] => R ): InputDStream[R] = { + val cleanedHandler = ssc.sc.clean(messageHandler) new DirectKafkaInputDStream[K, V, KD, VD, R]( - ssc, kafkaParams, fromOffsets, messageHandler) + ssc, kafkaParams, fromOffsets, cleanedHandler) } /** @@ -361,7 +363,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -425,7 +427,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). @@ -469,11 +471,12 @@ object KafkaUtils { implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass) implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass) implicit val recordCmt: ClassTag[R] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call _) createDirectStream[K, V, KD, VD, R]( jssc.ssc, Map(kafkaParams.toSeq: _*), Map(fromOffsets.mapValues { _.longValue() }.toSeq: _*), - messageHandler.call _ + cleanedHandler ) } @@ -486,7 +489,7 @@ object KafkaUtils { * Points to note: * - No receivers: This stream does not use any receiver. It directly queries Kafka * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked - * by the stream itself. For interoperability with Kafka monitoring tools that depend on + * by the stream itself. For interoperability with Kafka monitoring tools that depend on * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. * You can access the offsets used in each batch from the generated RDDs (see * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]). diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index ea87e960379f1..75f0dfc22b9dc 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -267,7 +267,7 @@ class ReliableKafkaReceiver[ } } catch { case e: Exception => - logError("Error handling message", e) + reportError("Error handling message", e) } } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index b6d314dfc7783..47bbfb605850a 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -28,10 +28,10 @@ import scala.language.postfixOps import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.{Logging, SparkConf, SparkContext} +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream @@ -39,7 +39,7 @@ import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.Utils class DirectKafkaStreamSuite - extends FunSuite + extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll with Eventually diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala index 7fb841b79cb65..d66830cbacdee 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.streaming.kafka import scala.util.Random import kafka.common.TopicAndPartition -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll { +import org.apache.spark.SparkFunSuite + +class KafkaClusterSuite extends SparkFunSuite with BeforeAndAfterAll { private val topic = "kcsuitetopic" + Random.nextInt(10000) private val topicAndPartition = TopicAndPartition(topic, 0) private var kc: KafkaCluster = null diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index 39c3fb448ff57..054487269a935 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -22,11 +22,11 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.apache.spark._ -class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { +class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll { private var kafkaTestUtils: KafkaTestUtils = _ @@ -65,7 +65,7 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll { val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size)) - val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( + val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder]( sc, kafkaParams, offsetRanges) val received = rdd.map(_._2).collect.toSet diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 24699dfc33adb..8ee2cc660f849 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -23,14 +23,14 @@ import scala.language.postfixOps import scala.util.Random import kafka.serializer.StringDecoder -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} -class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll { +class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll { private var ssc: StreamingContext = _ private var kafkaTestUtils: KafkaTestUtils = _ diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala index 38548dd73b82c..80e2df62de3fe 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -26,15 +26,15 @@ import scala.util.Random import kafka.serializer.StringDecoder import kafka.utils.{ZKGroupTopicDirs, ZkUtils} -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class ReliableKafkaStreamSuite extends FunSuite +class ReliableKafkaStreamSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with Eventually { private val sparkConf = new SparkConf() diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 98f95a9a64fa0..7d102e10ab60f 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.eclipse.paho org.eclipse.paho.client.mqttv3 diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 3c0ef94cb0fab..7c2f18cb35bda 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -17,25 +17,12 @@ package org.apache.spark.streaming.mqtt -import java.io.IOException -import java.util.concurrent.Executors -import java.util.Properties - -import scala.collection.JavaConversions._ -import scala.collection.Map -import scala.collection.mutable.HashMap -import scala.reflect.ClassTag - import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken import org.eclipse.paho.client.mqttv3.MqttCallback import org.eclipse.paho.client.mqttv3.MqttClient -import org.eclipse.paho.client.mqttv3.MqttClientPersistence -import org.eclipse.paho.client.mqttv3.MqttException import org.eclipse.paho.client.mqttv3.MqttMessage -import org.eclipse.paho.client.mqttv3.MqttTopic import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence -import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ @@ -57,6 +44,8 @@ class MQTTInputDStream( storageLevel: StorageLevel ) extends ReceiverInputDStream[String](ssc_) { + private[streaming] override def name: String = s"MQTT stream [$id]" + def getReceiver(): Receiver[String] = { new MQTTReceiver(brokerUrl, topic, storageLevel) } @@ -86,7 +75,7 @@ class MQTTReceiver( // Handles Mqtt message override def messageArrived(topic: String, message: MqttMessage) { - store(new String(message.getPayload(),"utf-8")) + store(new String(message.getPayload(), "utf-8")) } override def deliveryComplete(token: IMqttDeliveryToken) { diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index a19a72c58a705..c4bf5aa7869bb 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -29,7 +29,7 @@ import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually import org.apache.spark.streaming.{Milliseconds, StreamingContext} @@ -37,10 +37,10 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.scheduler.StreamingListener import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils -class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { +class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) private val master = "local[2]" diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 8b6a8959ac4cf..d28e3e1846d70 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.twitter4j twitter4j-stream diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index 9ee57d7581d85..d9acb568879fe 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.streaming.twitter -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import twitter4j.Status import twitter4j.auth.{NullAuthorization, Authorization} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging { +class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { val batchDuration = Seconds(1) diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index a50d378b34335..9998c11c85171 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -41,6 +41,13 @@ ${project.version} provided + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + ${akka.group} akka-zeromq_${scala.binary.version} diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala index a7566e733d891..35d2e62c68480 100644 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.streaming.zeromq import akka.actor.SupervisorStrategy import akka.util.ByteString import akka.zeromq.Subscribe -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -class ZeroMQStreamSuite extends FunSuite { +class ZeroMQStreamSuite extends SparkFunSuite { val batchDuration = Seconds(1) diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index b0bff27a61c19..06e0ff28afd95 100644 --- a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -20,6 +20,7 @@ import java.util.List; import java.util.regex.Pattern; +import com.amazonaws.regions.RegionUtils; import org.apache.log4j.Logger; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FlatMapFunction; @@ -40,140 +41,146 @@ import com.google.common.collect.Lists; /** - * Java-friendly Kinesis Spark Streaming WordCount example + * Consumes messages from a Amazon Kinesis streams and does wordcount. * - * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details - * on the Kinesis Spark Streaming integration. + * This example spins up 1 Kinesis Receiver per shard for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given stream. * - * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard - * for the given stream. - * It then starts pulling from the last checkpointed sequence number of the given - * and . + * Usage: JavaKinesisWordCountASL [app-name] [stream-name] [endpoint-url] [region-name] + * [app-name] is the name of the consumer app, used to track the read data in DynamoDB + * [stream-name] name of the Kinesis stream (ie. mySparkStream) + * [endpoint-url] endpoint of the Kinesis service + * (e.g. https://kinesis.us-east-1.amazonaws.com) * - * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region - * - * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs - * Instance profile credentials - delivered through the Amazon EC2 metadata service - * - * Usage: JavaKinesisWordCountASL - * is the name of the Kinesis stream (ie. mySparkStream) - * is the endpoint of the Kinesis service - * (ie. https://kinesis.us-east-1.amazonaws.com) * * Example: - * $ export AWS_ACCESS_KEY_ID= + * # export AWS keys if necessary + * $ export AWS_ACCESS_KEY_ID=[your-access-key] * $ export AWS_SECRET_KEY= - * $ $SPARK_HOME/bin/run-example \ - * org.apache.spark.examples.streaming.JavaKinesisWordCountASL mySparkStream \ - * https://kinesis.us-east-1.amazonaws.com * - * Note that number of workers/threads should be 1 more than the number of receivers. - * This leaves one thread available for actually processing the data. + * # run the example + * $ SPARK_HOME/bin/run-example streaming.JavaKinesisWordCountASL myAppName mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com + * + * There is a companion helper class called KinesisWordProducerASL which puts dummy data + * onto the Kinesis stream. * - * There is a companion helper class called KinesisWordCountProducerASL which puts dummy data - * onto the Kinesis stream. - * Usage instructions for KinesisWordCountProducerASL are provided in the class definition. + * This code uses the DefaultAWSCredentialsProviderChain to find credentials + * in the following order: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * For more information, see + * http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + * + * See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + * the Kinesis Spark Streaming integration. */ public final class JavaKinesisWordCountASL { // needs to be public for access from run-example - private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); - private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); - - /* Make the constructor private to enforce singleton */ - private JavaKinesisWordCountASL() { + private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); + private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); + + public static void main(String[] args) { + // Check that all required args were passed in. + if (args.length != 3) { + System.err.println( + "Usage: JavaKinesisWordCountASL \n\n" + + " is the name of the app, used to track the read data in DynamoDB\n" + + " is the name of the Kinesis stream\n" + + " is the endpoint of the Kinesis service\n" + + " (e.g. https://kinesis.us-east-1.amazonaws.com)\n" + + "Generate data for the Kinesis stream using the example KinesisWordProducerASL.\n" + + "See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more\n" + + "details.\n" + ); + System.exit(1); } - public static void main(String[] args) { - /* Check that all required args were passed in. */ - if (args.length < 2) { - System.err.println( - "Usage: JavaKinesisWordCountASL \n" + - " is the name of the Kinesis stream\n" + - " is the endpoint of the Kinesis service\n" + - " (e.g. https://kinesis.us-east-1.amazonaws.com)\n"); - System.exit(1); - } - - StreamingExamples.setStreamingLogLevels(); - - /* Populate the appropriate variables from the given args */ - String streamName = args[0]; - String endpointUrl = args[1]; - /* Set the batch interval to a fixed 2000 millis (2 seconds) */ - Duration batchInterval = new Duration(2000); - - /* Create a Kinesis client in order to determine the number of shards for the given stream */ - AmazonKinesisClient kinesisClient = new AmazonKinesisClient( - new DefaultAWSCredentialsProviderChain()); - kinesisClient.setEndpoint(endpointUrl); - - /* Determine the number of shards from the stream */ - int numShards = kinesisClient.describeStream(streamName) - .getStreamDescription().getShards().size(); - - /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard */ - int numStreams = numShards; - - /* Setup the Spark config. */ - SparkConf sparkConfig = new SparkConf().setAppName("KinesisWordCount"); - - /* Kinesis checkpoint interval. Same as batchInterval for this example. */ - Duration checkpointInterval = batchInterval; + // Set default log4j logging level to WARN to hide Spark logs + StreamingExamples.setStreamingLogLevels(); + + // Populate the appropriate variables from the given args + String kinesisAppName = args[0]; + String streamName = args[1]; + String endpointUrl = args[2]; + + // Create a Kinesis client in order to determine the number of shards for the given stream + AmazonKinesisClient kinesisClient = + new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()); + kinesisClient.setEndpoint(endpointUrl); + int numShards = + kinesisClient.describeStream(streamName).getStreamDescription().getShards().size(); + + + // In this example, we're going to create 1 Kinesis Receiver/input DStream for each shard. + // This is not a necessity; if there are less receivers/DStreams than the number of shards, + // then the shards will be automatically distributed among the receivers and each receiver + // will receive data from multiple shards. + int numStreams = numShards; + + // Spark Streaming batch interval + Duration batchInterval = new Duration(2000); + + // Kinesis checkpoint interval. Same as batchInterval for this example. + Duration kinesisCheckpointInterval = batchInterval; + + // Get the region name from the endpoint URL to save Kinesis Client Library metadata in + // DynamoDB of the same region as the Kinesis stream + String regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName(); + + // Setup the Spark config and StreamingContext + SparkConf sparkConfig = new SparkConf().setAppName("JavaKinesisWordCountASL"); + JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); + + // Create the Kinesis DStreams + List> streamsList = new ArrayList>(numStreams); + for (int i = 0; i < numStreams; i++) { + streamsList.add( + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + InitialPositionInStream.LATEST, kinesisCheckpointInterval, StorageLevel.MEMORY_AND_DISK_2()) + ); + } - /* Setup the StreamingContext */ - JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); + // Union all the streams if there is more than 1 stream + JavaDStream unionStreams; + if (streamsList.size() > 1) { + unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); + } else { + // Otherwise, just use the 1 stream + unionStreams = streamsList.get(0); + } - /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ - List> streamsList = new ArrayList>(numStreams); - for (int i = 0; i < numStreams; i++) { - streamsList.add( - KinesisUtils.createStream(jssc, streamName, endpointUrl, checkpointInterval, - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()) - ); + // Convert each line of Array[Byte] to String, and split into words + JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { + @Override + public Iterable call(byte[] line) { + return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); + } + }); + + // Map each word to a (word, 1) tuple so we can reduce by key to count the words + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } } - - /* Union all the streams if there is more than 1 stream */ - JavaDStream unionStreams; - if (streamsList.size() > 1) { - unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); - } else { - /* Otherwise, just use the 1 stream */ - unionStreams = streamsList.get(0); + ).reduceByKey( + new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } } + ); - /* - * Split each line of the union'd DStreams into multiple words using flatMap to produce the collection. - * Convert lines of byte[] to multiple Strings by first converting to String, then splitting on WORD_SEPARATOR. - */ - JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { - @Override - public Iterable call(byte[] line) { - return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); - } - }); - - /* Map each word to a (word, 1) tuple, then reduce/aggregate by word. */ - JavaPairDStream wordCounts = words.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(String s) { - return new Tuple2(s, 1); - } - }).reduceByKey(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - }); - - /* Print the first 10 wordCounts */ - wordCounts.print(); - - /* Start the streaming context and await termination */ - jssc.start(); - jssc.awaitTermination(); - } + // Print the first 10 wordCounts + wordCounts.print(); + + // Start the streaming context and await termination + jssc.start(); + jssc.awaitTermination(); + } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index 32da0858d1a1d..be8b62d3cc6ba 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -18,223 +18,249 @@ package org.apache.spark.examples.streaming import java.nio.ByteBuffer + import scala.util.Random -import org.apache.spark.Logging -import org.apache.spark.SparkConf -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Milliseconds -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions -import org.apache.spark.streaming.kinesis.KinesisUtils -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain + +import com.amazonaws.auth.{DefaultAWSCredentialsProviderChain, BasicAWSCredentials} +import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.model.PutRecordRequest -import org.apache.log4j.Logger -import org.apache.log4j.Level +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} +import org.apache.spark.streaming.dstream.DStream.toPairDStreamFunctions +import org.apache.spark.streaming.kinesis.KinesisUtils + /** - * Kinesis Spark Streaming WordCount example. + * Consumes messages from a Amazon Kinesis streams and does wordcount. * - * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details on - * the Kinesis Spark Streaming integration. + * This example spins up 1 Kinesis Receiver per shard for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given stream. * - * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard - * for the given stream. - * It then starts pulling from the last checkpointed sequence number of the given - * and . + * Usage: KinesisWordCountASL + * is the name of the consumer app, used to track the read data in DynamoDB + * name of the Kinesis stream (ie. mySparkStream) + * endpoint of the Kinesis service + * (e.g. https://kinesis.us-east-1.amazonaws.com) * - * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region - * - * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs - * Instance profile credentials - delivered through the Amazon EC2 metadata service - * - * Usage: KinesisWordCountASL - * is the name of the Kinesis stream (ie. mySparkStream) - * is the endpoint of the Kinesis service - * (ie. https://kinesis.us-east-1.amazonaws.com) * * Example: - * $ export AWS_ACCESS_KEY_ID= - * $ export AWS_SECRET_KEY= - * $ $SPARK_HOME/bin/run-example \ - * org.apache.spark.examples.streaming.KinesisWordCountASL mySparkStream \ - * https://kinesis.us-east-1.amazonaws.com + * # export AWS keys if necessary + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * + * # run the example + * $ SPARK_HOME/bin/run-example streaming.KinesisWordCountASL myAppName mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com * - * - * Note that number of workers/threads should be 1 more than the number of receivers. - * This leaves one thread available for actually processing the data. + * There is a companion helper class called KinesisWordProducerASL which puts dummy data + * onto the Kinesis stream. * - * There is a companion helper class below called KinesisWordCountProducerASL which puts - * dummy data onto the Kinesis stream. - * Usage instructions for KinesisWordCountProducerASL are provided in that class definition. + * This code uses the DefaultAWSCredentialsProviderChain to find credentials + * in the following order: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * For more information, see + * http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + * + * See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + * the Kinesis Spark Streaming integration. */ -private object KinesisWordCountASL extends Logging { +object KinesisWordCountASL extends Logging { def main(args: Array[String]) { - /* Check that all required args were passed in. */ - if (args.length < 2) { + // Check that all required args were passed in. + if (args.length != 3) { System.err.println( """ - |Usage: KinesisWordCount + |Usage: KinesisWordCountASL + | + | is the name of the consumer app, used to track the read data in DynamoDB | is the name of the Kinesis stream | is the endpoint of the Kinesis service | (e.g. https://kinesis.us-east-1.amazonaws.com) + | + |Generate input data for Kinesis stream using the example KinesisWordProducerASL. + |See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more + |details. """.stripMargin) System.exit(1) } StreamingExamples.setStreamingLogLevels() - /* Populate the appropriate variables from the given args */ - val Array(streamName, endpointUrl) = args + // Populate the appropriate variables from the given args + val Array(appName, streamName, endpointUrl) = args - /* Determine the number of shards from the stream */ - val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) + + // Determine the number of shards from the stream using the low-level Kinesis Client + // from the AWS Java SDK. + val credentials = new DefaultAWSCredentialsProviderChain().getCredentials() + require(credentials != null, + "No AWS credentials found. Please specify credentials using one of the methods specified " + + "in http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html") + val kinesisClient = new AmazonKinesisClient(credentials) kinesisClient.setEndpoint(endpointUrl) - val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards() - .size() + val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards().size + - /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard. */ + // In this example, we're going to create 1 Kinesis Receiver/input DStream for each shard. + // This is not a necessity; if there are less receivers/DStreams than the number of shards, + // then the shards will be automatically distributed among the receivers and each receiver + // will receive data from multiple shards. val numStreams = numShards - /* Setup the and SparkConfig and StreamingContext */ - /* Spark Streaming batch interval */ + // Spark Streaming batch interval val batchInterval = Milliseconds(2000) - val sparkConfig = new SparkConf().setAppName("KinesisWordCount") - val ssc = new StreamingContext(sparkConfig, batchInterval) - /* Kinesis checkpoint interval. Same as batchInterval for this example. */ + // Kinesis checkpoint interval is the interval at which the DynamoDB is updated with information + // on sequence number of records that have been received. Same as batchInterval for this + // example. val kinesisCheckpointInterval = batchInterval - /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ + // Get the region name from the endpoint URL to save Kinesis Client Library metadata in + // DynamoDB of the same region as the Kinesis stream + val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + + // Setup the SparkConfig and StreamingContext + val sparkConfig = new SparkConf().setAppName("KinesisWordCountASL") + val ssc = new StreamingContext(sparkConfig, batchInterval) + + // Create the Kinesis DStreams val kinesisStreams = (0 until numStreams).map { i => - KinesisUtils.createStream(ssc, streamName, endpointUrl, kinesisCheckpointInterval, - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + KinesisUtils.createStream(ssc, appName, streamName, endpointUrl, regionName, + InitialPositionInStream.LATEST, kinesisCheckpointInterval, StorageLevel.MEMORY_AND_DISK_2) } - /* Union all the streams */ + // Union all the streams val unionStreams = ssc.union(kinesisStreams) - /* Convert each line of Array[Byte] to String, split into words, and count them */ - val words = unionStreams.flatMap(byteArray => new String(byteArray) - .split(" ")) + // Convert each line of Array[Byte] to String, and split into words + val words = unionStreams.flatMap(byteArray => new String(byteArray).split(" ")) - /* Map each word to a (word, 1) tuple so we can reduce/aggregate by key. */ + // Map each word to a (word, 1) tuple so we can reduce by key to count the words val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _) - /* Print the first 10 wordCounts */ + // Print the first 10 wordCounts wordCounts.print() - /* Start the streaming context and await termination */ + // Start the streaming context and await termination ssc.start() ssc.awaitTermination() } } /** - * Usage: KinesisWordCountProducerASL - * + * Usage: KinesisWordProducerASL \ + * + * * is the name of the Kinesis stream (ie. mySparkStream) - * is the endpoint of the Kinesis service + * is the endpoint of the Kinesis service * (ie. https://kinesis.us-east-1.amazonaws.com) * is the rate of records per second to put onto the stream * is the rate of records per second to put onto the stream * * Example: - * $ export AWS_ACCESS_KEY_ID= - * $ export AWS_SECRET_KEY= - * $ $SPARK_HOME/bin/run-example \ - * org.apache.spark.examples.streaming.KinesisWordCountProducerASL mySparkStream \ - * https://kinesis.us-east-1.amazonaws.com 10 5 + * $ SPARK_HOME/bin/run-example streaming.KinesisWordProducerASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com us-east-1 10 5 */ -private object KinesisWordCountProducerASL { +object KinesisWordProducerASL { def main(args: Array[String]) { - if (args.length < 4) { - System.err.println("Usage: KinesisWordCountProducerASL " + - " ") + if (args.length != 4) { + System.err.println( + """ + |Usage: KinesisWordProducerASL + + | + | is the name of the Kinesis stream + | is the endpoint of the Kinesis service + | (e.g. https://kinesis.us-east-1.amazonaws.com) + | is the rate of records per second to put onto the stream + | is the rate of records per second to put onto the stream + | + """.stripMargin) + System.exit(1) } + // Set default log4j logging level to WARN to hide Spark logs StreamingExamples.setStreamingLogLevels() - /* Populate the appropriate variables from the given args */ + // Populate the appropriate variables from the given args val Array(stream, endpoint, recordsPerSecond, wordsPerRecord) = args - /* Generate the records and return the totals */ - val totals = generate(stream, endpoint, recordsPerSecond.toInt, wordsPerRecord.toInt) + // Generate the records and return the totals + val totals = generate(stream, endpoint, recordsPerSecond.toInt, + wordsPerRecord.toInt) - /* Print the array of (index, total) tuples */ - println("Totals") - totals.foreach(total => println(total.toString())) + // Print the array of (word, total) tuples + println("Totals for the words sent") + totals.foreach(println(_)) } def generate(stream: String, endpoint: String, recordsPerSecond: Int, - wordsPerRecord: Int): Seq[(Int, Int)] = { + wordsPerRecord: Int): Seq[(String, Int)] = { - val MaxRandomInts = 10 + val randomWords = List("spark", "you", "are", "my", "father") + val totals = scala.collection.mutable.Map[String, Int]() - /* Create the Kinesis client */ + // Create the low-level Kinesis Client from the AWS Java SDK. val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) kinesisClient.setEndpoint(endpoint) println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" + - s" $recordsPerSecond records per second and $wordsPerRecord words per record"); - - val totals = new Array[Int](MaxRandomInts) - /* Put String records onto the stream per the given recordPerSec and wordsPerRecord */ - for (i <- 1 to 5) { - - /* Generate recordsPerSec records to put onto the stream */ - val records = (1 to recordsPerSecond.toInt).map { recordNum => - /* - * Randomly generate each wordsPerRec words between 0 (inclusive) - * and MAX_RANDOM_INTS (exclusive) - */ + s" $recordsPerSecond records per second and $wordsPerRecord words per record") + + // Iterate and put records onto the stream per the given recordPerSec and wordsPerRecord + for (i <- 1 to 10) { + // Generate recordsPerSec records to put onto the stream + val records = (1 to recordsPerSecond.toInt).foreach { recordNum => + // Randomly generate wordsPerRecord number of words val data = (1 to wordsPerRecord.toInt).map(x => { - /* Generate the random int */ - val randomInt = Random.nextInt(MaxRandomInts) + // Get a random index to a word + val randomWordIdx = Random.nextInt(randomWords.size) + val randomWord = randomWords(randomWordIdx) - /* Keep track of the totals */ - totals(randomInt) += 1 + // Increment total count to compare to server counts later + totals(randomWord) = totals.getOrElse(randomWord, 0) + 1 - randomInt.toString() + randomWord }).mkString(" ") - /* Create a partitionKey based on recordNum */ + // Create a partitionKey based on recordNum val partitionKey = s"partitionKey-$recordNum" - /* Create a PutRecordRequest with an Array[Byte] version of the data */ + // Create a PutRecordRequest with an Array[Byte] version of the data val putRecordRequest = new PutRecordRequest().withStreamName(stream) .withPartitionKey(partitionKey) - .withData(ByteBuffer.wrap(data.getBytes())); + .withData(ByteBuffer.wrap(data.getBytes())) - /* Put the record onto the stream and capture the PutRecordResult */ - val putRecordResult = kinesisClient.putRecord(putRecordRequest); + // Put the record onto the stream and capture the PutRecordResult + val putRecordResult = kinesisClient.putRecord(putRecordRequest) } - /* Sleep for a second */ + // Sleep for a second Thread.sleep(1000) println("Sent " + recordsPerSecond + " records") } - - /* Convert the totals to (index, total) tuple */ - (0 to (MaxRandomInts - 1)).zip(totals) + // Convert the totals to (index, total) tuple + totals.toSeq.sortBy(_._1) } } -/** - * Utility functions for Spark Streaming examples. +/** + * Utility functions for Spark Streaming examples. * This has been lifted from the examples/ project to remove the circular dependency. */ private[streaming] object StreamingExamples extends Logging { - - /** Set reasonable logging levels for streaming if the user has not configured log4j. */ + // Set reasonable logging levels for streaming if the user has not configured log4j. def setStreamingLogLevels() { val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements if (!log4jInitialized) { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala index 588e86a1887ec..83a4537559512 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -23,20 +23,20 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock} /** * This is a helper class for managing checkpoint clocks. * - * @param checkpointInterval + * @param checkpointInterval * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) */ private[kinesis] class KinesisCheckpointState( - checkpointInterval: Duration, + checkpointInterval: Duration, currentClock: Clock = new SystemClock()) extends Logging { - + /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ val checkpointClock = new ManualClock() checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) /** - * Check if it's time to checkpoint based on the current time and the derived time + * Check if it's time to checkpoint based on the current time and the derived time * for the next checkpoint * * @return true if it's time to checkpoint @@ -48,7 +48,7 @@ private[kinesis] class KinesisCheckpointState( /** * Advance the checkpoint clock by the checkpoint interval. */ - def advanceCheckpoint() = { + def advanceCheckpoint(): Unit = { checkpointClock.advance(checkpointInterval.milliseconds) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index a7fe4476cacb8..1a8a4cecc1141 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -16,39 +16,45 @@ */ package org.apache.spark.streaming.kinesis -import java.net.InetAddress import java.util.UUID +import scala.util.control.NonFatal + +import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, BasicAWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} + import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -import com.amazonaws.auth.AWSCredentialsProvider -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker + +private[kinesis] +case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) + extends AWSCredentials { + override def getAWSAccessKeyId: String = accessKeyId + override def getAWSSecretKey: String = secretKey +} /** * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: * https://github.com/awslabs/amazon-kinesis-client - * This is a custom receiver used with StreamingContext.receiverStream(Receiver) - * as described here: - * http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * Instances of this class will get shipped to the Spark Streaming Workers - * to run within a Spark Executor. + * This is a custom receiver used with StreamingContext.receiverStream(Receiver) as described here: + * http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * Instances of this class will get shipped to the Spark Streaming Workers to run within a + * Spark Executor. * * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams * by the Kinesis Client Library. If you change the App name or Stream name, - * the KCL will throw errors. This usually requires deleting the backing + * the KCL will throw errors. This usually requires deleting the backing * DynamoDB table with the same name this Kinesis application. * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Region name used by the Kinesis Client Library for + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. @@ -59,92 +65,121 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). * @param storageLevel Storage level to use for storing the received objects - * - * @return ReceiverInputDStream[Array[Byte]] + * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies + * the credentials */ private[kinesis] class KinesisReceiver( appName: String, streamName: String, endpointUrl: String, - checkpointInterval: Duration, + regionName: String, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel) - extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => - - /* - * The following vars are built in the onStart() method which executes in the Spark Worker after - * this code is serialized and shipped remotely. - */ + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsCredentialsOption: Option[SerializableAWSCredentials] + ) extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => /* - * workerId should be based on the ip address of the actual Spark Worker where this code runs - * (not the Driver's ip address.) + * ================================================================================= + * The following vars are initialize in the onStart() method which executes in the + * Spark worker after this Receiver is serialized and shipped to the worker. + * ================================================================================= */ - var workerId: String = null - /* - * This impl uses the DefaultAWSCredentialsProviderChain and searches for credentials - * in the following order of precedence: - * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY - * Java System Properties - aws.accessKeyId and aws.secretKey - * Credential profiles file at the default location (~/.aws/credentials) shared by all - * AWS SDKs and the AWS CLI - * Instance profile credentials delivered through the Amazon EC2 metadata service + /** + * workerId is used by the KCL should be based on the ip address of the actual Spark Worker + * where this code runs (not the driver's IP address.) */ - var credentialsProvider: AWSCredentialsProvider = null - - /* KCL config instance. */ - var kinesisClientLibConfiguration: KinesisClientLibConfiguration = null + private var workerId: String = null - /* - * RecordProcessorFactory creates impls of IRecordProcessor. - * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the - * IRecordProcessor.processRecords() method. - * We're using our custom KinesisRecordProcessor in this case. + /** + * Worker is the core client abstraction from the Kinesis Client Library (KCL). + * A worker can process more than one shards from the given stream. + * Each shard is assigned its own IRecordProcessor and the worker run multiple such + * processors. */ - var recordProcessorFactory: IRecordProcessorFactory = null + private var worker: Worker = null - /* - * Create a Kinesis Worker. - * This is the core client abstraction from the Kinesis Client Library (KCL). - * We pass the RecordProcessorFactory from above as well as the KCL config instance. - * A Kinesis Worker can process 1..* shards from the given stream - each with its - * own RecordProcessor. - */ - var worker: Worker = null + /** Thread running the worker */ + private var workerThread: Thread = null /** - * This is called when the KinesisReceiver starts and must be non-blocking. - * The KCL creates and manages the receiving/processing thread pool through the Worker.run() - * method. + * This is called when the KinesisReceiver starts and must be non-blocking. + * The KCL creates and manages the receiving/processing thread pool through Worker.run(). */ override def onStart() { workerId = Utils.localHostName() + ":" + UUID.randomUUID() - credentialsProvider = new DefaultAWSCredentialsProviderChain() - kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, - credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPositionInStream).withTaskBackoffTimeMillis(500) - recordProcessorFactory = new IRecordProcessorFactory { + + // KCL config instance + val awsCredProvider = resolveAWSCredentialsProvider() + val kinesisClientLibConfiguration = + new KinesisClientLibConfiguration(appName, streamName, awsCredProvider, workerId) + .withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream) + .withTaskBackoffTimeMillis(500) + .withRegionName(regionName) + + /* + * RecordProcessorFactory creates impls of IRecordProcessor. + * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the + * IRecordProcessor.processRecords() method. + * We're using our custom KinesisRecordProcessor in this case. + */ + val recordProcessorFactory = new IRecordProcessorFactory { override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, workerId, new KinesisCheckpointState(checkpointInterval)) } + worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) - worker.run() + workerThread = new Thread() { + override def run(): Unit = { + try { + worker.run() + } catch { + case NonFatal(e) => + restart("Error running the KCL worker in Receiver", e) + } + } + } + workerThread.setName(s"Kinesis Receiver ${streamId}") + workerThread.setDaemon(true) + workerThread.start() logInfo(s"Started receiver with workerId $workerId") } /** - * This is called when the KinesisReceiver stops. - * The KCL worker.shutdown() method stops the receiving/processing threads. - * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. + * This is called when the KinesisReceiver stops. + * The KCL worker.shutdown() method stops the receiving/processing threads. + * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. */ override def onStop() { - worker.shutdown() - logInfo(s"Shut down receiver with workerId $workerId") + if (workerThread != null) { + if (worker != null) { + worker.shutdown() + worker = null + } + workerThread.join() + workerThread = null + logInfo(s"Stopped receiver for workerId $workerId") + } workerId = null - credentialsProvider = null - kinesisClientLibConfiguration = null - recordProcessorFactory = null - worker = null + } + + /** + * If AWS credential is provided, return a AWSCredentialProvider returning that credential. + * Otherwise, return the DefaultAWSCredentialsProviderChain. + */ + private def resolveAWSCredentialsProvider(): AWSCredentialsProvider = { + awsCredentialsOption match { + case Some(awsCredentials) => + logInfo("Using provided AWS credentials") + new AWSCredentialsProvider { + override def getCredentials: AWSCredentials = awsCredentials + override def refresh(): Unit = { } + } + case None => + logInfo("Using DefaultAWSCredentialsProviderChain") + new DefaultAWSCredentialsProviderChain() + } } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index af8cd875b4541..fe9e3a0c793e2 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -35,7 +35,10 @@ import com.amazonaws.services.kinesis.model.Record /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. - * The Kinesis Worker creates an instance of this KinesisRecordProcessor upon startup. + * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create + * multiple Receivers. * * @param receiver Kinesis receiver * @param workerId for logging purposes @@ -47,8 +50,8 @@ private[kinesis] class KinesisRecordProcessor( workerId: String, checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { - /* shardId to be populated during initialize() */ - var shardId: String = _ + // shardId to be populated during initialize() + private var shardId: String = _ /** * The Kinesis Client Library calls this method during IRecordProcessor initialization. @@ -56,8 +59,8 @@ private[kinesis] class KinesisRecordProcessor( * @param shardId assigned by the KCL to this particular RecordProcessor. */ override def initialize(shardId: String) { - logInfo(s"Initialize: Initializing workerId $workerId with shardId $shardId") this.shardId = shardId + logInfo(s"Initialized workerId $workerId with shardId $shardId") } /** @@ -66,29 +69,34 @@ private[kinesis] class KinesisRecordProcessor( * and Spark Streaming's Receiver.store(). * * @param batch list of records from the Kinesis stream shard - * @param checkpointer used to update Kinesis when this batch has been processed/stored + * @param checkpointer used to update Kinesis when this batch has been processed/stored * in the DStream */ override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { if (!receiver.isStopped()) { try { /* - * Note: If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming - * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the - * internally-configured Spark serializer (kryo, etc). - * This is not desirable, so we instead store a raw Array[Byte] and decouple - * ourselves from Spark's internal serialization strategy. - */ + * Notes: + * 1) If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming + * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the + * internally-configured Spark serializer (kryo, etc). + * 2) This is not desirable, so we instead store a raw Array[Byte] and decouple + * ourselves from Spark's internal serialization strategy. + * 3) For performance, the BlockGenerator is asynchronously queuing elements within its + * memory before creating blocks. This prevents the small block scenario, but requires + * that you register callbacks to know when a block has been generated and stored + * (WAL is sufficient for storage) before can checkpoint back to the source. + */ batch.foreach(record => receiver.store(record.getData().array())) - + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") /* - * Checkpoint the sequence number of the last record successfully processed/stored + * Checkpoint the sequence number of the last record successfully processed/stored * in the batch. * In this implementation, we're checkpointing after the given checkpointIntervalMillis. - * Note that this logic requires that processRecords() be called AND that it's time to - * checkpoint. I point this out because there is no background thread running the + * Note that this logic requires that processRecords() be called AND that it's time to + * checkpoint. I point this out because there is no background thread running the * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). * However, if the worker dies unexpectedly, a checkpoint may not happen. @@ -116,22 +124,22 @@ private[kinesis] class KinesisRecordProcessor( logError(s"Exception: WorkerId $workerId encountered and exception while storing " + " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) - /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor.*/ + /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ throw e } } } else { /* RecordProcessor has been stopped. */ - logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + s" and shardId $shardId. No more records will be processed.") } } /** * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: - * 1) the stream is resharding by splitting or merging adjacent shards + * 1) the stream is resharding by splitting or merging adjacent shards * (ShutdownReason.TERMINATE) - * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason + * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason * (ShutdownReason.ZOMBIE) * * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE @@ -145,7 +153,7 @@ private[kinesis] class KinesisRecordProcessor( * Checkpoint to indicate that all records from the shard have been drained and processed. * It's now OK to read from the new shards that resulted from a resharding event. */ - case ShutdownReason.TERMINATE => + case ShutdownReason.TERMINATE => KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) /* @@ -190,7 +198,7 @@ private[kinesis] object KinesisRecordProcessor extends Logging { logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) } - /* Throw: Shutdown has been requested by the Kinesis Client Library.*/ + /* Throw: Shutdown has been requested by the Kinesis Client Library. */ case _: ShutdownException => { logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) throw e diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 96f4399accd3a..e5acab50181e1 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -16,29 +16,78 @@ */ package org.apache.spark.streaming.kinesis -import org.apache.spark.annotation.Experimental +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Duration -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream -import org.apache.spark.streaming.api.java.JavaStreamingContext +import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream - -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import org.apache.spark.streaming.{Duration, StreamingContext} -/** - * Helper class to create Amazon Kinesis Input Stream - * :: Experimental :: - */ -@Experimental object KinesisUtils { /** - * Create an InputDStream that pulls messages from a Kinesis stream. - * :: Experimental :: - * @param ssc StreamingContext object + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel + ): ReceiverInputDStream[Array[Byte]] = { + // Setting scope to override receiver stream's scope of "receiver stream" + ssc.withNamedScope("kinesis stream") { + ssc.receiverStream( + new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, checkpointInterval, storageLevel, None)) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. @@ -48,28 +97,84 @@ object KinesisUtils { * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). - * @param storageLevel Storage level to use for storing the received objects + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream( + new KinesisReceiver(kinesisAppName, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, checkpointInterval, storageLevel, + Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * @return ReceiverInputDStream[Array[Byte]] + * Note: + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in + * [[org.apache.spark.SparkConf]]. + * + * @param ssc Java StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Endpoint url of Kinesis service + * (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ - @Experimental + @deprecated("use other forms of createStream", "1.4.0") def createStream( ssc: StreamingContext, streamName: String, endpointUrl: String, checkpointInterval: Duration, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { - ssc.receiverStream(new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, - checkpointInterval, initialPositionInStream, storageLevel)) + storageLevel: StorageLevel + ): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream( + new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, getRegionByEndpoint(endpointUrl), + initialPositionInStream, checkpointInterval, storageLevel, None)) } /** - * Create a Java-friendly InputDStream that pulls messages from a Kinesis stream. - * :: Experimental :: + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. @@ -79,19 +184,116 @@ object KinesisUtils { * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). - * @param storageLevel Storage level to use for storing the received objects + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * @return JavaReceiverInputDStream[Array[Byte]] + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ - @Experimental def createStream( - jssc: JavaStreamingContext, - streamName: String, - endpointUrl: String, + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, awsAccessKeyId, awsSecretKey) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * - The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets AWS credentials. + * - The region of the `endpointUrl` will be used for DynamoDB and CloudWatch. + * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in + * [[org.apache.spark.SparkConf]]. + * + * @param jssc Java StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Endpoint url of Kinesis service + * (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + */ + @deprecated("use other forms of createStream", "1.4.0") + def createStream( + jssc: JavaStreamingContext, + streamName: String, + endpointUrl: String, checkpointInterval: Duration, initialPositionInStream: InitialPositionInStream, - storageLevel: StorageLevel): JavaReceiverInputDStream[Array[Byte]] = { - jssc.receiverStream(new KinesisReceiver(jssc.ssc.sc.appName, streamName, - endpointUrl, checkpointInterval, initialPositionInStream, storageLevel)) + storageLevel: StorageLevel + ): JavaReceiverInputDStream[Array[Byte]] = { + createStream( + jssc.ssc, streamName, endpointUrl, checkpointInterval, initialPositionInStream, storageLevel) + } + + private def getRegionByEndpoint(endpointUrl: String): String = { + RegionUtils.getRegionByEndpoint(endpointUrl).getName() + } + + private def validateRegion(regionName: String): String = { + Option(RegionUtils.getRegion(regionName)).map { _.getName }.getOrElse { + throw new IllegalArgumentException(s"Region name '$regionName' is not valid") + } } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 255fe65819608..2103dca6b766f 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -20,26 +20,18 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions.seqAsJavaList -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.Milliseconds -import org.apache.spark.streaming.Seconds -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.TestSuiteBase -import org.apache.spark.util.{ManualClock, Clock} - -import org.mockito.Mockito._ -import org.scalatest.BeforeAndAfter -import org.scalatest.Matchers -import org.scalatest.mock.MockitoSugar - -import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException -import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfter, Matchers} +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, Seconds, StreamingContext, TestSuiteBase} +import org.apache.spark.util.{Clock, ManualClock, Utils} /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor @@ -65,7 +57,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft var checkpointStateMock: KinesisCheckpointState = _ var currentClockMock: Clock = _ - override def beforeFunction() = { + override def beforeFunction(): Unit = { receiverMock = mock[KinesisReceiver] checkpointerMock = mock[IRecordProcessorCheckpointer] checkpointClockMock = mock[ManualClock] @@ -81,15 +73,28 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointStateMock, currentClockMock) } - test("kinesis utils api") { + test("KinesisUtils API") { val ssc = new StreamingContext(master, framework, batchDuration) // Tests the API, does not actually test data receiving - val kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", + val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", "https://kinesis.us-west-2.amazonaws.com", Seconds(2), - InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2); + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2) + val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + ssc.stop() } + test("check serializability of SerializableAWSCredentials") { + Utils.deserialize[SerializableAWSCredentials]( + Utils.serialize(new SerializableAWSCredentials("x", "y"))) + } + test("process records including store and checkpoint") { when(receiverMock.isStopped()).thenReturn(false) when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) diff --git a/graphx/pom.xml b/graphx/pom.xml index d38a3aa8256b7..28b41228feb3d 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + com.google.guava guava diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala index 058c8c8aa1b24..ce1054ed92ba1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeDirection.scala @@ -26,8 +26,8 @@ class EdgeDirection private (private val name: String) extends Serializable { * out becomes in and both and either remain the same. */ def reverse: EdgeDirection = this match { - case EdgeDirection.In => EdgeDirection.Out - case EdgeDirection.Out => EdgeDirection.In + case EdgeDirection.In => EdgeDirection.Out + case EdgeDirection.Out => EdgeDirection.In case EdgeDirection.Either => EdgeDirection.Either case EdgeDirection.Both => EdgeDirection.Both } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index cc70b396a8dd4..4611a3ace219b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -41,14 +41,16 @@ abstract class EdgeRDD[ED]( @transient sc: SparkContext, @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { + // scalastyle:off structural.type private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } + // scalastyle:on structural.type override protected def getPartitions: Array[Partition] = partitionsRDD.partitions override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context) if (p.hasNext) { - p.next._2.iterator.map(_.copy()) + p.next()._2.iterator.map(_.copy()) } else { Iterator.empty } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala index c8790cac3d8a0..65f82429d2029 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeTriplet.scala @@ -37,7 +37,7 @@ class EdgeTriplet[VD, ED] extends Edge[ED] { /** * Set the edge properties of this triplet. */ - protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD,ED] = { + protected[spark] def set(other: Edge[ED]): EdgeTriplet[VD, ED] = { srcId = other.srcId dstId = other.dstId attr = other.attr diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 36dc7b0f86c89..db73a8abc5733 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -316,7 +316,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * satisfy the predicates */ def subgraph( - epred: EdgeTriplet[VD,ED] => Boolean = (x => true), + epred: EdgeTriplet[VD, ED] => Boolean = (x => true), vpred: (VertexId, VD) => Boolean = ((v, d) => true)) : Graph[VD, ED] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 7edd627b20918..9451ff1e5c0e2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -124,18 +124,18 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = { val nbrs = edgeDirection match { case EdgeDirection.Either => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => { ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))) ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))) }, (a, b) => a ++ b, TripletFields.All) case EdgeDirection.In => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))), (a, b) => a ++ b, TripletFields.Src) case EdgeDirection.Out => - graph.aggregateMessages[Array[(VertexId,VD)]]( + graph.aggregateMessages[Array[(VertexId, VD)]]( ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))), (a, b) => a ++ b, TripletFields.Dst) case EdgeDirection.Both => @@ -253,7 +253,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def filter[VD2: ClassTag, ED2: ClassTag]( preprocess: Graph[VD, ED] => Graph[VD2, ED2], epred: (EdgeTriplet[VD2, ED2]) => Boolean = (x: EdgeTriplet[VD2, ED2]) => true, - vpred: (VertexId, VD2) => Boolean = (v:VertexId, d:VD2) => true): Graph[VD, ED] = { + vpred: (VertexId, VD2) => Boolean = (v: VertexId, d: VD2) => true): Graph[VD, ED] = { graph.mask(preprocess(graph).subgraph(epred, vpred)) } @@ -356,7 +356,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali maxIterations: Int = Int.MaxValue, activeDirection: EdgeDirection = EdgeDirection.Either)( vprog: (VertexId, VD, A) => VD, - sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)], + sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], mergeMsg: (A, A) => A) : Graph[VD, ED] = { Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 01b013ff716fc..cfcf7244eaed5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -147,10 +147,10 @@ object Pregel extends Logging { logInfo("Pregel finished iteration " + i) // Unpersist the RDDs hidden by newly-materialized RDDs - oldMessages.unpersist(blocking=false) - newVerts.unpersist(blocking=false) - prevG.unpersistVertices(blocking=false) - prevG.edges.unpersist(blocking=false) + oldMessages.unpersist(blocking = false) + newVerts.unpersist(blocking = false) + prevG.unpersistVertices(blocking = false) + prevG.edges.unpersist(blocking = false) // count the iteration i += 1 } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index c561570809253..ab021a252eb8a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -156,8 +156,8 @@ class EdgePartition[ val size = data.size var i = 0 while (i < size) { - edge.srcId = srcIds(i) - edge.dstId = dstIds(i) + edge.srcId = srcIds(i) + edge.dstId = dstIds(i) edge.attr = data(i) newData(i) = f(edge) i += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index bc974b2f04e70..8c0a461e99fa4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -116,7 +116,7 @@ object PageRank extends Logging { val personalized = srcId isDefined val src: VertexId = srcId.getOrElse(-1L) - def delta(u: VertexId, v: VertexId):Double = { if (u == v) 1.0 else 0.0 } + def delta(u: VertexId, v: VertexId): Double = { if (u == v) 1.0 else 0.0 } var iteration = 0 var prevRankGraph: Graph[Double, Double] = null @@ -133,13 +133,13 @@ object PageRank extends Logging { // edge partitions. prevRankGraph = rankGraph val rPrb = if (personalized) { - (src: VertexId ,id: VertexId) => resetProb * delta(src,id) + (src: VertexId , id: VertexId) => resetProb * delta(src, id) } else { (src: VertexId, id: VertexId) => resetProb } rankGraph = rankGraph.joinVertices(rankUpdates) { - (id, oldRank, msgSum) => rPrb(src,id) + (1.0 - resetProb) * msgSum + (id, oldRank, msgSum) => rPrb(src, id) + (1.0 - resetProb) * msgSum }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -243,7 +243,7 @@ object PageRank extends Logging { // Execute a dynamic version of Pregel. val vp = if (personalized) { - (id: VertexId, attr: (Double, Double),msgSum: Double) => + (id: VertexId, attr: (Double, Double), msgSum: Double) => personalizedVertexProgram(id, attr, msgSum) } else { (id: VertexId, attr: (Double, Double), msgSum: Double) => diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 3b0e1628d86b5..9cb24ed080e1c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -210,7 +210,7 @@ object SVDPlusPlus { /** * Forces materialization of a Graph by count()ing its RDDs. */ - private def materialize(g: Graph[_,_]): Unit = { + private def materialize(g: Graph[_, _]): Unit = { g.vertices.count() g.edges.count() } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index daf162085e3e4..a5d598053f9ca 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -38,7 +38,7 @@ import org.apache.spark.graphx._ */ object TriangleCount { - def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD,ED]): Graph[Int, ED] = { + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = { // Remove redundant edges val g = graph.groupEdges((a, b) => a).cache() @@ -49,7 +49,7 @@ object TriangleCount { var i = 0 while (i < nbrs.size) { // prevent self cycle - if(nbrs(i) != vid) { + if (nbrs(i) != vid) { set.add(nbrs(i)) } i += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 2d6a825b61726..9591c4e9b8f4e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -243,14 +243,15 @@ object GraphGenerators { * @return A graph containing vertices with the row and column ids * as their attributes and edge values as 1.0. */ - def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int,Int), Double] = { + def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int, Int), Double] = { // Convert row column address into vertex ids (row major order) def sub2ind(r: Int, c: Int): VertexId = r * cols + c - val vertices: RDD[(VertexId, (Int,Int))] = - sc.parallelize(0 until rows).flatMap( r => (0 until cols).map( c => (sub2ind(r,c), (r,c)) ) ) + val vertices: RDD[(VertexId, (Int, Int))] = sc.parallelize(0 until rows).flatMap { r => + (0 until cols).map( c => (sub2ind(r, c), (r, c)) ) + } val edges: RDD[Edge[Double]] = - vertices.flatMap{ case (vid, (r,c)) => + vertices.flatMap{ case (vid, (r, c)) => (if (r + 1 < rows) { Seq( (sub2ind(r, c), sub2ind(r + 1, c))) } else { Seq.empty }) ++ (if (c + 1 < cols) { Seq( (sub2ind(r, c), sub2ind(r, c + 1))) } else { Seq.empty }) }.map{ case (src, dst) => Edge(src, dst, 1.0) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala index eb1dbe52c2fda..f1ecc9e2219d1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeRDDSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.storage.StorageLevel -class EdgeRDDSuite extends FunSuite with LocalSparkContext { +class EdgeRDDSuite extends SparkFunSuite with LocalSparkContext { test("cache, getStorageLevel") { // test to see if getStorageLevel returns correct value after caching diff --git a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala index 5a2c73b414279..094a63472eaab 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/EdgeSuite.scala @@ -17,21 +17,21 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class EdgeSuite extends FunSuite { +class EdgeSuite extends SparkFunSuite { test ("compare") { // decending order val testEdges: Array[Edge[Int]] = Array( - Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), - Edge(0x2345L, 0x1234L, 1), - Edge(0x1234L, 0x5678L, 1), - Edge(0x1234L, 0x2345L, 1), + Edge(0x7FEDCBA987654321L, -0x7FEDCBA987654321L, 1), + Edge(0x2345L, 0x1234L, 1), + Edge(0x1234L, 0x5678L, 1), + Edge(0x1234L, 0x2345L, 1), Edge(-0x7FEDCBA987654321L, 0x7FEDCBA987654321L, 1) ) // to ascending order val sortedEdges = testEdges.sorted(Edge.lexicographicOrdering[Int]) - + for (i <- 0 until testEdges.length) { assert(sortedEdges(i) == testEdges(testEdges.length - i - 1)) } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 9bc8007ce49cd..57a8b95dd12e9 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.rdd._ -import org.scalatest.FunSuite -class GraphOpsSuite extends FunSuite with LocalSparkContext { +class GraphOpsSuite extends SparkFunSuite with LocalSparkContext { test("joinVertices") { withSpark { sc => @@ -59,7 +58,7 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { test ("filter") { withSpark { sc => val n = 5 - val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x))) + val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) val graph: Graph[Int, Int] = Graph(vertices, edges).cache() val filteredGraph = graph.filter( @@ -67,11 +66,11 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { val degrees: VertexRDD[Int] = graph.outDegrees graph.outerJoinVertices(degrees) {(vid, data, deg) => deg.getOrElse(0)} }, - vpred = (vid: VertexId, deg:Int) => deg > 0 + vpred = (vid: VertexId, deg: Int) => deg > 0 ).cache() val v = filteredGraph.vertices.collect().toSet - assert(v === Set((0,0))) + assert(v === Set((0, 0))) // the map is necessary because of object-reuse in the edge iterator val e = filteredGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index a570e4ed75fc3..1f5e27d5508b8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.Graph._ import org.apache.spark.graphx.PartitionStrategy._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class GraphSuite extends FunSuite with LocalSparkContext { +class GraphSuite extends SparkFunSuite with LocalSparkContext { def starGraph(sc: SparkContext, n: Int): Graph[String, Int] = { Graph.fromEdgeTuples(sc.parallelize((1 to n).map(x => (0: VertexId, x: VertexId)), 3), "v") @@ -248,7 +246,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { test("mask") { withSpark { sc => val n = 5 - val vertices = sc.parallelize((0 to n).map(x => (x:VertexId, x))) + val vertices = sc.parallelize((0 to n).map(x => (x: VertexId, x))) val edges = sc.parallelize((1 to n).map(x => Edge(0, x, x))) val graph: Graph[Int, Int] = Graph(vertices, edges).cache() @@ -260,11 +258,11 @@ class GraphSuite extends FunSuite with LocalSparkContext { val projectedGraph = graph.mask(subgraph) val v = projectedGraph.vertices.collect().toSet - assert(v === Set((0,0), (1,1), (2,2), (4,4), (5,5))) + assert(v === Set((0, 0), (1, 1), (2, 2), (4, 4), (5, 5))) // the map is necessary because of object-reuse in the edge iterator val e = projectedGraph.edges.map(e => Edge(e.srcId, e.dstId, e.attr)).collect().toSet - assert(e === Set(Edge(0,1,1), Edge(0,2,2), Edge(0,5,5))) + assert(e === Set(Edge(0, 1, 1), Edge(0, 2, 2), Edge(0, 5, 5))) } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala index 490b94429ea1f..8afa2d403b53f 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.rdd._ -class PregelSuite extends FunSuite with LocalSparkContext { +class PregelSuite extends SparkFunSuite with LocalSparkContext { test("1 iteration") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala index d0a7198d691d7..f1aa685a79c98 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/VertexRDDSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.graphx -import org.scalatest.FunSuite - -import org.apache.spark.{HashPartitioner, SparkContext} +import org.apache.spark.{HashPartitioner, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -class VertexRDDSuite extends FunSuite with LocalSparkContext { +class VertexRDDSuite extends SparkFunSuite with LocalSparkContext { private def vertices(sc: SparkContext, n: Int) = { VertexRDD(sc.parallelize((0 to n).map(x => (x.toLong, x)), 5)) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 515f3a9cd02eb..7435647c6d9ee 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class EdgePartitionSuite extends FunSuite { +class EdgePartitionSuite extends SparkFunSuite { def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = { val builder = new EdgePartitionBuilder[A, Int] diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index fe8304c1cdc32..1203f8959f506 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.graphx.impl -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer import org.apache.spark.graphx._ -class VertexPartitionSuite extends FunSuite { +class VertexPartitionSuite extends SparkFunSuite { test("isDefined, filter") { val vp = VertexPartition(Iterator((0L, 1), (1L, 1))).filter { (vid, attr) => vid == 0 } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala index 4cc30a96408f8..c965a6eb8df13 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class ConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Grid Connected Components") { withSpark { sc => @@ -52,13 +50,16 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1)) val chain2 = (10 until 20).map(x => (x, x + 1)) - val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, 1.0) val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect() for ( (id, cc) <- vertices ) { - if(id < 10) { assert(cc === 0) } - else { assert(cc === 10) } + if (id < 10) { + assert(cc === 0) + } else { + assert(cc === 10) + } } val ccMap = vertices.toMap for (id <- 0 until 20) { @@ -75,7 +76,7 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1)) val chain2 = (10 until 20).map(x => (x, x + 1)) - val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s, d) => (s.toLong, d.toLong) } val twoChains = Graph.fromEdgeTuples(rawEdges, true).reverse val ccGraph = twoChains.connectedComponents() val vertices = ccGraph.vertices.collect() @@ -106,9 +107,9 @@ class ConnectedComponentsSuite extends FunSuite with LocalSparkContext { (4L, ("peter", "student")))) // Create an RDD for edges val relationships: RDD[Edge[String]] = - sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), + sc.parallelize(Array(Edge(3L, 7L, "collab"), Edge(5L, 3L, "advisor"), Edge(2L, 5L, "colleague"), Edge(5L, 7L, "pi"), - Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague"))) + Edge(4L, 0L, "student"), Edge(5L, 0L, "colleague"))) // Edges are: // 2 ---> 5 ---> 3 // | \ diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala index 61fd0c4605568..808877f0590f8 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/LabelPropagationSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class LabelPropagationSuite extends FunSuite with LocalSparkContext { +class LabelPropagationSuite extends SparkFunSuite with LocalSparkContext { test("Label Propagation") { withSpark { sc => // Construct a graph with two cliques connected by a single edge diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala index 3f3c9dfd7b3dd..45f1e3011035e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/PageRankSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators @@ -31,14 +30,14 @@ object GridPageRank { def sub2ind(r: Int, c: Int): Int = r * nCols + c // Make the grid graph for (r <- 0 until nRows; c <- 0 until nCols) { - val ind = sub2ind(r,c) + val ind = sub2ind(r, c) if (r + 1 < nRows) { outDegree(ind) += 1 - inNbrs(sub2ind(r + 1,c)) += ind + inNbrs(sub2ind(r + 1, c)) += ind } if (c + 1 < nCols) { outDegree(ind) += 1 - inNbrs(sub2ind(r,c + 1)) += ind + inNbrs(sub2ind(r, c + 1)) += ind } } // compute the pagerank @@ -57,7 +56,7 @@ object GridPageRank { } -class PageRankSuite extends FunSuite with LocalSparkContext { +class PageRankSuite extends SparkFunSuite with LocalSparkContext { def compareRanks(a: VertexRDD[Double], b: VertexRDD[Double]): Double = { a.leftJoin(b) { case (id, a, bOpt) => (a - bOpt.getOrElse(0.0)) * (a - bOpt.getOrElse(0.0)) } @@ -99,8 +98,8 @@ class PageRankSuite extends FunSuite with LocalSparkContext { val resetProb = 0.15 val errorTol = 1.0e-5 - val staticRanks1 = starGraph.staticPersonalizedPageRank(0,numIter = 1, resetProb).vertices - val staticRanks2 = starGraph.staticPersonalizedPageRank(0,numIter = 2, resetProb) + val staticRanks1 = starGraph.staticPersonalizedPageRank(0, numIter = 1, resetProb).vertices + val staticRanks2 = starGraph.staticPersonalizedPageRank(0, numIter = 2, resetProb) .vertices.cache() // Static PageRank should only take 2 iterations to converge @@ -117,7 +116,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { } assert(staticErrors.sum === 0) - val dynamicRanks = starGraph.personalizedPageRank(0,0, resetProb).vertices.cache() + val dynamicRanks = starGraph.personalizedPageRank(0, 0, resetProb).vertices.cache() assert(compareRanks(staticRanks2, dynamicRanks) < errorTol) } } // end of test Star PageRank @@ -162,7 +161,7 @@ class PageRankSuite extends FunSuite with LocalSparkContext { test("Chain PersonalizedPageRank") { withSpark { sc => val chain1 = (0 until 9).map(x => (x, x + 1) ) - val rawEdges = sc.parallelize(chain1, 1).map { case (s,d) => (s.toLong, d.toLong) } + val rawEdges = sc.parallelize(chain1, 1).map { case (s, d) => (s.toLong, d.toLong) } val chain = Graph.fromEdgeTuples(rawEdges, 1.0).cache() val resetProb = 0.15 val tol = 0.0001 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala index 7bd6b7f3c4ab2..2991438f5e57e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ -class SVDPlusPlusSuite extends FunSuite with LocalSparkContext { +class SVDPlusPlusSuite extends SparkFunSuite with LocalSparkContext { test("Test SVD++ with mean square error on training set") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala index f2c38e79c452c..d7eaa70ce6407 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/ShortestPathsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class ShortestPathsSuite extends FunSuite with LocalSparkContext { +class ShortestPathsSuite extends SparkFunSuite with LocalSparkContext { test("Shortest Path Computations") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala index 1f658c371ffcf..d6b03208180db 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/StronglyConnectedComponentsSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ import org.apache.spark.graphx.util.GraphGenerators import org.apache.spark.rdd._ -class StronglyConnectedComponentsSuite extends FunSuite with LocalSparkContext { +class StronglyConnectedComponentsSuite extends SparkFunSuite with LocalSparkContext { test("Island Strongly Connected Components") { withSpark { sc => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala index 293c7f3ba4c21..c47552cf3a3bd 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.graphx.lib -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx._ import org.apache.spark.graphx.PartitionStrategy.RandomVertexCut -class TriangleCountSuite extends FunSuite with LocalSparkContext { +class TriangleCountSuite extends SparkFunSuite with LocalSparkContext { test("Count a single triangle") { withSpark { sc => @@ -58,7 +57,7 @@ class TriangleCountSuite extends FunSuite with LocalSparkContext { val triangles = Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++ Array(0L -> -1L, -1L -> -2L, -2L -> 0L) - val revTriangles = triangles.map { case (a,b) => (b,a) } + val revTriangles = triangles.map { case (a, b) => (b, a) } val rawEdges = sc.parallelize(triangles ++ revTriangles, 2) val graph = Graph.fromEdgeTuples(rawEdges, true).cache() val triangleCount = graph.triangleCount() diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala index f3b3738db0dad..186d0cc2a977b 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/BytecodeUtilsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class BytecodeUtilsSuite extends FunSuite { +class BytecodeUtilsSuite extends SparkFunSuite { import BytecodeUtilsSuite.TestClass diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala index 8d9c8ddccbb3c..32e0c841c6997 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.graphx.util -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.LocalSparkContext -class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { +class GraphGeneratorsSuite extends SparkFunSuite with LocalSparkContext { test("GraphGenerators.generateRandomEdges") { val src = 5 diff --git a/launcher/pom.xml b/launcher/pom.xml index ebfa7685eaa18..cc177d23dff77 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -29,7 +29,7 @@ org.apache.spark spark-launcher_2.10 jar - Spark Launcher Project + Spark Project Launcher http://spark.apache.org/ launcher diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index b8f02b961113d..33d65d13f0d25 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -121,7 +121,10 @@ List buildJavaCommand(String extraClassPath) throws IOException { * set it. */ void addPermGenSizeOpt(List cmd) { - // Don't set MaxPermSize for Java 8 and later. + // Don't set MaxPermSize for IBM Java, or Oracle Java 8 and later. + if (getJavaVendor() == JavaVendor.IBM) { + return; + } String[] version = System.getProperty("java.version").split("\\."); if (Integer.parseInt(version[0]) > 1 || Integer.parseInt(version[1]) > 7) { return; @@ -293,6 +296,9 @@ Properties loadPropertiesFile() throws IOException { try { fd = new FileInputStream(propsFile); props.load(new InputStreamReader(fd, "UTF-8")); + for (Map.Entry e : props.entrySet()) { + e.setValue(e.getValue().toString().trim()); + } } finally { if (fd != null) { try { diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 261402856ac5e..2665a700fe1f5 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -32,6 +32,11 @@ class CommandBuilderUtils { static final String ENV_SPARK_HOME = "SPARK_HOME"; static final String ENV_SPARK_ASSEMBLY = "_SPARK_ASSEMBLY"; + /** The set of known JVM vendors. */ + static enum JavaVendor { + Oracle, IBM, OpenJDK, Unknown + }; + /** Returns whether the given string is null or empty. */ static boolean isEmpty(String s) { return s == null || s.isEmpty(); @@ -108,6 +113,21 @@ static boolean isWindows() { return os.startsWith("Windows"); } + /** Returns an enum value indicating whose JVM is being used. */ + static JavaVendor getJavaVendor() { + String vendorString = System.getProperty("java.vendor"); + if (vendorString.contains("Oracle")) { + return JavaVendor.Oracle; + } + if (vendorString.contains("IBM")) { + return JavaVendor.IBM; + } + if (vendorString.contains("OpenJDK")) { + return JavaVendor.OpenJDK; + } + return JavaVendor.Unknown; + } + /** * Updates the user environment, appending the given pathList to the existing value of the given * environment variable (or setting it if it hasn't yet been set). diff --git a/make-distribution.sh b/make-distribution.sh index 1bfa9acb1fe6e..a2b0c431fb4d0 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -58,7 +58,7 @@ while (( "$#" )); do --hadoop) echo "Error: '--hadoop' is no longer supported:" echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead." - echo "Error: Related profiles include hadoop-2.2, hadoop-2.3 and hadoop-2.4." + echo "Error: Related profiles include hadoop-1, hadoop-2.2, hadoop-2.3 and hadoop-2.4." exit_with_usage ;; --with-yarn) @@ -231,6 +231,11 @@ cp -r "$SPARK_HOME/bin" "$DISTDIR" cp -r "$SPARK_HOME/python" "$DISTDIR" cp -r "$SPARK_HOME/sbin" "$DISTDIR" cp -r "$SPARK_HOME/ec2" "$DISTDIR" +# Copy SparkR if it exists +if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then + mkdir -p "$DISTDIR"/R/lib + cp -r "$SPARK_HOME/R/lib/SparkR" "$DISTDIR"/R/lib +fi # Download and copy in tachyon, if requested if [ "$SPARK_TACHYON" == "true" ]; then diff --git a/mllib/pom.xml b/mllib/pom.xml index 0c07ca1a62fd3..65c647a91d192 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 7f3f3262a644f..e9a5d7c0e7988 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -19,16 +19,16 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for estimators that fit models to data. */ -@AlphaComponent -abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { +@DeveloperApi +abstract class Estimator[M <: Model[M]] extends PipelineStage { /** * Fits a single model to the input data with optional parameters. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 9974efe7b1d25..186bf7ae7a2f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -17,22 +17,33 @@ package org.apache.spark.ml -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.ParamMap /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]]. * * @tparam M model type */ -@AlphaComponent +@DeveloperApi abstract class Model[M <: Model[M]] extends Transformer { /** * The parent estimator that produced this model. * Note: For ensembles' component Models, this value can be null. */ - val parent: Estimator[M] + @transient var parent: Estimator[M] = _ + + /** + * Sets the parent of this model (Java API). + */ + def setParent(parent: Estimator[M]): M = { + this.parent = parent + this.asInstanceOf[M] + } + + /** Indicates whether this [[Model]] has a corresponding parent. */ + def hasParent: Boolean = parent != null override def copy(extra: ParamMap): M = { // The default implementation of Params.copy doesn't work for models. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 33d430f5671ee..11a4722722ea1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -20,16 +20,17 @@ package org.apache.spark.ml import scala.collection.mutable.ListBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]]. */ -@AlphaComponent +@DeveloperApi abstract class PipelineStage extends Params with Logging { /** @@ -68,7 +69,7 @@ abstract class PipelineStage extends Params with Logging { } /** - * :: AlphaComponent :: + * :: Experimental :: * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline#fit]] is called, the * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator#fit]] method will @@ -79,8 +80,10 @@ abstract class PipelineStage extends Params with Logging { * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * an identity transformer. */ -@AlphaComponent -class Pipeline extends Estimator[PipelineModel] { +@Experimental +class Pipeline(override val uid: String) extends Estimator[PipelineModel] { + + def this() = this(Identifiable.randomUID("pipeline")) /** * param for pipeline stages @@ -94,12 +97,9 @@ class Pipeline extends Estimator[PipelineModel] { /** @group getParam */ def getStages: Array[PipelineStage] = $(stages).clone() - override def validateParams(paramMap: ParamMap): Unit = { - val map = extractParamMap(paramMap) - getStages.foreach { - case pStage: Params => pStage.validateParams(map) - case _ => - } + override def validateParams(): Unit = { + super.validateParams() + $(stages).foreach(_.validateParams()) } /** @@ -148,7 +148,7 @@ class Pipeline extends Estimator[PipelineModel] { } } - new PipelineModel(this, transformers.toArray) + new PipelineModel(uid, transformers.toArray).setParent(this) } override def copy(extra: ParamMap): Pipeline = { @@ -166,12 +166,12 @@ class Pipeline extends Estimator[PipelineModel] { } /** - * :: AlphaComponent :: - * Represents a compiled pipeline. + * :: Experimental :: + * Represents a fitted pipeline. */ -@AlphaComponent +@Experimental class PipelineModel private[ml] ( - override val parent: Pipeline, + override val uid: String, val stages: Array[Transformer]) extends Model[PipelineModel] with Logging { @@ -190,6 +190,6 @@ class PipelineModel private[ml] ( } override def copy(extra: ParamMap): PipelineModel = { - new PipelineModel(parent, stages) + new PipelineModel(uid, stages) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index f6a5f27425d1f..e752b81a14282 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -58,7 +58,6 @@ private[ml] trait PredictorParams extends Params /** * :: DeveloperApi :: - * * Abstraction for prediction problems (regression and classification). * * @tparam FeaturesType Type of features. @@ -88,7 +87,7 @@ abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) - copyValues(train(dataset)) + copyValues(train(dataset).setParent(this)) } override def copy(extra: ParamMap): Learner = { @@ -113,7 +112,6 @@ abstract class Predictor[ * * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. */ - @DeveloperApi private[ml] def featuresDataType: DataType = new VectorUDT override def transformSchema(schema: StructType): StructType = { @@ -134,7 +132,6 @@ abstract class Predictor[ /** * :: DeveloperApi :: - * * Abstraction for a model for prediction tasks (regression and classification). * * @tparam FeaturesType Type of features. diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index d96b54e511e9c..f07f733a5ddb5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml import scala.annotation.varargs import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.sql.DataFrame @@ -28,11 +28,11 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for transformers that transform one dataset into another. */ -@AlphaComponent -abstract class Transformer extends PipelineStage with Params { +@DeveloperApi +abstract class Transformer extends PipelineStage { /** * Transforms the dataset with optional parameters @@ -73,10 +73,12 @@ abstract class Transformer extends PipelineStage with Params { } /** + * :: DeveloperApi :: * Abstract class for transformers that take one input column, apply transformation, and output the * result as a new column. */ -private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] +@DeveloperApi +abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] extends Transformer with HasInputCol with HasOutputCol with Logging { /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index f5f37aa77929c..457c15830fd38 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -19,10 +19,12 @@ package org.apache.spark.ml.attribute import scala.collection.mutable.ArrayBuffer +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} /** + * :: DeveloperApi :: * Attributes that describe a vector ML column. * * @param name name of the attribute group (the ML column name) @@ -31,6 +33,7 @@ import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} * @param attrs optional array of attributes. Attribute will be copied with their corresponding * indices in the array. */ +@DeveloperApi class AttributeGroup private ( val name: String, val numAttributes: Option[Int], @@ -182,7 +185,11 @@ class AttributeGroup private ( } } -/** Factory methods to create attribute groups. */ +/** + * :: DeveloperApi :: + * Factory methods to create attribute groups. + */ +@DeveloperApi object AttributeGroup { import AttributeKeys._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala index a83febd7de2cc..5c7089b491677 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala @@ -17,12 +17,17 @@ package org.apache.spark.ml.attribute +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: * An enum-like type for attribute types: [[AttributeType$#Numeric]], [[AttributeType$#Nominal]], * and [[AttributeType$#Binary]]. */ +@DeveloperApi sealed abstract class AttributeType(val name: String) +@DeveloperApi object AttributeType { /** Numeric type. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index e8f7f152784a1..ce43a450daad0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -19,11 +19,14 @@ package org.apache.spark.ml.attribute import scala.annotation.varargs +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, StructField} /** + * :: DeveloperApi :: * Abstract class for ML attributes. */ +@DeveloperApi sealed abstract class Attribute extends Serializable { name.foreach { n => @@ -135,6 +138,10 @@ private[attribute] trait AttributeFactory { } } +/** + * :: DeveloperApi :: + */ +@DeveloperApi object Attribute extends AttributeFactory { private[attribute] override def fromMetadata(metadata: Metadata): Attribute = { @@ -163,6 +170,7 @@ object Attribute extends AttributeFactory { /** + * :: DeveloperApi :: * A numeric attribute with optional summary statistics. * @param name optional name * @param index optional index @@ -171,6 +179,7 @@ object Attribute extends AttributeFactory { * @param std optional standard deviation * @param sparsity optional sparsity (ratio of zeros) */ +@DeveloperApi class NumericAttribute private[ml] ( override val name: Option[String] = None, override val index: Option[Int] = None, @@ -278,8 +287,10 @@ class NumericAttribute private[ml] ( } /** + * :: DeveloperApi :: * Factory methods for numeric attributes. */ +@DeveloperApi object NumericAttribute extends AttributeFactory { /** The default numeric attribute. */ @@ -298,6 +309,7 @@ object NumericAttribute extends AttributeFactory { } /** + * :: DeveloperApi :: * A nominal attribute. * @param name optional name * @param index optional index @@ -306,6 +318,7 @@ object NumericAttribute extends AttributeFactory { * defined. * @param values optional values. At most one of `numValues` and `values` can be defined. */ +@DeveloperApi class NominalAttribute private[ml] ( override val name: Option[String] = None, override val index: Option[Int] = None, @@ -430,7 +443,11 @@ class NominalAttribute private[ml] ( } } -/** Factory methods for nominal attributes. */ +/** + * :: DeveloperApi :: + * Factory methods for nominal attributes. + */ +@DeveloperApi object NominalAttribute extends AttributeFactory { /** The default nominal attribute. */ @@ -450,11 +467,13 @@ object NominalAttribute extends AttributeFactory { } /** + * :: DeveloperApi :: * A binary attribute. * @param name optional name * @param index optional index * @param values optionla values. If set, its size must be 2. */ +@DeveloperApi class BinaryAttribute private[ml] ( override val name: Option[String] = None, override val index: Option[Int] = None, @@ -526,7 +545,11 @@ class BinaryAttribute private[ml] ( } } -/** Factory methods for binary attributes. */ +/** + * :: DeveloperApi :: + * Factory methods for binary attributes. + */ +@DeveloperApi object BinaryAttribute extends AttributeFactory { /** The default binary attribute. */ @@ -543,8 +566,10 @@ object BinaryAttribute extends AttributeFactory { } /** + * :: DeveloperApi :: * An unresolved attribute. */ +@DeveloperApi object UnresolvedAttribute extends Attribute { override def attrType: AttributeType = AttributeType.Unresolved diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index dcebea1d4b015..8030e0728a56c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} @@ -31,18 +31,19 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm * for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ -@AlphaComponent -final class DecisionTreeClassifier +@Experimental +final class DecisionTreeClassifier(override val uid: String) extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { + def this() = this(Identifiable.randomUID("dtc")) + // Override parameter setters from parent trait for Java API compatibility. override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) @@ -87,21 +88,21 @@ final class DecisionTreeClassifier } } +@Experimental object DecisionTreeClassifier { /** Accessor for supported impurities: entropy, gini */ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ -@AlphaComponent +@Experimental final class DecisionTreeClassificationModel private[ml] ( - override val parent: DecisionTreeClassifier, + override val uid: String, override val rootNode: Node) extends PredictionModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { @@ -114,7 +115,7 @@ final class DecisionTreeClassificationModel private[ml] ( } override def copy(extra: ParamMap): DecisionTreeClassificationModel = { - copyValues(new DecisionTreeClassificationModel(parent, rootNode), extra) + copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra) } override def toString: String = { @@ -138,6 +139,7 @@ private[ml] object DecisionTreeClassificationModel { s"Cannot convert non-classification DecisionTreeModel (old API) to" + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) - new DecisionTreeClassificationModel(parent, rootNode) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") + new DecisionTreeClassificationModel(uid, rootNode) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index ae51b05a0c42d..62f4b51f770e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} @@ -36,18 +36,19 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for classification. * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. */ -@AlphaComponent -final class GBTClassifier +@Experimental +final class GBTClassifier(override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] with GBTParams with TreeClassifierParams with Logging { + def this() = this(Identifiable.randomUID("gbtc")) + // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: @@ -142,6 +143,7 @@ final class GBTClassifier } } +@Experimental object GBTClassifier { // The losses below should be lowercase. /** Accessor for supported loss settings: logistic */ @@ -149,8 +151,7 @@ object GBTClassifier { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * model for classification. * It supports binary labels, as well as both continuous and categorical features. @@ -158,9 +159,9 @@ object GBTClassifier { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ -@AlphaComponent +@Experimental final class GBTClassificationModel( - override val parent: GBTClassifier, + override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double]) extends PredictionModel[Vector, GBTClassificationModel] @@ -184,7 +185,7 @@ final class GBTClassificationModel( } override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(parent, _trees, _treeWeights), extra) + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra) } override def toString: String = { @@ -207,9 +208,10 @@ private[ml] object GBTClassificationModel { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new GBTClassificationModel(parent, newTrees, oldModel.treeWeights) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") + new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 93ba91167bfad..f136bcee9cf2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -19,13 +19,14 @@ package org.apache.spark.ml.classification import scala.collection.mutable -import breeze.linalg.{norm => brzNorm, DenseVector => BDV} -import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import breeze.optimize.{CachedDiffFunction, DiffFunction} +import breeze.linalg.{DenseVector => BDV, norm => brzNorm} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.regression.LabeledPoint @@ -34,7 +35,6 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel -import org.apache.spark.{SparkException, Logging} /** * Params for logistic regression. @@ -44,16 +44,17 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas with HasThreshold /** - * :: AlphaComponent :: - * + * :: Experimental :: * Logistic regression. * Currently, this class only supports binary classification. */ -@AlphaComponent -class LogisticRegression +@Experimental +class LogisticRegression(override val uid: String) extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams with Logging { + def this() = this(Identifiable.randomUID("logreg")) + /** * Set the regularization parameter. * Default is 0.0. @@ -73,7 +74,7 @@ class LogisticRegression setDefault(elasticNetParam -> 0.0) /** - * Set the maximal number of iterations. + * Set the maximum number of iterations. * Default is 100. * @group setParam */ @@ -89,7 +90,11 @@ class LogisticRegression def setTol(value: Double): this.type = set(tol, value) setDefault(tol -> 1E-6) - /** @group setParam */ + /** + * Whether to fit an intercept term. + * Default is true. + * @group setParam + * */ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) setDefault(fitIntercept -> true) @@ -213,18 +218,17 @@ class LogisticRegression (weightsWithIntercept, 0.0) } - new LogisticRegressionModel(this, weights.compressed, intercept) + new LogisticRegressionModel(uid, weights.compressed, intercept) } } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Model produced by [[LogisticRegression]]. */ -@AlphaComponent +@Experimental class LogisticRegressionModel private[ml] ( - override val parent: LogisticRegression, + override val uid: String, val weights: Vector, val intercept: Double) extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] @@ -258,7 +262,8 @@ class LogisticRegressionModel private[ml] ( rawPrediction match { case dv: DenseVector => var i = 0 - while (i < dv.size) { + val size = dv.size + while (i < size) { dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i))) i += 1 } @@ -275,7 +280,7 @@ class LogisticRegressionModel private[ml] ( } override def copy(extra: ParamMap): LogisticRegressionModel = { - copyValues(new LogisticRegressionModel(parent, weights, intercept), extra) + copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) } override protected def raw2prediction(rawPrediction: Vector): Double = { @@ -357,7 +362,8 @@ private[classification] class MultiClassSummarizer extends Serializable { def histogram: Array[Long] = { val result = Array.ofDim[Long](numClasses) var i = 0 - while (i < result.length) { + val len = result.length + while (i < len) { result(i) = distinctMap.getOrElse(i, 0L) i += 1 } @@ -480,7 +486,8 @@ private class LogisticAggregator( var i = 0 val localThisGradientSumArray = this.gradientSumArray val localOtherGradientSumArray = other.gradientSumArray - while (i < localThisGradientSumArray.length) { + val len = localThisGradientSumArray.length + while (i < len) { localThisGradientSumArray(i) += localOtherGradientSumArray(i) i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index afb8d75d57384..825f9ed1b54b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -21,11 +21,11 @@ import java.util.UUID import scala.language.existentials -import org.apache.spark.annotation.{AlphaComponent, Experimental} +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.Param -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -37,27 +37,26 @@ import org.apache.spark.storage.StorageLevel */ private[ml] trait OneVsRestParams extends PredictorParams { + // scalastyle:off structural.type type ClassifierType = Classifier[F, E, M] forSome { type F type M <: ClassificationModel[F, M] - type E <: Classifier[F, E, M] + type E <: Classifier[F, E, M] } + // scalastyle:on structural.type /** * param for the base binary classifier that we reduce multiclass classification into. * @group param */ - val classifier: Param[ClassifierType] = - new Param(this, "classifier", "base binary classifier ") + val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier") /** @group getParam */ def getClassifier: ClassifierType = $(classifier) - } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Model produced by [[OneVsRest]]. * This stores the models resulting from training k binary classifiers: one for each class. * Each example is scored against all k models, and the model with the highest score @@ -69,11 +68,11 @@ private[ml] trait OneVsRestParams extends PredictorParams { * The i-th model is produced by testing the i-th class (taking label 1) vs the rest * (taking label 0). */ -@AlphaComponent -class OneVsRestModel private[ml] ( - override val parent: OneVsRest, - labelMetadata: Metadata, - val models: Array[_ <: ClassificationModel[_,_]]) +@Experimental +final class OneVsRestModel private[ml] ( + override val uid: String, + labelMetadata: Metadata, + val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams { override def transformSchema(schema: StructType): StructType = { @@ -107,17 +106,17 @@ class OneVsRestModel private[ml] ( // add temporary column to store intermediate scores and update val tmpColName = "mbc$tmp" + UUID.randomUUID().toString - val update: (Map[Int, Double], Vector) => Map[Int, Double] = + val update: (Map[Int, Double], Vector) => Map[Int, Double] = (predictions: Map[Int, Double], prediction: Vector) => { predictions + ((index, prediction(1))) } val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol)) - val transformedDataset = model.transform(df).select(columns:_*) + val transformedDataset = model.transform(df).select(columns : _*) val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf) val newColumns = origCols ++ List(col(tmpColName)) // switch out the intermediate column with the accumulator column - updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName) + updatedDataset.select(newColumns : _*).withColumnRenamed(tmpColName, accColName) } if (handlePersistence) { @@ -132,6 +131,7 @@ class OneVsRestModel private[ml] ( // output label and label metadata as prediction val labelUdf = callUDF(label, DoubleType, col(accColName)) aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) + .drop(accColName) } } @@ -145,11 +145,13 @@ class OneVsRestModel private[ml] ( * is picked to label the example. */ @Experimental -final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams { +final class OneVsRest(override val uid: String) + extends Estimator[OneVsRestModel] with OneVsRestParams { + + def this() = this(Identifiable.randomUID("oneVsRest")) /** @group setParam */ - def setClassifier(value: Classifier[_,_,_]): this.type = { - // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed + def setClassifier(value: Classifier[_, _, _]): this.type = { set(classifier, value.asInstanceOf[ClassifierType]) } @@ -191,7 +193,7 @@ final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams { val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) val classifier = getClassifier classifier.fit(trainingDataset, classifier.labelCol -> labelColName) - }.toArray[ClassificationModel[_,_]] + }.toArray[ClassificationModel[_, _]] if (handlePersistence) { multiclassLabeled.unpersist() @@ -204,6 +206,7 @@ final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams { NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) case attr: Attribute => attr } - copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models)) + val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this) + copyValues(model) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 9954893f14359..852a67e066322 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.classification import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} @@ -33,18 +33,19 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for * classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. */ -@AlphaComponent -final class RandomForestClassifier +@Experimental +final class RandomForestClassifier(override val uid: String) extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { + def this() = this(Identifiable.randomUID("rfc")) + // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeClassifierParams: @@ -98,6 +99,7 @@ final class RandomForestClassifier } } +@Experimental object RandomForestClassifier { /** Accessor for supported impurity settings: entropy, gini */ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities @@ -108,17 +110,16 @@ object RandomForestClassifier { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. * @param _trees Decision trees in the ensemble. * Warning: These have null parents. */ -@AlphaComponent +@Experimental final class RandomForestClassificationModel private[ml] ( - override val parent: RandomForestClassifier, + override val uid: String, private val _trees: Array[DecisionTreeClassificationModel]) extends PredictionModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { @@ -146,7 +147,7 @@ final class RandomForestClassificationModel private[ml] ( } override def copy(extra: ParamMap): RandomForestClassificationModel = { - copyValues(new RandomForestClassificationModel(parent, _trees), extra) + copyValues(new RandomForestClassificationModel(uid, _trees), extra) } override def toString: String = { @@ -169,9 +170,10 @@ private[ml] object RandomForestClassificationModel { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestClassificationModel(parent, newTrees) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") + new RandomForestClassificationModel(uid, newTrees) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index e5a73c6087a11..f695ddaeefc72 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -17,23 +17,24 @@ package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.Evaluator +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** - * :: AlphaComponent :: - * + * :: Experimental :: * Evaluator for binary classification, which expects two input columns: score and label. */ -@AlphaComponent -class BinaryClassificationEvaluator extends Evaluator with HasRawPredictionCol with HasLabelCol { +@Experimental +class BinaryClassificationEvaluator(override val uid: String) + extends Evaluator with HasRawPredictionCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("binEval")) /** * param for metric name in evaluation diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala similarity index 89% rename from mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala rename to mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index 5f2f8c94e9ff7..61e937e693699 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -15,21 +15,21 @@ * limitations under the License. */ -package org.apache.spark.ml +package org.apache.spark.ml.evaluation -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Abstract class for evaluators that compute metrics from predictions. */ -@AlphaComponent +@DeveloperApi abstract class Evaluator extends Params { /** - * Evaluates the output. + * Evaluates model output and returns a scalar metric (larger is better). * * @param dataset a dataset that contains labels/observations and predictions. * @param paramMap parameter map that specifies the input columns and output metrics diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala new file mode 100644 index 0000000000000..abb1b35bedea5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{Param, ParamValidators} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.DoubleType + +/** + * :: Experimental :: + * Evaluator for regression, which expects two input columns: prediction and label. + */ +@Experimental +final class RegressionEvaluator(override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("regEval")) + + /** + * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`) + * @group param + */ + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae")) + new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", allowedParams) + } + + /** @group getParam */ + def getMetricName: String = $(metricName) + + /** @group setParam */ + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + setDefault(metricName -> "rmse") + + override def evaluate(dataset: DataFrame): Double = { + val schema = dataset.schema + SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + .map { case Row(prediction: Double, label: Double) => + (prediction, label) + } + val metrics = new RegressionMetrics(predictionAndLabels) + val metric = $(metricName) match { + case "rmse" => + metrics.rootMeanSquaredError + case "mse" => + metrics.meanSquaredError + case "r2" => + metrics.r2 + case "mae" => + metrics.meanAbsoluteError + } + metric + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 6eb1db6971111..b06122d733853 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,22 +17,25 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * Binarize a column of continuous features given a threshold. */ -@AlphaComponent -final class Binarizer extends Transformer with HasInputCol with HasOutputCol { +@Experimental +final class Binarizer(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("binarizer")) /** * Param for threshold used to binarize continuous features. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index b28c88aaaecbc..a3d1f6f65ccaf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -20,25 +20,25 @@ package org.apache.spark.ml.feature import java.{util => ju} import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * `Bucketizer` maps a column of continuous features to a column of feature buckets. */ -@AlphaComponent -final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer]) +@Experimental +final class Bucketizer(override val uid: String) extends Model[Bucketizer] with HasInputCol with HasOutputCol { - def this() = this(null) + def this() = this(Identifiable.randomUID("bucketizer")) /** * Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets. @@ -48,7 +48,7 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer]) * otherwise, values outside the splits specified will be treated as errors. * @group param */ - val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits", + val splits: DoubleArrayParam = new DoubleArrayParam(this, "splits", "Split points for mapping continuous features into buckets. With n+1 splits, there are n " + "buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " + "bucket, which also includes y. The splits should be strictly increasing. " + @@ -98,7 +98,8 @@ private[feature] object Bucketizer { false } else { var i = 0 - while (i < splits.length - 1) { + val n = splits.length - 1 + while (i < n) { if (splits(i) >= splits(i + 1)) return false i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index f8b56293e3ccc..1e758cb775de7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -17,27 +17,31 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.Param +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType /** - * :: AlphaComponent :: + * :: Experimental :: * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a * provided "weight" vector. In other words, it scales each column of the dataset by a scalar * multiplier. */ -@AlphaComponent -class ElementwiseProduct extends UnaryTransformer[Vector, Vector, ElementwiseProduct] { +@Experimental +class ElementwiseProduct(override val uid: String) + extends UnaryTransformer[Vector, Vector, ElementwiseProduct] { + + def this() = this(Identifiable.randomUID("elemProd")) /** * the vector to multiply with input vectors * @group param */ - val scalingVec: Param[Vector] = new Param(this, "scalingVector", "vector for hadamard product") + val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product") /** @group setParam */ def setScalingVec(value: Vector): this.type = set(scalingVec, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index c305a819a8966..f936aef80f8af 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,19 +17,31 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{ArrayType, StructType} /** - * :: AlphaComponent :: + * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. */ -@AlphaComponent -class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { +@Experimental +class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { + + def this() = this(Identifiable.randomUID("hashingTF")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) /** * Number of features. Should be > 0. @@ -47,10 +59,19 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) - override protected def createTransformFunc: Iterable[_] => Vector = { + override def transform(dataset: DataFrame): DataFrame = { + val outputSchema = transformSchema(dataset.schema) val hashingTF = new feature.HashingTF($(numFeatures)) - hashingTF.transform + val t = udf { terms: Seq[_] => hashingTF.transform(terms) } + val metadata = outputSchema($(outputCol)).metadata + dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) } - override protected def outputDataType: DataType = new VectorUDT() + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[ArrayType], + s"The input column must be ArrayType, but got $inputType.") + val attrGroup = new AttributeGroup($(outputCol), $(numFeatures)) + SchemaUtils.appendColumn(schema, attrGroup.toStructField()) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index d901a20aed002..376b84530cd57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -58,11 +58,13 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol } /** - * :: AlphaComponent :: + * :: Experimental :: * Compute the Inverse Document Frequency (IDF) given a collection of documents. */ -@AlphaComponent -final class IDF extends Estimator[IDFModel] with IDFBase { +@Experimental +final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase { + + def this() = this(Identifiable.randomUID("idf")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -74,7 +76,7 @@ final class IDF extends Estimator[IDFModel] with IDFBase { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val idf = new feature.IDF($(minDocFreq)).fit(input) - copyValues(new IDFModel(this, idf)) + copyValues(new IDFModel(uid, idf).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -83,12 +85,12 @@ final class IDF extends Estimator[IDFModel] with IDFBase { } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[IDF]]. */ -@AlphaComponent +@Experimental class IDFModel private[ml] ( - override val parent: IDF, + override val uid: String, idfModel: feature.IDFModel) extends Model[IDFModel] with IDFBase { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 755b46a64c7f1..8282e5ffa17f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -17,19 +17,22 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{DoubleParam, ParamValidators} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType /** - * :: AlphaComponent :: + * :: Experimental :: * Normalize a vector to have unit norm using the given p-norm. */ -@AlphaComponent -class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { +@Experimental +class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] { + + def this() = this(Identifiable.randomUID("normalizer")) /** * Normalization in L^p^ space. Must be >= 1. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 46514ae5f0e84..8f34878c8d329 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,91 +17,152 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute} -import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{DoubleType, StructType} /** - * A one-hot encoder that maps a column of label indices to a column of binary vectors, with - * at most a single one-value. By default, the binary vector has an element for each category, so - * with 5 categories, an input value of 2.0 would map to an output vector of - * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the - * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value - * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns - * linearly dependent because they sum up to one. + * :: Experimental :: + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with + * at most a single one-value per row that indicates the input category index. + * For example with 5 categories, an input value of 2.0 would map to an output vector of + * `[0.0, 0.0, 1.0, 0.0]`. + * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]] + * because it makes the vector entries sum up to one, and hence linearly dependent. + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories. + * The output vectors are sparse. + * + * @see [[StringIndexer]] for converting categorical values into category indices */ -@AlphaComponent -class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder] +@Experimental +class OneHotEncoder(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { + def this() = this(Identifiable.randomUID("oneHot")) + /** - * Whether to include a component in the encoded vectors for the first category, defaults to true. + * Whether to drop the last category in the encoded vector (default: true) * @group param */ - final val includeFirst: BooleanParam = - new BooleanParam(this, "includeFirst", "include first category") - setDefault(includeFirst -> true) - - private var categories: Array[String] = _ + final val dropLast: BooleanParam = + new BooleanParam(this, "dropLast", "whether to drop the last category") + setDefault(dropLast -> true) /** @group setParam */ - def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value) + def setDropLast(value: Boolean): this.type = set(dropLast, value) /** @group setParam */ - override def setInputCol(value: String): this.type = set(inputCol, value) + def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ - override def setOutputCol(value: String): this.type = set(outputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) - val inputFields = schema.fields + val is = "_is_" + val inputColName = $(inputCol) val outputColName = $(outputCol) - require(inputFields.forall(_.name != $(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val inputColAttr = Attribute.fromStructField(schema($(inputCol))) - categories = inputColAttr match { + SchemaUtils.checkColumnType(schema, inputColName, DoubleType) + val inputFields = schema.fields + require(!inputFields.exists(_.name == outputColName), + s"Output column $outputColName already exists.") + + val inputAttr = Attribute.fromStructField(schema(inputColName)) + val outputAttrNames: Option[Array[String]] = inputAttr match { case nominal: NominalAttribute => - nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray) - case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1")) + if (nominal.values.isDefined) { + nominal.values.map(_.map(v => inputColName + is + v)) + } else if (nominal.numValues.isDefined) { + nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i)) + } else { + None + } + case binary: BinaryAttribute => + if (binary.values.isDefined) { + binary.values.map(_.map(v => inputColName + is + v)) + } else { + Some(Array.tabulate(2)(i => inputColName + is + i)) + } + case _: NumericAttribute => + throw new RuntimeException( + s"The input column $inputColName cannot be numeric.") case _ => - throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal") + None // optimistic about unknown attributes + } + + val filteredOutputAttrNames = outputAttrNames.map { names => + if ($(dropLast)) { + require(names.length > 1, + s"The input column $inputColName should have at least two distinct values.") + names.dropRight(1) + } else { + names + } } - val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray - val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues) - val outputFields = inputFields :+ attr.toStructField() + val outputAttrGroup = if (filteredOutputAttrNames.isDefined) { + val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name => + BinaryAttribute.defaultAttr.withName(name) + } + new AttributeGroup($(outputCol), attrs) + } else { + new AttributeGroup($(outputCol)) + } + + val outputFields = inputFields :+ outputAttrGroup.toStructField() StructType(outputFields) } - protected override def createTransformFunc(): (Double) => Vector = { - val first = $(includeFirst) - val vecLen = if (first) categories.length else categories.length - 1 + override def transform(dataset: DataFrame): DataFrame = { + // schema transformation + val is = "_is_" + val inputColName = $(inputCol) + val outputColName = $(outputCol) + val shouldDropLast = $(dropLast) + var outputAttrGroup = AttributeGroup.fromStructField( + transformSchema(dataset.schema)(outputColName)) + if (outputAttrGroup.size < 0) { + // If the number of attributes is unknown, we check the values from the input column. + val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).map(_.getDouble(0)) + .aggregate(0.0)( + (m, x) => { + assert(x >=0.0 && x == x.toInt, + s"Values from column $inputColName must be indices, but got $x.") + math.max(m, x) + }, + (m0, m1) => { + math.max(m0, m1) + } + ).toInt + 1 + val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i) + val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames + val outputAttrs: Array[Attribute] = + filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) + outputAttrGroup = new AttributeGroup(outputColName, outputAttrs) + } + val metadata = outputAttrGroup.toMetadata() + + // data transformation + val size = outputAttrGroup.size val oneValue = Array(1.0) val emptyValues = Array[Double]() val emptyIndices = Array[Int]() - label: Double => { - val values = if (first || label != 0.0) oneValue else emptyValues - val indices = if (first) { - Array(label.toInt) - } else if (label != 0.0) { - Array(label.toInt - 1) + val encode = udf { label: Double => + if (label < size) { + Vectors.sparse(size, Array(label.toInt), oneValue) } else { - emptyIndices + Vectors.sparse(size, emptyIndices, emptyValues) } - Vectors.sparse(vecLen, indices, values) } - } - /** - * Returns the data type of the output column. - */ - protected def outputDataType: DataType = new VectorUDT + dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata)) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 9e6177ca27e4a..442e95820217a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,22 +19,26 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{IntParam, ParamValidators} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType /** - * :: AlphaComponent :: + * :: Experimental :: * Perform feature expansion in a polynomial space. As said in wikipedia of Polynomial Expansion, * which is available at [[http://en.wikipedia.org/wiki/Polynomial_expansion]], "In mathematics, an * expansion of a product of sums expresses it as a sum of products by using the fact that * multiplication distributes over addition". Take a 2-variable feature vector as an example: * `(x, y)`, if we want to expand it with degree 2, then we get `(x, x * x, y, x * y, y * y)`. */ -@AlphaComponent -class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { +@Experimental +class PolynomialExpansion(override val uid: String) + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { + + def this() = this(Identifiable.randomUID("poly")) /** * The polynomial degree to expand, which should be >= 1. A value of 1 means no expansion. @@ -71,7 +75,7 @@ class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExp * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the * current index and increment it properly for sparse input. */ -object PolynomialExpansion { +private[feature] object PolynomialExpansion { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 7cad59ff3fa37..b0fd06d84fdb3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -17,10 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ @@ -34,13 +35,13 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with /** * Centers the data with mean before scaling. - * It will build a dense output, so this does not work on sparse input + * It will build a dense output, so this does not work on sparse input * and will raise an exception. * Default: false * @group param */ val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") - + /** * Scales the data to unit standard deviation. * Default: true @@ -50,12 +51,15 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with } /** - * :: AlphaComponent :: + * :: Experimental :: * Standardizes features by removing the mean and scaling to unit variance using column summary * statistics on the samples in the training set. */ -@AlphaComponent -class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { +@Experimental +class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel] + with StandardScalerParams { + + def this() = this(Identifiable.randomUID("stdScal")) setDefault(withMean -> false, withStd -> true) @@ -64,19 +68,19 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - + /** @group setParam */ def setWithMean(value: Boolean): this.type = set(withMean, value) - + /** @group setParam */ def setWithStd(value: Boolean): this.type = set(withStd, value) - + override def fit(dataset: DataFrame): StandardScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) val scalerModel = scaler.fit(input) - copyValues(new StandardScalerModel(this, scalerModel)) + copyValues(new StandardScalerModel(uid, scalerModel).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -91,12 +95,12 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[StandardScaler]]. */ -@AlphaComponent +@Experimental class StandardScalerModel private[ml] ( - override val parent: StandardScaler, + override val uid: String, scaler: feature.StandardScalerModel) extends Model[StandardScalerModel] with StandardScalerParams { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 3d78537ad84cb..a2dc8a8b960c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -18,11 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{NumericType, StringType, StructType} @@ -51,14 +52,17 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha } /** - * :: AlphaComponent :: + * :: Experimental :: * A label indexer that maps a string column of labels to an ML column of label indices. * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. */ -@AlphaComponent -class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase { +@Experimental +class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel] + with StringIndexerBase { + + def this() = this(Identifiable.randomUID("strIdx")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -73,7 +77,7 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase .map(_.getString(0)) .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray - copyValues(new StringIndexerModel(this, labels)) + copyValues(new StringIndexerModel(uid, labels).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -82,12 +86,12 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[StringIndexer]]. */ -@AlphaComponent +@Experimental class StringIndexerModel private[ml] ( - override val parent: StringIndexer, + override val uid: String, labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { private val labelToIndex: OpenHashMap[String, Double] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 649c217b16590..21c15b6c33f6c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,17 +17,22 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** - * :: AlphaComponent :: + * :: Experimental :: * A tokenizer that converts the input string to lowercase and then splits it by white spaces. + * + * @see [[RegexTokenizer]] */ -@AlphaComponent -class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { +@Experimental +class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] { + + def this() = this(Identifiable.randomUID("tok")) override protected def createTransformFunc: String => Seq[String] = { _.toLowerCase.split("\\s") @@ -41,21 +46,24 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { } /** - * :: AlphaComponent :: - * A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default) - * or using it to split the text (set matching to false). Optional parameters also allow filtering - * tokens using a minimal length. + * :: Experimental :: + * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split + * the text (default) or repeatedly matching the regex (if `gaps` is true). + * Optional parameters also allow filtering tokens using a minimal length. * It returns an array of strings that can be empty. */ -@AlphaComponent -class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] { +@Experimental +class RegexTokenizer(override val uid: String) + extends UnaryTransformer[String, Seq[String], RegexTokenizer] { + + def this() = this(Identifiable.randomUID("regexTok")) /** * Minimum token length, >= 0. * Default: 1, to avoid returning empty strings * @group param */ - val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length (>= 0)", + val minTokenLength: IntParam = new IntParam(this, "minTokenLength", "minimum token length (>= 0)", ParamValidators.gtEq(0)) /** @group setParam */ @@ -65,8 +73,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def getMinTokenLength: Int = $(minTokenLength) /** - * Indicates whether regex splits on gaps (true) or matching tokens (false). - * Default: false + * Indicates whether regex splits on gaps (true) or matches tokens (false). + * Default: true * @group param */ val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens") @@ -78,8 +86,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def getGaps: Boolean = $(gaps) /** - * Regex pattern used by tokenizer. - * Default: `"\\p{L}+|[^\\p{L}\\s]+"` + * Regex pattern used to match delimiters if [[gaps]] is true or tokens if [[gaps]] is false. + * Default: `"\\s+"` * @group param */ val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing") @@ -90,7 +98,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize /** @group getParam */ def getPattern: String = $(pattern) - setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+") + setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+") override protected def createTransformFunc: String => Seq[String] = { str => val re = $(pattern).r diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 796758a70ef18..229ee27ec5942 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -20,20 +20,25 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ /** - * :: AlphaComponent :: + * :: Experimental :: * A feature transformer that merges multiple columns into a vector column. */ -@AlphaComponent -class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { +@Experimental +class VectorAssembler(override val uid: String) + extends Transformer with HasInputCols with HasOutputCol { + + def this() = this(Identifiable.randomUID("vecAssembler")) /** @group setParam */ def setInputCols(value: Array[String]): this.type = set(inputCols, value) @@ -42,19 +47,59 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { + // Schema transformation. + val schema = dataset.schema + lazy val first = dataset.first() + val attrs = $(inputCols).flatMap { c => + val field = schema(c) + val index = schema.fieldIndex(c) + field.dataType match { + case DoubleType => + val attr = Attribute.fromStructField(field) + // If the input column doesn't have ML attribute, assume numeric. + if (attr == UnresolvedAttribute) { + Some(NumericAttribute.defaultAttr.withName(c)) + } else { + Some(attr.withName(c)) + } + case _: NumericType | BooleanType => + // If the input column type is a compatible scalar type, assume numeric. + Some(NumericAttribute.defaultAttr.withName(c)) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(field) + if (group.attributes.isDefined) { + // If attributes are defined, copy them with updated names. + group.attributes.get.map { attr => + if (attr.name.isDefined) { + // TODO: Define a rigorous naming scheme. + attr.withName(c + "_" + attr.name.get) + } else { + attr + } + } + } else { + // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes + // from metadata, check the first row. + val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) + Array.fill(numAttrs)(NumericAttribute.defaultAttr) + } + } + } + val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() + + // Data transformation. val assembleFunc = udf { r: Row => VectorAssembler.assemble(r.toSeq: _*) } - val schema = dataset.schema - val inputColNames = $(inputCols) - val args = inputColNames.map { c => + val args = $(inputCols).map { c => schema(c).dataType match { case DoubleType => dataset(c) case _: VectorUDT => dataset(c) case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } - dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol))) + + dataset.select(col("*"), assembleFunc(struct(args : _*)).as($(outputCol), metadata)) } override def transformSchema(schema: StructType): StructType = { @@ -74,8 +119,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { } } -@AlphaComponent -object VectorAssembler { +private object VectorAssembler { private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 2e6313ac14485..1d0f23b4fb3db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -17,12 +17,17 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import java.lang.{Double => JDouble, Integer => JInt} +import java.util.{Map => JMap} + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.{IntParam, ParamValidators, Params} import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.callUDF @@ -51,8 +56,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Class for indexing categorical feature columns in a dataset of [[Vector]]. * * This has 2 usage modes: @@ -86,8 +90,11 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * - Add warning if a categorical feature has only 1 category. * - Add option for allowing unknown categories. */ -@AlphaComponent -class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams { +@Experimental +class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel] + with VectorIndexerParams { + + def this() = this(Identifiable.randomUID("vecIdx")) /** @group setParam */ def setMaxCategories(value: Int): this.type = set(maxCategories, value) @@ -110,7 +117,9 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara iter.foreach(localCatStats.addVector) Iterator(localCatStats) }.reduce((stats1, stats2) => stats1.merge(stats2)) - copyValues(new VectorIndexerModel(this, numFeatures, categoryStats.getCategoryMaps)) + val model = new VectorIndexerModel(uid, numFeatures, categoryStats.getCategoryMaps) + .setParent(this) + copyValues(model) } override def transformSchema(schema: StructType): StructType = { @@ -189,7 +198,8 @@ private object VectorIndexer { private def addDenseVector(dv: DenseVector): Unit = { var i = 0 - while (i < dv.size) { + val size = dv.size + while (i < size) { if (featureValueSets(i).size <= maxCategories) { featureValueSets(i).add(dv(i)) } @@ -201,7 +211,8 @@ private object VectorIndexer { // TODO: This might be able to handle 0's more efficiently. var vecIndex = 0 // index into vector var k = 0 // index into non-zero elements - while (vecIndex < sv.size) { + val size = sv.size + while (vecIndex < size) { val featureValue = if (k < sv.indices.length && vecIndex == sv.indices(k)) { k += 1 sv.values(k - 1) @@ -218,8 +229,7 @@ private object VectorIndexer { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Transform categorical features to use 0-based indices instead of their original values. * - Categorical features are mapped to indices. * - Continuous features (columns) are left unchanged. @@ -234,13 +244,18 @@ private object VectorIndexer { * Values are maps from original features values to 0-based category indices. * If a feature is not in this map, it is treated as continuous. */ -@AlphaComponent +@Experimental class VectorIndexerModel private[ml] ( - override val parent: VectorIndexer, + override val uid: String, val numFeatures: Int, val categoryMaps: Map[Int, Map[Double, Int]]) extends Model[VectorIndexerModel] with VectorIndexerParams { + /** Java-friendly version of [[categoryMaps]] */ + def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = { + categoryMaps.mapValues(_.asJava).asJava.asInstanceOf[JMap[JInt, JMap[JDouble, JInt]]] + } + /** * Pre-computed feature attributes, with some missing info. * In transform(), set attribute name and other info, if available. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 34ff92970129f..36f19509f0cfb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} import org.apache.spark.mllib.linalg.BLAS._ @@ -37,6 +37,7 @@ private[feature] trait Word2VecBase extends Params /** * The dimension of the code that you want to transform from words. + * @group param */ final val vectorSize = new IntParam( this, "vectorSize", "the dimension of codes after transforming from words") @@ -47,6 +48,7 @@ private[feature] trait Word2VecBase extends Params /** * Number of partitions for sentences of words. + * @group param */ final val numPartitions = new IntParam( this, "numPartitions", "number of partitions for sentences of words") @@ -58,6 +60,7 @@ private[feature] trait Word2VecBase extends Params /** * The minimum number of times a token must appear to be included in the word2vec model's * vocabulary. + * @group param */ final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + "appear to be included in the word2vec model's vocabulary") @@ -68,7 +71,6 @@ private[feature] trait Word2VecBase extends Params setDefault(stepSize -> 0.025) setDefault(maxIter -> 1) - setDefault(seed -> 42L) /** * Validate and transform the input schema. @@ -80,12 +82,14 @@ private[feature] trait Word2VecBase extends Params } /** - * :: AlphaComponent :: + * :: Experimental :: * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further * natural language processing or machine learning process. */ -@AlphaComponent -final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase { +@Experimental +final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase { + + def this() = this(Identifiable.randomUID("w2v")) /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) @@ -122,7 +126,7 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase { .setSeed($(seed)) .setVectorSize($(vectorSize)) .fit(input) - copyValues(new Word2VecModel(this, wordVectors)) + copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -131,12 +135,12 @@ final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase { } /** - * :: AlphaComponent :: + * :: Experimental :: * Model fitted by [[Word2Vec]]. */ -@AlphaComponent +@Experimental class Word2VecModel private[ml] ( - override val parent: Word2Vec, + override val uid: String, wordVectors: feature.Word2VecModel) extends Model[Word2VecModel] with Word2VecBase { diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java index 00d9c802e930d..87f4223964ada 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package-info.java +++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java @@ -16,10 +16,10 @@ */ /** - * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. */ -@AlphaComponent +@Experimental package org.apache.spark.ml; -import org.apache.spark.annotation.AlphaComponent; +import org.apache.spark.annotation.Experimental; diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index ac75e9de1a8f2..c589d06d9f7e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -18,7 +18,7 @@ package org.apache.spark /** - * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * Spark ML is a BETA component that adds a new set of machine learning APIs to let users quickly * assemble and configure practical machine learning pipelines. * * @groupname param Parameters diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 7ebbf106ee753..473488dce9b0d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,11 +24,11 @@ import scala.annotation.varargs import scala.collection.mutable import scala.collection.JavaConverters._ -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.util.Identifiable /** - * :: AlphaComponent :: + * :: DeveloperApi :: * A param with self-contained documentation and optionally default value. Primitive-typed param * should use the specialized versions, which are more friendly to Java users. * @@ -39,13 +39,18 @@ import org.apache.spark.ml.util.Identifiable * See [[ParamValidators]] for factory methods for common validation functions. * @tparam T param value type */ -@AlphaComponent -class Param[T] (val parent: Params, val name: String, val doc: String, val isValid: T => Boolean) +@DeveloperApi +class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean) extends Serializable { - def this(parent: Params, name: String, doc: String) = + def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: String, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue[T]) + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + /** * Assert that the given value is valid for this parameter. * @@ -60,8 +65,7 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal */ private[param] def validate(value: T): Unit = { if (!isValid(value)) { - throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value." + - s" Parameter description: $toString") + throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value.") } } @@ -75,26 +79,24 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal */ def ->(value: T): ParamPair[T] = ParamPair(this, value) - /** - * Converts this param's name, doc, and optionally its default value and the user-supplied - * value in its parent to string. - */ - override def toString: String = { - val valueStr = if (parent.isDefined(this)) { - val defaultValueStr = parent.getDefault(this).map("default: " + _) - val currentValueStr = parent.get(this).map("current: " + _) - (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") - } else { - "(undefined)" + override final def toString: String = s"${parent}__$name" + + override final def hashCode: Int = toString.## + + override final def equals(obj: Any): Boolean = { + obj match { + case p: Param[_] => (p.parent == parent) && (p.name == name) + case _ => false } - s"$name: $doc $valueStr" } } /** + * :: DeveloperApi :: * Factory methods for common validation functions for [[Param.isValid]]. * The numerical methods only support Int, Long, Float, and Double. */ +@DeveloperApi object ParamValidators { /** (private[param]) Default validation always return true */ @@ -172,54 +174,100 @@ object ParamValidators { // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... -/** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String, isValid: Double => Boolean) +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Double]]] for Java. + */ +@DeveloperApi +class DoubleParam(parent: String, name: String, doc: String, isValid: Double => Boolean) extends Param[Double](parent, name, doc, isValid) { - def this(parent: Params, name: String, doc: String) = + def this(parent: String, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) + def this(parent: Identifiable, name: String, doc: String, isValid: Double => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + override def w(value: Double): ParamPair[Double] = super.w(value) } -/** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String, isValid: Int => Boolean) +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Int]]] for Java. + */ +@DeveloperApi +class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolean) extends Param[Int](parent, name, doc, isValid) { - def this(parent: Params, name: String, doc: String) = + def this(parent: String, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) + def this(parent: Identifiable, name: String, doc: String, isValid: Int => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + override def w(value: Int): ParamPair[Int] = super.w(value) } -/** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String, isValid: Float => Boolean) +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Float]]] for Java. + */ +@DeveloperApi +class FloatParam(parent: String, name: String, doc: String, isValid: Float => Boolean) extends Param[Float](parent, name, doc, isValid) { - def this(parent: Params, name: String, doc: String) = + def this(parent: String, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) + def this(parent: Identifiable, name: String, doc: String, isValid: Float => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + override def w(value: Float): ParamPair[Float] = super.w(value) } -/** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String, isValid: Long => Boolean) +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Long]]] for Java. + */ +@DeveloperApi +class LongParam(parent: String, name: String, doc: String, isValid: Long => Boolean) extends Param[Long](parent, name, doc, isValid) { - def this(parent: Params, name: String, doc: String) = + def this(parent: String, name: String, doc: String) = this(parent, name, doc, ParamValidators.alwaysTrue) + def this(parent: Identifiable, name: String, doc: String, isValid: Long => Boolean) = + this(parent.uid, name, doc, isValid) + + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + override def w(value: Long): ParamPair[Long] = super.w(value) } -/** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String) // No need for isValid +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Boolean]]] for Java. + */ +@DeveloperApi +class BooleanParam(parent: String, name: String, doc: String) // No need for isValid extends Param[Boolean](parent, name, doc) { + def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) + override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } -/** Specialized version of [[Param[Array[T]]]] for Java. */ +/** + * :: DeveloperApi :: + * Specialized version of [[Param[Array[String]]]] for Java. + */ +@DeveloperApi class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean) extends Param[Array[String]](parent, name, doc, isValid) { @@ -233,8 +281,27 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array } /** + * :: DeveloperApi :: + * Specialized version of [[Param[Array[Double]]]] for Java. + */ +@DeveloperApi +class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array[Double] => Boolean) + extends Param[Array[Double]](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value) + + /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ + def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray) +} + +/** + * :: Experimental :: * A param amd its value. */ +@Experimental case class ParamPair[T](param: Param[T], value: T) { // This is *the* place Param.validate is called. Whenever a parameter is specified, we should // always construct a ParamPair so that validate is called. @@ -242,16 +309,19 @@ case class ParamPair[T](param: Param[T], value: T) { } /** - * :: AlphaComponent :: + * :: DeveloperApi :: * Trait for components that take parameters. This also provides an internal param map to store * parameter values attached to the instance. */ -@AlphaComponent +@DeveloperApi trait Params extends Identifiable with Serializable { /** * Returns all params sorted by their names. The default implementation uses Java reflection to * list all public methods that have no arguments and return [[Param]]. + * + * Note: Developer should not use this method in constructor because we cannot guarantee that + * this variable gets initialized before other params. */ lazy val params: Array[Param[_]] = { val methods = this.getClass.getMethods @@ -263,19 +333,6 @@ trait Params extends Identifiable with Serializable { .map(m => m.invoke(this).asInstanceOf[Param[_]]) } - /** - * Validates parameter values stored internally plus the input parameter map. - * Raises an exception if any parameter is invalid. - * - * This only needs to check for interactions between parameters. - * Parameter value checks which do not depend on other parameters are handled by - * [[Param.validate()]]. This method does not handle input/output column parameters; - * those are checked during schema validation. - */ - def validateParams(paramMap: ParamMap): Unit = { - copy(paramMap).validateParams() - } - /** * Validates parameter values stored internally. * Raise an exception if any parameter value is invalid. @@ -286,15 +343,36 @@ trait Params extends Identifiable with Serializable { * those are checked during schema validation. */ def validateParams(): Unit = { - params.filter(isDefined _).foreach { param => + params.filter(isDefined).foreach { param => param.asInstanceOf[Param[Any]].validate($(param)) } } /** - * Returns the documentation of all params. + * Explains a param. + * @param param input param, must belong to this instance. + * @return a string that contains the input param name, doc, and optionally its default value and + * the user-supplied value */ - def explainParams(): String = params.mkString("\n") + def explainParam(param: Param[_]): String = { + shouldOwn(param) + val valueStr = if (isDefined(param)) { + val defaultValueStr = getDefault(param).map("default: " + _) + val currentValueStr = get(param).map("current: " + _) + (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") + } else { + "(undefined)" + } + s"${param.name}: ${param.doc} $valueStr" + } + + /** + * Explains all params of this instance. + * @see [[explainParam()]] + */ + def explainParams(): String = { + params.map(explainParam).mkString("\n") + } /** Checks whether a param is explicitly set. */ final def isSet(param: Param[_]): Boolean = { @@ -379,20 +457,18 @@ trait Params extends Identifiable with Serializable { * @param value the default value */ protected final def setDefault[T](param: Param[T], value: T): this.type = { - shouldOwn(param) - defaultParamMap.put(param, value) + defaultParamMap.put(param -> value) this } /** * Sets default values for a list of params. * - * Note: Java developers should use the single-parameter [[setDefault()]]. - * Annotating this with varargs causes compilation failures. See SPARK-7498. * @param paramPairs a list of param pairs that specify params and their default values to set * respectively. Make sure that the params are initialized before this method * gets called. */ + @varargs protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => setDefault(p.param.asInstanceOf[Param[Any]], p.value) @@ -417,24 +493,23 @@ trait Params extends Identifiable with Serializable { } /** - * Creates a copy of this instance with a randomly generated uid and some extra params. - * The default implementation calls the default constructor to create a new instance, then - * copies the embedded and extra parameters over and returns the new instance. + * Creates a copy of this instance with the same UID and some extra params. + * The default implementation tries to create a new instance with the same UID. + * Then it copies the embedded and extra parameters over and returns the new instance. * Subclasses should override this method if the default approach is not sufficient. */ def copy(extra: ParamMap): Params = { - val that = this.getClass.newInstance() + val that = this.getClass.getConstructor(classOf[String]).newInstance(uid) copyValues(that, extra) - that } /** * Extracts the embedded default param values and user-supplied values, and then merges them with * extra values from input into a flat param map, where the latter value is used if there exist - * conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap. + * conflicts, i.e., with ordering: default param values < user-supplied values < extra. */ - final def extractParamMap(extraParamMap: ParamMap): ParamMap = { - defaultParamMap ++ paramMap ++ extraParamMap + final def extractParamMap(extra: ParamMap): ParamMap = { + defaultParamMap ++ paramMap ++ extra } /** @@ -452,7 +527,7 @@ trait Params extends Identifiable with Serializable { /** Validates that the input param belongs to this instance. */ private def shouldOwn(param: Param[_]): Unit = { - require(param.parent.eq(this), s"Param $param does not belong to $this.") + require(param.parent == uid && hasParam(param.name), s"Param $param does not belong to $this.") } /** @@ -473,18 +548,20 @@ trait Params extends Identifiable with Serializable { } /** + * :: DeveloperApi :: * Java-friendly wrapper for [[Params]]. * Java developers who need to extend [[Params]] should use this class instead. * If you need to extend a abstract class which already extends [[Params]], then that abstract * class should be Java-friendly as well. */ +@DeveloperApi abstract class JavaParams extends Params /** - * :: AlphaComponent :: + * :: Experimental :: * A param to value map. */ -@AlphaComponent +@Experimental final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { @@ -502,7 +579,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Puts a (param, value) pair (overwrites if the input param exists). */ - def put[T](param: Param[T], value: T): this.type = put(ParamPair(param, value)) + def put[T](param: Param[T], value: T): this.type = put(param -> value) /** * Puts a list of param pairs (overwrites if the input params exists). @@ -568,7 +645,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) override def toString: String = { map.toSeq.sortBy(_._1.name).map { case (param, value) => - s"\t${param.parent.uid}-${param.name}: $value" + s"\t${param.parent}-${param.name}: $value" }.mkString("{\n", ",\n", "\n}") } @@ -605,6 +682,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) def size: Int = map.size } +@Experimental object ParamMap { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 5085b798daa17..8ffbcf0d8bc71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -33,7 +33,7 @@ private[shared] object SharedParamsCodeGen { val params = Seq( ParamDesc[Double]("regParam", "regularization parameter (>= 0)", isValid = "ParamValidators.gtEq(0)"), - ParamDesc[Int]("maxIter", "max number of iterations (>= 0)", + ParamDesc[Int]("maxIter", "maximum number of iterations (>= 0)", isValid = "ParamValidators.gtEq(0)"), ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), @@ -49,11 +49,11 @@ private[shared] object SharedParamsCodeGen { isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), - ParamDesc[String]("outputCol", "output column name"), + ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), - ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")), + ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", isValid = "ParamValidators.inRange(0, 1)"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 7525d37007377..a0c8ccdac9ad9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -45,10 +45,10 @@ private[ml] trait HasRegParam extends Params { private[ml] trait HasMaxIter extends Params { /** - * Param for max number of iterations (>= 0). + * Param for maximum number of iterations (>= 0). * @group param */ - final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0)) + final val maxIter: IntParam = new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", ParamValidators.gtEq(0)) /** @group getParam */ final def getMaxIter: Int = $(maxIter) @@ -185,7 +185,7 @@ private[ml] trait HasInputCols extends Params { } /** - * (private[ml]) Trait for shared param outputCol. + * (private[ml]) Trait for shared param outputCol (default: uid + "__output"). */ private[ml] trait HasOutputCol extends Params { @@ -195,6 +195,8 @@ private[ml] trait HasOutputCol extends Params { */ final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") + setDefault(outputCol, uid + "__output") + /** @group getParam */ final def getOutputCol: String = $(outputCol) } @@ -232,7 +234,7 @@ private[ml] trait HasFitIntercept extends Params { } /** - * (private[ml]) Trait for shared param seed (default: Utils.random.nextLong()). + * (private[ml]) Trait for shared param seed (default: this.getClass.getName.hashCode.toLong). */ private[ml] trait HasSeed extends Params { @@ -242,7 +244,7 @@ private[ml] trait HasSeed extends Params { */ final val seed: LongParam = new LongParam(this, "seed", "random seed") - setDefault(seed, Utils.random.nextLong()) + setDefault(seed, this.getClass.getName.hashCode.toLong) /** @group getParam */ final def getSeed: Long = $(seed) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index d7cbffc3be26f..df009d855ecbb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -31,24 +31,50 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.netlib.util.intW import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.random.XORShiftRandom +/** + * Common params for ALS and ALSModel. + */ +private[recommendation] trait ALSModelParams extends Params with HasPredictionCol { + /** + * Param for the column name for user ids. + * Default: "user" + * @group param + */ + val userCol = new Param[String](this, "userCol", "column name for user ids") + + /** @group getParam */ + def getUserCol: String = $(userCol) + + /** + * Param for the column name for item ids. + * Default: "item" + * @group param + */ + val itemCol = new Param[String](this, "itemCol", "column name for item ids") + + /** @group getParam */ + def getItemCol: String = $(itemCol) +} + /** * Common params for ALS. */ -private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam +private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam with HasPredictionCol with HasCheckpointInterval with HasSeed { /** @@ -104,26 +130,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR /** @group getParam */ def getAlpha: Double = $(alpha) - /** - * Param for the column name for user ids. - * Default: "user" - * @group param - */ - val userCol = new Param[String](this, "userCol", "column name for user ids") - - /** @group getParam */ - def getUserCol: String = $(userCol) - - /** - * Param for the column name for item ids. - * Default: "item" - * @group param - */ - val itemCol = new Param[String](this, "itemCol", "column name for item ids") - - /** @group getParam */ - def getItemCol: String = $(itemCol) - /** * Param for the column name for ratings. * Default: "rating" @@ -147,7 +153,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", - ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, seed -> 0L) + ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10) /** * Validates and transforms the input schema. @@ -155,58 +161,66 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { - require(schema($(userCol)).dataType == IntegerType) - require(schema($(itemCol)).dataType== IntegerType) + SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) + SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) val ratingType = schema($(ratingCol)).dataType require(ratingType == FloatType || ratingType == DoubleType) - val predictionColName = $(predictionCol) - require(!schema.fieldNames.contains(predictionColName), - s"Prediction column $predictionColName already exists.") - val newFields = schema.fields :+ StructField($(predictionCol), FloatType, nullable = false) - StructType(newFields) + SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } /** + * :: Experimental :: * Model fitted by ALS. + * + * @param rank rank of the matrix factorization model + * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features` + * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features` */ +@Experimental class ALSModel private[ml] ( - override val parent: ALS, - k: Int, - userFactors: RDD[(Int, Array[Float])], - itemFactors: RDD[(Int, Array[Float])]) - extends Model[ALSModel] with ALSParams { + override val uid: String, + val rank: Int, + @transient val userFactors: DataFrame, + @transient val itemFactors: DataFrame) + extends Model[ALSModel] with ALSModelParams { + + /** @group setParam */ + def setUserCol(value: String): this.type = set(userCol, value) + + /** @group setParam */ + def setItemCol(value: String): this.type = set(itemCol, value) /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) override def transform(dataset: DataFrame): DataFrame = { - import dataset.sqlContext.implicits._ - val users = userFactors.toDF("id", "features") - val items = itemFactors.toDF("id", "features") - // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => if (userFeatures != null && itemFeatures != null) { - blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) + blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) } else { Float.NaN } } dataset - .join(users, dataset($(userCol)) === users("id"), "left") - .join(items, dataset($(itemCol)) === items("id"), "left") - .select(dataset("*"), predict(users("features"), items("features")).as($(predictionCol))) + .join(userFactors, dataset($(userCol)) === userFactors("id"), "left") + .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left") + .select(dataset("*"), + predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) } override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) + SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) + SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) } } /** + * :: Experimental :: * Alternating Least Squares (ALS) matrix factorization. * * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, @@ -235,10 +249,13 @@ class ALSModel private[ml] ( * indicated user * preferences rather than explicit ratings given to items. */ -class ALS extends Estimator[ALSModel] with ALSParams { +@Experimental +class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { import org.apache.spark.ml.recommendation.ALS.Rating + def this() = this(Identifiable.randomUID("als")) + /** @group setParam */ def setRank(value: Int): this.type = set(rank, value) @@ -292,6 +309,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { } override def fit(dataset: DataFrame): ALSModel = { + import dataset.sqlContext.implicits._ val ratings = dataset .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), col($(ratingCol)).cast(FloatType)) @@ -303,7 +321,10 @@ class ALS extends Estimator[ALSModel] with ALSParams { maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), alpha = $(alpha), nonnegative = $(nonnegative), checkpointInterval = $(checkpointInterval), seed = $(seed)) - copyValues(new ALSModel(this, $(rank), userFactors, itemFactors)) + val userDF = userFactors.toDF("id", "features") + val itemDF = itemFactors.toDF("id", "features") + val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this) + copyValues(model) } override def transformSchema(schema: StructType): StructType = { @@ -322,7 +343,11 @@ class ALS extends Estimator[ALSModel] with ALSParams { @DeveloperApi object ALS extends Logging { - /** Rating class for better code readability. */ + /** + * :: DeveloperApi :: + * Rating class for better code readability. + */ + @DeveloperApi case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) /** Trait for least squares solvers applied to the normal equation. */ @@ -483,8 +508,10 @@ object ALS extends Logging { } /** + * :: DeveloperApi :: * Implementation of the ALS algorithm. */ + @DeveloperApi def train[ID: ClassTag]( // scalastyle:ignore ratings: RDD[Rating[ID]], rank: Int = 10, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index f8f0b161a4812..43b68e7bb20fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{TreeRegressorParams, DecisionTreeParams, DecisionTreeModel, Node} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} @@ -31,17 +31,18 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm * for regression. * It supports both continuous and categorical features. */ -@AlphaComponent -final class DecisionTreeRegressor +@Experimental +final class DecisionTreeRegressor(override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] with DecisionTreeParams with TreeRegressorParams { + def this() = this(Identifiable.randomUID("dtr")) + // Override parameter setters from parent trait for Java API compatibility. override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) @@ -77,21 +78,21 @@ final class DecisionTreeRegressor } } +@Experimental object DecisionTreeRegressor { /** Accessor for supported impurities: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression. * It supports both continuous and categorical features. * @param rootNode Root of the decision tree */ -@AlphaComponent +@Experimental final class DecisionTreeRegressionModel private[ml] ( - override val parent: DecisionTreeRegressor, + override val uid: String, override val rootNode: Node) extends PredictionModel[Vector, DecisionTreeRegressionModel] with DecisionTreeModel with Serializable { @@ -104,7 +105,7 @@ final class DecisionTreeRegressionModel private[ml] ( } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { - copyValues(new DecisionTreeRegressionModel(parent, rootNode), extra) + copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra) } override def toString: String = { @@ -128,6 +129,7 @@ private[ml] object DecisionTreeRegressionModel { s"Cannot convert non-regression DecisionTreeModel (old API) to" + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) - new DecisionTreeRegressionModel(parent, rootNode) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") + new DecisionTreeRegressionModel(uid, rootNode) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 461905c12701a..b7e374bb6cb49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -20,11 +20,11 @@ package org.apache.spark.ml.regression import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{GradientBoostedTrees => OldGBT} @@ -35,17 +35,18 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for regression. * It supports both continuous and categorical features. */ -@AlphaComponent -final class GBTRegressor +@Experimental +final class GBTRegressor(override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] with GBTParams with TreeRegressorParams with Logging { + def this() = this(Identifiable.randomUID("gbtr")) + // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeRegressorParams: @@ -132,6 +133,7 @@ final class GBTRegressor } } +@Experimental object GBTRegressor { // The losses below should be lowercase. /** Accessor for supported loss settings: squared (L2), absolute (L1) */ @@ -139,7 +141,7 @@ object GBTRegressor { } /** - * :: AlphaComponent :: + * :: Experimental :: * * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * model for regression. @@ -147,9 +149,9 @@ object GBTRegressor { * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. */ -@AlphaComponent +@Experimental final class GBTRegressionModel( - override val parent: GBTRegressor, + override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double]) extends PredictionModel[Vector, GBTRegressionModel] @@ -173,7 +175,7 @@ final class GBTRegressionModel( } override def copy(extra: ParamMap): GBTRegressionModel = { - copyValues(new GBTRegressionModel(parent, _trees, _treeWeights), extra) + copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra) } override def toString: String = { @@ -196,9 +198,10 @@ private[ml] object GBTRegressionModel { require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new GBTRegressionModel(parent, newTrees, oldModel.treeWeights) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") + new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 6377923afc0c4..70cd8e9e87fae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.regression import scala.collection.mutable import breeze.linalg.{DenseVector => BDV, norm => brzNorm} -import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, - OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.regression.LabeledPoint @@ -44,8 +44,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol /** - * :: AlphaComponent :: - * + * :: Experimental :: * Linear regression. * * The learning objective is to minimize the squared error, with regularization. @@ -58,10 +57,13 @@ private[regression] trait LinearRegressionParams extends PredictorParams * - L1 (Lasso) * - L2 + L1 (elastic net) */ -@AlphaComponent -class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] +@Experimental +class LinearRegression(override val uid: String) + extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams with Logging { + def this() = this(Identifiable.randomUID("linReg")) + /** * Set the regularization parameter. * Default is 0.0. @@ -81,7 +83,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress setDefault(elasticNetParam -> 0.0) /** - * Set the maximal number of iterations. + * Set the maximum number of iterations. * Default is 100. * @group setParam */ @@ -128,7 +130,7 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " + s"and the intercept will be the mean of the label; as a result, training is not needed.") if (handlePersistence) instances.unpersist() - return new LinearRegressionModel(this, Vectors.sparse(numFeatures, Seq()), yMean) + return new LinearRegressionModel(uid, Vectors.sparse(numFeatures, Seq()), yMean) } val featuresMean = summarizer.mean.toArray @@ -167,7 +169,8 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress val weights = { val rawWeights = state.x.toArray.clone() var i = 0 - while (i < rawWeights.length) { + val len = rawWeights.length + while (i < len) { rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 } i += 1 } @@ -181,18 +184,17 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress if (handlePersistence) instances.unpersist() // TODO: Converts to sparse format based on the storage, but may base on the scoring speed. - new LinearRegressionModel(this, weights.compressed, intercept) + copyValues(new LinearRegressionModel(uid, weights.compressed, intercept)) } } /** - * :: AlphaComponent :: - * + * :: Experimental :: * Model produced by [[LinearRegression]]. */ -@AlphaComponent +@Experimental class LinearRegressionModel private[ml] ( - override val parent: LinearRegression, + override val uid: String, val weights: Vector, val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] @@ -203,7 +205,7 @@ class LinearRegressionModel private[ml] ( } override def copy(extra: ParamMap): LinearRegressionModel = { - copyValues(new LinearRegressionModel(parent, weights, intercept), extra) + copyValues(new LinearRegressionModel(uid, weights, intercept), extra) } } @@ -307,7 +309,8 @@ private class LeastSquaresAggregator( val weightsArray = weights.toArray.clone() var sum = 0.0 var i = 0 - while (i < weightsArray.length) { + val len = weightsArray.length + while (i < len) { if (featuresStd(i) != 0.0) { weightsArray(i) /= featuresStd(i) sum += weightsArray(i) * featuresMean(i) @@ -318,7 +321,7 @@ private class LeastSquaresAggregator( } (weightsArray, -sum + labelMean / labelStd, weightsArray.length) } - + private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) private val gradientSumArray = Array.ofDim[Double](dim) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index dbc628927433d..49a1f7ce8c995 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{RandomForestParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} -import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{RandomForest => OldRandomForest} @@ -31,16 +31,17 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] learning algorithm for regression. * It supports both continuous and categorical features. */ -@AlphaComponent -final class RandomForestRegressor +@Experimental +final class RandomForestRegressor(override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] with RandomForestParams with TreeRegressorParams { + def this() = this(Identifiable.randomUID("rfr")) + // Override parameter setters from parent trait for Java API compatibility. // Parameters from TreeRegressorParams: @@ -87,6 +88,7 @@ final class RandomForestRegressor } } +@Experimental object RandomForestRegressor { /** Accessor for supported impurity settings: variance */ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities @@ -97,15 +99,14 @@ object RandomForestRegressor { } /** - * :: AlphaComponent :: - * + * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. * It supports both continuous and categorical features. * @param _trees Decision trees in the ensemble. */ -@AlphaComponent +@Experimental final class RandomForestRegressionModel private[ml] ( - override val parent: RandomForestRegressor, + override val uid: String, private val _trees: Array[DecisionTreeRegressionModel]) extends PredictionModel[Vector, RandomForestRegressionModel] with TreeEnsembleModel with Serializable { @@ -128,7 +129,7 @@ final class RandomForestRegressionModel private[ml] ( } override def copy(extra: ParamMap): RandomForestRegressionModel = { - copyValues(new RandomForestRegressionModel(parent, _trees), extra) + copyValues(new RandomForestRegressionModel(uid, _trees), extra) } override def toString: String = { @@ -151,9 +152,9 @@ private[ml] object RandomForestRegressionModel { require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") val newTrees = oldModel.trees.map { tree => - // parent, fittingParamMap for each tree is null since there are no good ways to set these. + // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestRegressionModel(parent, newTrees) + new RandomForestRegressionModel(parent.uid, newTrees) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d2dec0c76cb12..4242154be14ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -17,14 +17,16 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} - /** + * :: DeveloperApi :: * Decision tree node interface. */ +@DeveloperApi sealed abstract class Node extends Serializable { // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree @@ -89,10 +91,12 @@ private[ml] object Node { } /** + * :: DeveloperApi :: * Decision tree leaf node. * @param prediction Prediction this node makes * @param impurity Impurity measure at this node (for training data) */ +@DeveloperApi final class LeafNode private[ml] ( override val prediction: Double, override val impurity: Double) extends Node { @@ -118,6 +122,7 @@ final class LeafNode private[ml] ( } /** + * :: DeveloperApi :: * Internal Decision Tree node. * @param prediction Prediction this node would make if it were a leaf node * @param impurity Impurity measure at this node (for training data) @@ -127,6 +132,7 @@ final class LeafNode private[ml] ( * @param rightChild Right-hand child node * @param split Information about the test used to split to the left or right child. */ +@DeveloperApi final class InternalNode private[ml] ( override val prediction: Double, override val impurity: Double, @@ -153,9 +159,9 @@ final class InternalNode private[ml] ( override private[tree] def subtreeToString(indentFactor: Int = 0): String = { val prefix: String = " " * indentFactor - prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" + + prefix + s"If (${InternalNode.splitToString(split, left = true)})\n" + leftChild.subtreeToString(indentFactor + 1) + - prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" + + prefix + s"Else (${InternalNode.splitToString(split, left = false)})\n" + rightChild.subtreeToString(indentFactor + 1) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index 90f1d052764d3..7acdeeee72d23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -17,15 +17,18 @@ package org.apache.spark.ml.tree +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} import org.apache.spark.mllib.tree.model.{Split => OldSplit} /** + * :: DeveloperApi :: * Interface for a "Split," which specifies a test made at a decision tree node * to choose the left or right path. */ +@DeveloperApi sealed trait Split extends Serializable { /** Index of feature which this split tests */ @@ -52,12 +55,14 @@ private[tree] object Split { } /** + * :: DeveloperApi :: * Split which tests a categorical feature. * @param featureIndex Index of the feature to test * @param _leftCategories If the feature value is in this set of categories, then the split goes * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ +@DeveloperApi final class CategoricalSplit private[ml] ( override val featureIndex: Int, _leftCategories: Array[Double], @@ -125,11 +130,13 @@ final class CategoricalSplit private[ml] ( } /** + * :: DeveloperApi :: * Split which tests a continuous feature. * @param featureIndex Index of the feature to test * @param threshold If the feature value is <= this threshold, then the split goes left. * Otherwise, it goes right. */ +@DeveloperApi final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) extends Split { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 816fcedf2efb3..a0c5238d966bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.tree -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} @@ -26,12 +25,10 @@ import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldG import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} /** - * :: DeveloperApi :: * Parameters for Decision Tree-based algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait DecisionTreeParams extends PredictorParams { /** @@ -265,12 +262,10 @@ private[ml] object TreeRegressorParams { } /** - * :: DeveloperApi :: * Parameters for Decision Tree-based ensemble algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** @@ -307,12 +302,10 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { } /** - * :: DeveloperApi :: * Parameters for Random Forest algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait RandomForestParams extends TreeEnsembleParams { /** @@ -377,12 +370,10 @@ private[ml] object RandomForestParams { } /** - * :: DeveloperApi :: * Parameters for Gradient-Boosted Tree algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -@DeveloperApi private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index ac0d1fed84b2e..6434b64aed15d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -20,9 +20,11 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS import org.apache.spark.Logging -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml._ +import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -77,11 +79,14 @@ private[ml] trait CrossValidatorParams extends Params { } /** - * :: AlphaComponent :: + * :: Experimental :: * K-fold cross validation. */ -@AlphaComponent -class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging { +@Experimental +class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel] + with CrossValidatorParams with Logging { + + def this() = this(Identifiable.randomUID("cv")) private val f2jBLAS = new F2jBLAS @@ -97,12 +102,6 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP /** @group setParam */ def setNumFolds(value: Int): this.type = set(numFolds, value) - override def validateParams(paramMap: ParamMap): Unit = { - getEstimatorParamMaps.foreach { eMap => - getEstimator.validateParams(eMap ++ paramMap) - } - } - override def fit(dataset: DataFrame): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) @@ -136,26 +135,34 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - copyValues(new CrossValidatorModel(this, bestModel)) + copyValues(new CrossValidatorModel(uid, bestModel).setParent(this)) } override def transformSchema(schema: StructType): StructType = { $(estimator).transformSchema(schema) } + + override def validateParams(): Unit = { + super.validateParams() + val est = $(estimator) + for (paramMap <- $(estimatorParamMaps)) { + est.copy(paramMap).validateParams() + } + } } /** - * :: AlphaComponent :: + * :: Experimental :: * Model from k-fold cross validation. */ -@AlphaComponent +@Experimental class CrossValidatorModel private[ml] ( - override val parent: CrossValidator, + override val uid: String, val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def validateParams(paramMap: ParamMap): Unit = { - bestModel.validateParams(paramMap) + override def validateParams(): Unit = { + bestModel.validateParams() } override def transform(dataset: DataFrame): DataFrame = { @@ -166,4 +173,9 @@ class CrossValidatorModel private[ml] ( override def transformSchema(schema: StructType): StructType = { bestModel.transformSchema(schema) } + + override def copy(extra: ParamMap): CrossValidatorModel = { + val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]]) + copyValues(copied, extra) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala index dafe73d82c00a..98a8f0330ca45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.tuning import scala.annotation.varargs import scala.collection.mutable -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param._ /** - * :: AlphaComponent :: + * :: Experimental :: * Builder for a param grid used in grid search-based model selection. */ -@AlphaComponent +@Experimental class ParamGridBuilder { private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala index 8a56748ab0a02..ddd34a54503a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -19,15 +19,26 @@ package org.apache.spark.ml.util import java.util.UUID + /** - * Object with a unique id. + * Trait for an object with an immutable unique ID that identifies itself and its derivatives. */ -private[ml] trait Identifiable extends Serializable { +private[spark] trait Identifiable { + + /** + * An immutable unique ID for the object and its derivatives. + */ + val uid: String + + override def toString: String = uid +} + +private[spark] object Identifiable { /** - * A unique id for the object. The default implementation concatenates the class name, "_", and 8 - * random hex chars. + * Returns a random UID that concatenates the given prefix, "_", and 12 random hex chars. */ - private[ml] val uid: String = - this.getClass.getSimpleName + "_" + UUID.randomUUID().toString.take(8) + def randomUID(prefix: String): String = { + prefix + "_" + UUID.randomUUID().toString.takeRight(12) + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index 56075c9a6b39f..2a1db90f2ca2b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -19,18 +19,14 @@ package org.apache.spark.ml.util import scala.collection.immutable.HashMap -import org.apache.spark.annotation.Experimental import org.apache.spark.ml.attribute._ import org.apache.spark.sql.types.StructField /** - * :: Experimental :: - * * Helper utilities for tree-based algorithms */ -@Experimental -object MetadataUtils { +private[spark] object MetadataUtils { /** * Examine a schema to identify the number of classes in a label column. diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 11592b77eb356..7cd53c6d7ef79 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,15 +17,13 @@ package org.apache.spark.ml.util -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.types.{DataType, StructField, StructType} + /** - * :: DeveloperApi :: * Utils for handling schemas. */ -@DeveloperApi -object SchemaUtils { +private[spark] object SchemaUtils { // TODO: Move the utility methods to SQL. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index f4c477596557f..16f3131796709 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -345,28 +345,40 @@ private[python] class PythonMLLibAPI extends Serializable { * Returns a list containing weights, mean and covariance of each mixture component. */ def trainGaussianMixture( - data: JavaRDD[Vector], - k: Int, - convergenceTol: Double, + data: JavaRDD[Vector], + k: Int, + convergenceTol: Double, maxIterations: Int, - seed: java.lang.Long): JList[Object] = { + seed: java.lang.Long, + initialModelWeights: java.util.ArrayList[Double], + initialModelMu: java.util.ArrayList[Vector], + initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = { val gmmAlg = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) + if (initialModelWeights != null && initialModelMu != null && initialModelSigma != null) { + val gaussians = initialModelMu.asScala.toSeq.zip(initialModelSigma.asScala.toSeq).map { + case (x, y) => new MultivariateGaussian(x.asInstanceOf[Vector], y.asInstanceOf[Matrix]) + } + val initialModel = new GaussianMixtureModel( + initialModelWeights.asScala.toArray, gaussians.toArray) + gmmAlg.setInitialModel(initialModel) + } + if (seed != null) gmmAlg.setSeed(seed) try { val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) var wt = ArrayBuffer.empty[Double] - var mu = ArrayBuffer.empty[Vector] + var mu = ArrayBuffer.empty[Vector] var sigma = ArrayBuffer.empty[Matrix] for (i <- 0 until model.k) { wt += model.weights(i) mu += model.gaussians(i).mu sigma += model.gaussians(i).sigma - } + } List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava } finally { data.rdd.unpersist(blocking = false) @@ -380,14 +392,14 @@ private[python] class PythonMLLibAPI extends Serializable { data: JavaRDD[Vector], wt: Vector, mu: Array[Object], - si: Array[Object]): RDD[Vector] = { + si: Array[Object]): RDD[Vector] = { val weight = wt.toArray val mean = mu.map(_.asInstanceOf[DenseVector]) val sigma = si.map(_.asInstanceOf[DenseMatrix]) val gaussians = Array.tabulate(weight.length){ i => new MultivariateGaussian(mean(i), sigma(i)) - } + } val model = new GaussianMixtureModel(weight, gaussians) model.predictSoft(data).map(Vectors.dense) } @@ -416,7 +428,7 @@ private[python] class PythonMLLibAPI extends Serializable { if (seed != null) als.setSeed(seed) - val model = als.run(ratingsJRDD.rdd) + val model = als.run(ratingsJRDD.rdd) new MatrixFactorizationModelWrapper(model) } @@ -447,7 +459,7 @@ private[python] class PythonMLLibAPI extends Serializable { if (seed != null) als.setSeed(seed) - val model = als.run(ratingsJRDD.rdd) + val model = als.run(ratingsJRDD.rdd) new MatrixFactorizationModelWrapper(model) } @@ -482,7 +494,7 @@ private[python] class PythonMLLibAPI extends Serializable { def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = { new Normalizer(p).transform(rdd) } - + /** * Java stub for StandardScaler.fit(). This stub returns a * handle to the Java object instead of the content of the Java object. @@ -1230,7 +1242,7 @@ private[spark] object SerDe extends Serializable { } /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ - def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { + def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { rdd.map(x => Array(x._1, x._2)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index bd2e9079ce1ae..2df4d21e8cd55 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -163,7 +163,7 @@ class LogisticRegressionModel ( override protected def formatVersion: String = "1.0" override def toString: String = { - s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.get}" + s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}" } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index c9b3ff0172e2e..f51ee36d0dfcb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -21,21 +21,16 @@ import java.lang.{Iterable => JIterable} import scala.collection.JavaConverters._ -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} -import breeze.numerics.{exp => brzExp, log => brzLog} - import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.json4s.{DefaultFormats, JValue} import org.apache.spark.{Logging, SparkContext, SparkException} -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} +import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} - /** * Model for Naive Bayes Classifiers. * @@ -43,7 +38,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} * @param pi log of class priors, whose dimension is C, number of labels * @param theta log of class conditional probabilities, whose dimension is C-by-D, * where D is number of features - * @param modelType The type of NB model to fit can be "Multinomial" or "Bernoulli" + * @param modelType The type of NB model to fit can be "multinomial" or "bernoulli" */ class NaiveBayesModel private[mllib] ( val labels: Array[Double], @@ -52,8 +47,13 @@ class NaiveBayesModel private[mllib] ( val modelType: String) extends ClassificationModel with Serializable with Saveable { + import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes} + + private val piVector = new DenseVector(pi) + private val thetaMatrix = new DenseMatrix(labels.length, theta(0).length, theta.flatten, true) + private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) = - this(labels, pi, theta, "Multinomial") + this(labels, pi, theta, NaiveBayes.Multinomial) /** A Java-friendly constructor that takes three Iterable parameters. */ private[mllib] def this( @@ -62,20 +62,24 @@ class NaiveBayesModel private[mllib] ( theta: JIterable[JIterable[Double]]) = this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray)) - private val brzPi = new BDV[Double](pi) - private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t + require(supportedModelTypes.contains(modelType), + s"Invalid modelType $modelType. Supported modelTypes are $supportedModelTypes.") // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0. - // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra + // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra // application of this condition (in predict function). - private val (brzNegTheta, brzNegThetaSum) = modelType match { - case "Multinomial" => (None, None) - case "Bernoulli" => - val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x)) - (Option(negTheta), Option(brzSum(negTheta, Axis._1))) + private val (thetaMinusNegTheta, negThetaSum) = modelType match { + case Multinomial => (None, None) + case Bernoulli => + val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value))) + val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0}) + val thetaMinusNegTheta = thetaMatrix.map { value => + value - math.log(1.0 - math.exp(value)) + } + (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones))) case _ => // This should never happen. - throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") + throw new UnknownError(s"Invalid modelType: $modelType.") } override def predict(testData: RDD[Vector]): RDD[Double] = { @@ -88,14 +92,24 @@ class NaiveBayesModel private[mllib] ( override def predict(testData: Vector): Double = { modelType match { - case "Multinomial" => - labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) - case "Bernoulli" => - labels (brzArgmax (brzPi + - (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) + case Multinomial => + val prob = thetaMatrix.multiply(testData) + BLAS.axpy(1.0, piVector, prob) + labels(prob.argmax) + case Bernoulli => + testData.foreachActive { (index, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.") + } + } + val prob = thetaMinusNegTheta.get.multiply(testData) + BLAS.axpy(1.0, piVector, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + labels(prob.argmax) case _ => // This should never happen. - throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") + throw new UnknownError(s"Invalid modelType: $modelType.") } } @@ -137,17 +151,17 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Create Parquet data. val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() - dataRDD.saveAsParquetFile(dataPath(path)) + dataRDD.write.parquet(dataPath(path)) } def load(sc: SparkContext, path: String): NaiveBayesModel = { val sqlContext = new SQLContext(sc) // Load Parquet data. - val dataRDD = sqlContext.parquetFile(dataPath(path)) + val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) - assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") val data = dataArray(0) val labels = data.getAs[Seq[Double]](0).toArray val pi = data.getAs[Seq[Double]](1).toArray @@ -183,17 +197,17 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Create Parquet data. val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() - dataRDD.saveAsParquetFile(dataPath(path)) + dataRDD.write.parquet(dataPath(path)) } def load(sc: SparkContext, path: String): NaiveBayesModel = { val sqlContext = new SQLContext(sc) // Load Parquet data. - val dataRDD = sqlContext.parquetFile(dataPath(path)) + val dataRDD = sqlContext.read.parquet(dataPath(path)) // Check schema explicitly since erasure makes it hard to use match-case for checking. checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta").take(1) - assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") val data = dataArray(0) val labels = data.getAs[Seq[Double]](0).toArray val pi = data.getAs[Seq[Double]](1).toArray @@ -220,16 +234,16 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { s"($loadedClassName, $version). Supported:\n" + s" ($classNameV1_0, 1.0)") } - assert(model.pi.size == numClasses, + assert(model.pi.length == numClasses, s"NaiveBayesModel.load expected $numClasses classes," + - s" but class priors vector pi had ${model.pi.size} elements") - assert(model.theta.size == numClasses, + s" but class priors vector pi had ${model.pi.length} elements") + assert(model.theta.length == numClasses, s"NaiveBayesModel.load expected $numClasses classes," + - s" but class conditionals array theta had ${model.theta.size} elements") - assert(model.theta.forall(_.size == numFeatures), + s" but class conditionals array theta had ${model.theta.length} elements") + assert(model.theta.forall(_.length == numFeatures), s"NaiveBayesModel.load expected $numFeatures features," + s" but class conditionals array theta had elements of size:" + - s" ${model.theta.map(_.size).mkString(",")}") + s" ${model.theta.map(_.length).mkString(",")}") model } } @@ -247,9 +261,11 @@ class NaiveBayes private ( private var lambda: Double, private var modelType: String) extends Serializable with Logging { - def this(lambda: Double) = this(lambda, "Multinomial") + import NaiveBayes.{Bernoulli, Multinomial} + + def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial) - def this() = this(1.0, "Multinomial") + def this() = this(1.0, NaiveBayes.Multinomial) /** Set the smoothing parameter. Default: 1.0. */ def setLambda(lambda: Double): NaiveBayes = { @@ -262,12 +278,11 @@ class NaiveBayes private ( /** * Set the model type using a string (case-sensitive). - * Supported options: "Multinomial" and "Bernoulli". - * (default: Multinomial) + * Supported options: "multinomial" (default) and "bernoulli". */ - def setModelType(modelType:String): NaiveBayes = { + def setModelType(modelType: String): NaiveBayes = { require(NaiveBayes.supportedModelTypes.contains(modelType), - s"NaiveBayes was created with an unknown ModelType: $modelType") + s"NaiveBayes was created with an unknown modelType: $modelType.") this.modelType = modelType this } @@ -283,30 +298,46 @@ class NaiveBayes private ( def run(data: RDD[LabeledPoint]): NaiveBayesModel = { val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { - case SparseVector(size, indices, values) => - values - case DenseVector(values) => - values + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values } if (!values.forall(_ >= 0.0)) { throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") } } + val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + if (!values.forall(v => v == 0.0 || v == 1.0)) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") + } + } + // Aggregates term frequencies per label. // TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( + val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)]( createCombiner = (v: Vector) => { - requireNonnegativeValues(v) - (1L, v.toBreeze.toDenseVector) + if (modelType == Bernoulli) { + requireZeroOneBernoulliValues(v) + } else { + requireNonnegativeValues(v) + } + (1L, v.copy.toDense) }, - mergeValue = (c: (Long, BDV[Double]), v: Vector) => { + mergeValue = (c: (Long, DenseVector), v: Vector) => { requireNonnegativeValues(v) - (c._1 + 1L, c._2 += v.toBreeze) + BLAS.axpy(1.0, v, c._2) + (c._1 + 1L, c._2) }, - mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => - (c1._1 + c2._1, c1._2 += c2._2) + mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => { + BLAS.axpy(1.0, c2._2, c1._2) + (c1._1 + c2._1, c1._2) + } ).collect() val numLabels = aggregated.length @@ -326,11 +357,11 @@ class NaiveBayes private ( labels(i) = label pi(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = modelType match { - case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) - case "Bernoulli" => math.log(n + 2.0 * lambda) + case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) + case Bernoulli => math.log(n + 2.0 * lambda) case _ => // This should never happen. - throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType") + throw new UnknownError(s"Invalid modelType: $modelType.") } var j = 0 while (j < numFeatures) { @@ -349,8 +380,14 @@ class NaiveBayes private ( */ object NaiveBayes { + /** String name for multinomial model type. */ + private[classification] val Multinomial: String = "multinomial" + + /** String name for Bernoulli model type. */ + private[classification] val Bernoulli: String = "bernoulli" + /* Set of modelTypes that NaiveBayes supports */ - private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli") + private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) /** * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. @@ -380,7 +417,7 @@ object NaiveBayes { * @param lambda The smoothing parameter */ def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { - new NaiveBayes(lambda, "Multinomial").run(input) + new NaiveBayes(lambda, Multinomial).run(input) } /** @@ -403,7 +440,7 @@ object NaiveBayes { */ def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { require(supportedModelTypes.contains(modelType), - s"NaiveBayes was created with an unknown ModelType: $modelType") + s"NaiveBayes was created with an unknown modelType: $modelType.") new NaiveBayes(lambda, modelType).run(input) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 33104cf06c6ea..348485560713e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -89,7 +89,7 @@ class SVMModel ( override protected def formatVersion: String = "1.0" override def toString: String = { - s"${super.toString}, numClasses = 2, threshold = ${threshold.get}" + s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}" } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala index 3b6790cce47c6..fe09f6b75d28b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala @@ -62,7 +62,7 @@ private[classification] object GLMClassificationModel { // Create Parquet data. val data = Data(weights, intercept, threshold) - sc.parallelize(Seq(data), 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + sc.parallelize(Seq(data), 1).toDF().write.parquet(Loader.dataPath(path)) } /** @@ -75,7 +75,7 @@ private[classification] object GLMClassificationModel { def loadData(sc: SparkContext, path: String, modelClass: String): Data = { val datapath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) - val dataRDD = sqlContext.parquetFile(datapath) + val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept", "threshold").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") val data = dataArray(0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index c88410ac0ff43..70b0e40948e51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -36,11 +36,11 @@ import org.apache.spark.util.Utils * independent Gaussian distributions with associated "mixing" weights * specifying each's contribution to the composite. * - * Given a set of sample points, this class will maximize the log-likelihood - * for a mixture of k Gaussians, iterating until the log-likelihood changes by + * Given a set of sample points, this class will maximize the log-likelihood + * for a mixture of k Gaussians, iterating until the log-likelihood changes by * less than convergenceTol, or until it has reached the max number of iterations. * While this process is generally guaranteed to converge, it is not guaranteed - * to find a global optimum. + * to find a global optimum. * * Note: For high-dimensional data (with many features), this algorithm may perform poorly. * This is due to high-dimensional data (a) making it difficult to cluster at all (based @@ -53,24 +53,24 @@ import org.apache.spark.util.Utils */ @Experimental class GaussianMixture private ( - private var k: Int, - private var convergenceTol: Double, + private var k: Int, + private var convergenceTol: Double, private var maxIterations: Int, private var seed: Long) extends Serializable { - + /** * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01, * maxIterations: 100, seed: random}. */ def this() = this(2, 0.01, 100, Utils.random.nextLong()) - + // number of samples per cluster to use when initializing Gaussians private val nSamples = 5 - - // an initializing GMM can be provided rather than using the + + // an initializing GMM can be provided rather than using the // default random starting point private var initialModel: Option[GaussianMixtureModel] = None - + /** Set the initial GMM starting point, bypassing the random initialization. * You must call setK() prior to calling this method, and the condition * (model.k == this.k) must be met; failure will result in an IllegalArgumentException @@ -83,37 +83,37 @@ class GaussianMixture private ( } this } - + /** Return the user supplied initial GMM, if supplied */ def getInitialModel: Option[GaussianMixtureModel] = initialModel - + /** Set the number of Gaussians in the mixture model. Default: 2 */ def setK(k: Int): this.type = { this.k = k this } - + /** Return the number of Gaussians in the mixture model */ def getK: Int = k - + /** Set the maximum number of iterations to run. Default: 100 */ def setMaxIterations(maxIterations: Int): this.type = { this.maxIterations = maxIterations this } - + /** Return the maximum number of iterations to run */ def getMaxIterations: Int = maxIterations - + /** - * Set the largest change in log-likelihood at which convergence is + * Set the largest change in log-likelihood at which convergence is * considered to have occurred. */ def setConvergenceTol(convergenceTol: Double): this.type = { this.convergenceTol = convergenceTol this } - + /** * Return the largest change in log-likelihood at which convergence is * considered to have occurred. @@ -132,41 +132,41 @@ class GaussianMixture private ( /** Perform expectation maximization */ def run(data: RDD[Vector]): GaussianMixtureModel = { val sc = data.sparkContext - + // we will operate on the data as breeze data val breezeData = data.map(_.toBreeze).cache() - + // Get length of the input vectors val d = breezeData.first().length - + // Determine initial weights and corresponding Gaussians. // If the user supplied an initial GMM, we use those values, otherwise // we start with uniform weights, a random mean from the data, and // diagonal covariance matrices using component variances - // derived from the samples + // derived from the samples val (weights, gaussians) = initialModel match { case Some(gmm) => (gmm.weights, gmm.gaussians) - + case None => { val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) - (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => + (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) - new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) + new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) }) } } - - var llh = Double.MinValue // current log-likelihood + + var llh = Double.MinValue // current log-likelihood var llhp = 0.0 // previous log-likelihood - + var iter = 0 while (iter < maxIterations && math.abs(llh-llhp) > convergenceTol) { // create and broadcast curried cluster contribution function val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_) - + // aggregate the cluster contribution for all sample points val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _) - + // Create new distributions based on the partial assignments // (often referred to as the "M" step in literature) val sumWeights = sums.weights.sum @@ -179,22 +179,22 @@ class GaussianMixture private ( gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i)) i = i + 1 } - + llhp = llh // current becomes previous llh = sums.logLikelihood // this is the freshly computed log-likelihood iter += 1 - } - + } + new GaussianMixtureModel(weights, gaussians) } - + /** Average of dense breeze vectors */ private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = { val v = BDV.zeros[Double](x(0).length) x.foreach(xi => v += xi) - v / x.length.toDouble + v / x.length.toDouble } - + /** * Construct matrix where diagonal entries are element-wise * variance of input vectors (computes biased variance) @@ -210,14 +210,14 @@ class GaussianMixture private ( // companion class to provide zero constructor for ExpectationSum private object ExpectationSum { def zero(k: Int, d: Int): ExpectationSum = { - new ExpectationSum(0.0, Array.fill(k)(0.0), - Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d))) + new ExpectationSum(0.0, Array.fill(k)(0.0), + Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d, d))) } - + // compute cluster contributions for each input point // (U, T) => U for aggregation def add( - weights: Array[Double], + weights: Array[Double], dists: Array[MultivariateGaussian]) (sums: ExpectationSum, x: BV[Double]): ExpectationSum = { val p = weights.zip(dists).map { @@ -235,7 +235,7 @@ private object ExpectationSum { i = i + 1 } sums - } + } } // Aggregation class for partial expectation results @@ -244,9 +244,9 @@ private class ExpectationSum( val weights: Array[Double], val means: Array[BDV[Double]], val sigmas: Array[BreezeMatrix[Double]]) extends Serializable { - + val k = weights.length - + def +=(x: ExpectationSum): ExpectationSum = { var i = 0 while (i < k) { @@ -257,5 +257,5 @@ private class ExpectationSum( } logLikelihood += x.logLikelihood this - } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index ec65a3da689de..5fc2cb1b62d33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -34,21 +34,20 @@ import org.apache.spark.sql.{SQLContext, Row} /** * :: Experimental :: * - * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points - * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are - * the respective mean and covariance for each Gaussian distribution i=1..k. - * - * @param weight Weights for each Gaussian distribution in the mixture, where weight(i) is - * the weight for Gaussian i, and weight.sum == 1 - * @param mu Means for each Gaussian in the mixture, where mu(i) is the mean for Gaussian i - * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the - * covariance matrix for Gaussian i + * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points + * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are + * the respective mean and covariance for each Gaussian distribution i=1..k. + * + * @param weights Weights for each Gaussian distribution in the mixture, where weights(i) is + * the weight for Gaussian i, and weights.sum == 1 + * @param gaussians Array of MultivariateGaussian where gaussians(i) represents + * the Multivariate Gaussian (Normal) Distribution for Gaussian i */ @Experimental class GaussianMixtureModel( - val weights: Array[Double], + val weights: Array[Double], val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{ - + require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match") override protected def formatVersion = "1.0" @@ -65,20 +64,20 @@ class GaussianMixtureModel( val responsibilityMatrix = predictSoft(points) responsibilityMatrix.map(r => r.indexOf(r.max)) } - + /** * Given the input vectors, return the membership value of each vector - * to all mixture components. + * to all mixture components. */ def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = { val sc = points.sparkContext val bcDists = sc.broadcast(gaussians) val bcWeights = sc.broadcast(weights) - points.map { x => + points.map { x => computeSoftAssignments(x.toBreeze.toDenseVector, bcDists.value, bcWeights.value, k) } } - + /** * Compute the partial assignments for each vector */ @@ -90,7 +89,7 @@ class GaussianMixtureModel( val p = weights.zip(dists).map { case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt) } - val pSum = p.sum + val pSum = p.sum for (i <- 0 until k) { p(i) /= pSum } @@ -127,13 +126,13 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { val dataArray = Array.tabulate(weights.length) { i => Data(weights(i), gaussians(i).mu, gaussians(i).sigma) } - sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): GaussianMixtureModel = { val dataPath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) - val dataFrame = sqlContext.parquetFile(dataPath) + val dataFrame = sqlContext.read.parquet(dataPath) val dataArray = dataFrame.select("weight", "mu", "sigma").collect() // Check schema explicitly since erasure makes it hard to use match-case for checking. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index ba228b11fcec3..8ecb3df11d95e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -110,7 +110,7 @@ object KMeansModel extends Loader[KMeansModel] { val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) => Cluster(id, point) }.toDF() - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): KMeansModel = { @@ -120,7 +120,7 @@ object KMeansModel extends Loader[KMeansModel] { assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] - val centriods = sqlContext.parquetFile(Loader.dataPath(path)) + val centriods = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[Cluster](centriods.schema) val localCentriods = centriods.map(Cluster.apply).collect() assert(k == localCentriods.size) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 6fa2fe053c6a4..8e5154b902d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -273,7 +273,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * Default: 1024, following the original Online LDA paper. */ def setTau0(tau0: Double): this.type = { - require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") + require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0") this.tau0 = tau0 this } @@ -339,7 +339,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { override private[clustering] def initialize( docs: RDD[(Long, Vector)], - lda: LDA): OnlineLDAOptimizer = { + lda: LDA): OnlineLDAOptimizer = { this.k = lda.getK this.corpusSize = docs.count() this.vocabSize = docs.first()._2.size @@ -458,7 +458,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * uses digamma which is accurate but expensive. */ private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { - val rowSum = sum(alpha(breeze.linalg.*, ::)) + val rowSum = sum(alpha(breeze.linalg.*, ::)) val digAlpha = digamma(alpha) val digRowSum = digamma(rowSum) val result = digAlpha(::, breeze.linalg.*) - digRowSum diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index aa53e88d59856..e7a243f854e33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -74,7 +74,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val dataRDD = model.assignments.toDF() - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): PowerIterationClusteringModel = { @@ -86,7 +86,7 @@ object PowerIterationClusteringModel extends Loader[PowerIterationClusteringMode assert(formatVersion == thisFormatVersion) val k = (metadata \ "k").extract[Int] - val assignments = sqlContext.parquetFile(Loader.dataPath(path)) + val assignments = sqlContext.read.parquet(Loader.dataPath(path)) Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema) val assignmentsRDD = assignments.map { @@ -121,7 +121,7 @@ class PowerIterationClustering private[clustering] ( import org.apache.spark.mllib.clustering.PowerIterationClustering._ /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, - * initMode: "random"}. + * initMode: "random"}. */ def this() = this(k = 2, maxIterations = 100, initMode = "random") @@ -243,7 +243,7 @@ object PowerIterationClustering extends Logging { /** * Generates random vertex properties (v0) to start power iteration. - * + * * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing a random vector * with unit 1-norm @@ -266,7 +266,7 @@ object PowerIterationClustering extends Logging { * Generates the degree vector as the vertex properties (v0) to start power iteration. * It is not exactly the node degrees but just the normalized sum similarities. Call it * as degree vector because it is used in the PIC paper. - * + * * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing the degree vector */ @@ -276,7 +276,7 @@ object PowerIterationClustering extends Logging { val v0 = g.vertices.mapValues(_ / sum) GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) } - + /** * Runs power iteration. * @param g input graph with edges representing the normalized affinity matrix (W) and vertices diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 812014a041719..c21e4fe7dc9b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -178,7 +178,7 @@ class StreamingKMeans( /** Set the decay factor directly (for forgetful algorithms). */ def setDecayFactor(a: Double): this.type = { - this.decayFactor = decayFactor + this.decayFactor = a this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index a8378a76d20ae..bf6eb1d5bd2ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ +import org.apache.spark.sql.DataFrame /** * Evaluator for multilabel classification. @@ -27,6 +28,13 @@ import org.apache.spark.SparkContext._ */ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { + /** + * An auxiliary constructor taking a DataFrame. + * @param predictionAndLabels a DataFrame with two double array columns: prediction and label + */ + private[mllib] def this(predictionAndLabels: DataFrame) = + this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray))) + private lazy val numDocs: Long = predictionAndLabels.count() private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index b9b54b93c27fa..5b5a2a1450f7f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -31,6 +31,8 @@ import org.apache.spark.rdd.RDD * ::Experimental:: * Evaluator for ranking algorithms. * + * Java users should use [[RankingMetrics$.of]] to create a [[RankingMetrics]] instance. + * * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. */ @Experimental diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index c6057c7f837b1..5f8c1dea237b4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -38,7 +38,8 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf protected def isSorted(array: Array[Int]): Boolean = { var i = 1 - while (i < array.length) { + val len = array.length + while (i < len) { if (array(i) < array(i-1)) return false i += 1 } @@ -107,7 +108,7 @@ class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransf * (ordered by statistic value descending) */ @Experimental -class ChiSqSelector (val numTopFeatures: Int) { +class ChiSqSelector (val numTopFeatures: Int) extends Serializable { /** * Returns a ChiSquared feature selector. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala index b0985baf9b278..d67fe6c3ee4f8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ElementwiseProduct.scala @@ -25,10 +25,10 @@ import org.apache.spark.mllib.linalg._ * Outputs the Hadamard product (i.e., the element-wise product) of each input vector with a * provided "weight" vector. In other words, it scales each column of the dataset by a scalar * multiplier. - * @param scalingVector The values used to scale the reference vector's individual components. + * @param scalingVec The values used to scale the reference vector's individual components. */ @Experimental -class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { +class ElementwiseProduct(val scalingVec: Vector) extends VectorTransformer { /** * Does the hadamard product transformation. @@ -37,15 +37,15 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { * @return transformed vector. */ override def transform(vector: Vector): Vector = { - require(vector.size == scalingVector.size, - s"vector sizes do not match: Expected ${scalingVector.size} but found ${vector.size}") + require(vector.size == scalingVec.size, + s"vector sizes do not match: Expected ${scalingVec.size} but found ${vector.size}") vector match { case dv: DenseVector => val values: Array[Double] = dv.values.clone() - val dim = scalingVector.size + val dim = scalingVec.size var i = 0 while (i < dim) { - values(i) *= scalingVector(i) + values(i) *= scalingVec(i) i += 1 } Vectors.dense(values) @@ -54,7 +54,7 @@ class ElementwiseProduct(val scalingVector: Vector) extends VectorTransformer { val dim = values.length var i = 0 while (i < dim) { - values(i) *= scalingVector(indices(i)) + values(i) *= scalingVec(indices(i)) i += 1 } Vectors.sparse(size, indices, values) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index a89eea0e21be2..efbfeb4059f5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -144,7 +144,7 @@ private object IDF { * Since arrays are initialized to 0 by default, * we just omit changing those entries. */ - if(df(j) >= minDocFreq) { + if (df(j) >= minDocFreq) { inv(j) = math.log((m + 1.0) / (df(j) + 1.0)) } j += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 6ae6917eae595..c73b8f258060d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -90,7 +90,7 @@ class StandardScalerModel ( @DeveloperApi def setWithMean(withMean: Boolean): this.type = { - require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null") + require(!(withMean && this.mean == null), "cannot set withMean to true while mean is null") this.withMean = withMean this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 98e83112f52ae..51546d41c36a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -42,32 +42,32 @@ import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.sql.{SQLContext, Row} /** - * Entry in vocabulary + * Entry in vocabulary */ private case class VocabWord( var word: String, var cn: Int, var point: Array[Int], var code: Array[Int], - var codeLen:Int + var codeLen: Int ) /** * :: Experimental :: * Word2Vec creates vector representation of words in a text corpus. * The algorithm first constructs a vocabulary from the corpus - * and then learns vector representation of words in the vocabulary. - * The vector representation can be used as features in + * and then learns vector representation of words in the vocabulary. + * The vector representation can be used as features in * natural language processing and machine learning algorithms. - * - * We used skip-gram model in our implementation and hierarchical softmax + * + * We used skip-gram model in our implementation and hierarchical softmax * method to train the model. The variable names in the implementation * matches the original C implementation. * - * For original C implementation, see https://code.google.com/p/word2vec/ - * For research papers, see + * For original C implementation, see https://code.google.com/p/word2vec/ + * For research papers, see * Efficient Estimation of Word Representations in Vector Space - * and + * and * Distributed Representations of Words and Phrases and their Compositionality. */ @Experimental @@ -79,7 +79,7 @@ class Word2Vec extends Serializable with Logging { private var numIterations = 1 private var seed = Utils.random.nextLong() private var minCount = 5 - + /** * Sets vector size (default: 100). */ @@ -122,15 +122,15 @@ class Word2Vec extends Serializable with Logging { this } - /** - * Sets minCount, the minimum number of times a token must appear to be included in the word2vec + /** + * Sets minCount, the minimum number of times a token must appear to be included in the word2vec * model's vocabulary (default: 5). */ def setMinCount(minCount: Int): this.type = { this.minCount = minCount this } - + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 @@ -150,14 +150,17 @@ class Word2Vec extends Serializable with Logging { .map(x => VocabWord( x._1, x._2, - new Array[Int](MAX_CODE_LENGTH), - new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), + new Array[Int](MAX_CODE_LENGTH), 0)) .filter(_.cn >= minCount) .collect() .sortWith((a, b) => a.cn > b.cn) - + vocabSize = vocab.length + require(vocabSize > 0, "The vocabulary size should be > 0. You may need to check " + + "the setting of minCount, which could be large enough to remove all your words in sentences.") + var a = 0 while (a < vocabSize) { vocabHash += vocab(a).word -> a @@ -195,8 +198,8 @@ class Word2Vec extends Serializable with Logging { } var pos1 = vocabSize - 1 var pos2 = vocabSize - - var min1i = 0 + + var min1i = 0 var min2i = 0 a = 0 @@ -265,15 +268,15 @@ class Word2Vec extends Serializable with Logging { val words = dataset.flatMap(x => x) learnVocab(words) - + createBinaryTree() - + val sc = dataset.context val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) - + val sentences: RDD[Array[Int]] = words.mapPartitions { iter => new Iterator[Array[Int]] { def hasNext: Boolean = iter.hasNext @@ -294,7 +297,7 @@ class Word2Vec extends Serializable with Logging { } } } - + val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) @@ -399,7 +402,7 @@ class Word2Vec extends Serializable with Logging { } } newSentences.unpersist() - + val word2VecMap = mutable.HashMap.empty[String, Array[Float]] var i = 0 while (i < vocabSize) { @@ -466,7 +469,7 @@ class Word2VecModel private[mllib] ( val norm1 = blas.snrm2(n, v1, 1) val norm2 = blas.snrm2(n, v2, 1) if (norm1 == 0 || norm2 == 0) return 0.0 - blas.sdot(n, v1, 1, v2,1) / norm1 / norm2 + blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2 } override protected def formatVersion = "1.0" @@ -477,7 +480,7 @@ class Word2VecModel private[mllib] ( /** * Transforms a word to its vector representation - * @param word a word + * @param word a word * @return vector representation of word */ def transform(word: String): Vector = { @@ -492,18 +495,18 @@ class Word2VecModel private[mllib] ( /** * Find synonyms of a word * @param word a word - * @param num number of synonyms to find + * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) - findSynonyms(vector,num) + findSynonyms(vector, num) } /** * Find synonyms of the vector representation of a word * @param vector vector representation of a word - * @param num number of synonyms to find + * @param num number of synonyms to find * @return array of (word, cosineSimilarity) */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { @@ -556,7 +559,7 @@ object Word2VecModel extends Loader[Word2VecModel] { def load(sc: SparkContext, path: String): Word2VecModel = { val dataPath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) - val dataFrame = sqlContext.parquetFile(dataPath) + val dataFrame = sqlContext.read.parquet(dataPath) val dataArray = dataFrame.select("word", "vector").collect() @@ -580,7 +583,7 @@ object Word2VecModel extends Loader[Word2VecModel] { sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } - sc.parallelize(dataArray.toSeq, 1).toDF().saveAsParquetFile(Loader.dataPath(path)) + sc.parallelize(dataArray.toSeq, 1).toDF().write.parquet(Loader.dataPath(path)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 87052e1ba8539..557119f7b1cd1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -228,7 +228,7 @@ private[spark] object BLAS extends Serializable with Logging { } _nativeBLAS } - + /** * A := alpha * x * x^T^ + A * @param alpha a real scalar that will be multiplied to x * x^T^. @@ -264,7 +264,7 @@ private[spark] object BLAS extends Serializable with Logging { j += 1 } i += 1 - } + } } private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { @@ -463,7 +463,7 @@ private[spark] object BLAS extends Serializable with Logging { def gemv( alpha: Double, A: Matrix, - x: DenseVector, + x: Vector, beta: Double, y: DenseVector): Unit = { require(A.numCols == x.size, @@ -473,27 +473,32 @@ private[spark] object BLAS extends Serializable with Logging { if (alpha == 0.0) { logDebug("gemv: alpha is equal to 0. Returning y.") } else { - A match { - case sparse: SparseMatrix => - gemv(alpha, sparse, x, beta, y) - case dense: DenseMatrix => - gemv(alpha, dense, x, beta, y) + (A, x) match { + case (smA: SparseMatrix, dvx: DenseVector) => + gemv(alpha, smA, dvx, beta, y) + case (smA: SparseMatrix, svx: SparseVector) => + gemv(alpha, smA, svx, beta, y) + case (dmA: DenseMatrix, dvx: DenseVector) => + gemv(alpha, dmA, dvx, beta, y) + case (dmA: DenseMatrix, svx: SparseVector) => + gemv(alpha, dmA, svx, beta, y) case _ => - throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") + throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " + + s"${A.getClass} and vector type ${x.getClass}.") } } } /** * y := alpha * A * x + beta * y - * For `DenseMatrix` A. + * For `DenseMatrix` A and `DenseVector` x. */ private def gemv( alpha: Double, A: DenseMatrix, x: DenseVector, beta: Double, - y: DenseVector): Unit = { + y: DenseVector): Unit = { val tStrA = if (A.isTransposed) "T" else "N" val mA = if (!A.isTransposed) A.numRows else A.numCols val nA = if (!A.isTransposed) A.numCols else A.numRows @@ -503,14 +508,134 @@ private[spark] object BLAS extends Serializable with Logging { /** * y := alpha * A * x + beta * y - * For `SparseMatrix` A. + * For `DenseMatrix` A and `SparseVector` x. + */ + private def gemv( + alpha: Double, + A: DenseMatrix, + x: SparseVector, + beta: Double, + y: DenseVector): Unit = { + val mA: Int = A.numRows + val nA: Int = A.numCols + + val Avals = A.values + + val xIndices = x.indices + val xNnz = xIndices.length + val xValues = x.values + val yValues = y.values + + if (alpha == 0.0) { + scal(beta, y) + return + } + + if (A.isTransposed) { + var rowCounterForA = 0 + while (rowCounterForA < mA) { + var sum = 0.0 + var k = 0 + while (k < xNnz) { + sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA) + k += 1 + } + yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA) + rowCounterForA += 1 + } + } else { + var rowCounterForA = 0 + while (rowCounterForA < mA) { + var sum = 0.0 + var k = 0 + while (k < xNnz) { + sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA) + k += 1 + } + yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA) + rowCounterForA += 1 + } + } + } + + /** + * y := alpha * A * x + beta * y + * For `SparseMatrix` A and `SparseVector` x. + */ + private def gemv( + alpha: Double, + A: SparseMatrix, + x: SparseVector, + beta: Double, + y: DenseVector): Unit = { + val xValues = x.values + val xIndices = x.indices + val xNnz = xIndices.length + + val yValues = y.values + + val mA: Int = A.numRows + val nA: Int = A.numCols + + val Avals = A.values + val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs + val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices + + if (alpha == 0.0) { + scal(beta, y) + return + } + + if (A.isTransposed) { + var rowCounter = 0 + while (rowCounter < mA) { + var i = Arows(rowCounter) + val indEnd = Arows(rowCounter + 1) + var sum = 0.0 + var k = 0 + while (k < xNnz && i < indEnd) { + if (xIndices(k) == Acols(i)) { + sum += Avals(i) * xValues(k) + i += 1 + } + k += 1 + } + yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) + rowCounter += 1 + } + } else { + scal(beta, y) + + var colCounterForA = 0 + var k = 0 + while (colCounterForA < nA && k < xNnz) { + if (xIndices(k) == colCounterForA) { + var i = Acols(colCounterForA) + val indEnd = Acols(colCounterForA + 1) + + val xTemp = xValues(k) * alpha + while (i < indEnd) { + val rowIndex = Arows(i) + yValues(Arows(i)) += Avals(i) * xTemp + i += 1 + } + k += 1 + } + colCounterForA += 1 + } + } + } + + /** + * y := alpha * A * x + beta * y + * For `SparseMatrix` A and `DenseVector` x. */ private def gemv( alpha: Double, A: SparseMatrix, x: DenseVector, beta: Double, - y: DenseVector): Unit = { + y: DenseVector): Unit = { val xValues = x.values val yValues = y.values val mA: Int = A.numRows @@ -534,10 +659,7 @@ private[spark] object BLAS extends Serializable with Logging { rowCounter += 1 } } else { - // Scale vector first if `beta` is not equal to 0.0 - if (beta != 0.0) { - scal(beta, y) - } + scal(beta, y) // Perform matrix-vector multiplication and add to y var colCounterForA = 0 while (colCounterForA < nA) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 866936aa4f118..ae3ba3099c878 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -81,7 +81,7 @@ private[mllib] object EigenValueDecomposition { require(n * ncv.toLong <= Integer.MAX_VALUE && ncv * (ncv.toLong + 8) <= Integer.MAX_VALUE, s"k = $k and/or n = $n are too large to compute an eigendecomposition") - + var ido = new intW(0) var info = new intW(0) var resid = new Array[Double](n) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 3fa5e068d16d4..9584da8e3a0f9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -77,8 +77,13 @@ sealed trait Matrix extends Serializable { C } - /** Convenience method for `Matrix`-`DenseVector` multiplication. */ + /** Convenience method for `Matrix`-`DenseVector` multiplication. For binary compatibility. */ def multiply(y: DenseVector): DenseVector = { + multiply(y.asInstanceOf[Vector]) + } + + /** Convenience method for `Matrix`-`Vector` multiplication. */ + def multiply(y: Vector): DenseVector = { val output = new DenseVector(new Array[Double](numRows)) BLAS.gemv(1.0, this, y, 0.0, output) output @@ -273,7 +278,8 @@ class DenseMatrix( override def copy: DenseMatrix = new DenseMatrix(numRows, numCols, values.clone()) - private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f)) + private[mllib] def map(f: Double => Double) = new DenseMatrix(numRows, numCols, values.map(f), + isTransposed) private[mllib] def update(f: Double => Double): DenseMatrix = { val len = values.length @@ -535,7 +541,7 @@ class SparseMatrix( } private[mllib] def map(f: Double => Double) = - new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f)) + new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values.map(f), isTransposed) private[mllib] def update(f: Double => Double): SparseMatrix = { val len = values.length diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index f6bcdf83cd337..2ffa497a99d93 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -176,27 +176,31 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { } override def serialize(obj: Any): Row = { - val row = new GenericMutableRow(4) obj match { case SparseVector(size, indices, values) => + val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) row.update(2, indices.toSeq) row.update(3, values.toSeq) + row case DenseVector(values) => + val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) row.update(3, values.toSeq) + row + // TODO: There are bugs in UDT serialization because we don't have a clear separation between + // TODO: internal SQL types and language specific types (including UDT). UDT serialize and + // TODO: deserialize may get called twice. See SPARK-7186. + case row: Row => + row } - row } override def deserialize(datum: Any): Vector = { datum match { - // TODO: something wrong with UDT serialization - case v: Vector => - v case row: Row => require(row.length == 4, s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") @@ -211,6 +215,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val values = row.getAs[Iterable[Double]](3).toArray new DenseVector(values) } + // TODO: There are bugs in UDT serialization because we don't have a clear separation between + // TODO: internal SQL types and language specific types (including UDT). UDT serialize and + // TODO: deserialize may get called twice. See SPARK-7186. + case v: Vector => + v } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 9a89a6f3a515f..1626da9c3d2ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -219,7 +219,7 @@ class RowMatrix( val computeMode = mode match { case "auto" => - if(k > 5000) { + if (k > 5000) { logWarning(s"computing svd with k=$k and n=$n, please check necessity") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 3ed3a5b9b3843..9f463e0cafb6f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -116,7 +116,8 @@ class L1Updater extends Updater { // Apply proximal operator (soft thresholding) val shrinkageVal = regParam * thisIterStepSize var i = 0 - while (i < brzWeights.length) { + val len = brzWeights.length + while (i < len) { val wi = brzWeights(i) brzWeights(i) = signum(wi) * max(0.0, abs(wi) - shrinkageVal) i += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala index 354e90f3eeaa6..5e882d4ebb10b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala @@ -23,13 +23,16 @@ import javax.xml.transform.stream.StreamResult import org.jpmml.model.JAXBUtil import org.apache.spark.SparkContext +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory /** + * :: DeveloperApi :: * Export model to the PMML format * Predictive Model Markup Language (PMML) is an XML-based file format * developed by the Data Mining Group (www.dmg.org). */ +@DeveloperApi trait PMMLExportable { /** @@ -41,30 +44,38 @@ trait PMMLExportable { } /** + * :: Experimental :: * Export the model to a local file in PMML format */ + @Experimental def toPMML(localPath: String): Unit = { toPMML(new StreamResult(new File(localPath))) } /** + * :: Experimental :: * Export the model to a directory on a distributed file system in PMML format */ + @Experimental def toPMML(sc: SparkContext, path: String): Unit = { val pmml = toPMML() sc.parallelize(Array(pmml), 1).saveAsTextFile(path) } /** + * :: Experimental :: * Export the model to the OutputStream in PMML format */ + @Experimental def toPMML(outputStream: OutputStream): Unit = { toPMML(new StreamResult(outputStream)) } /** + * :: Experimental :: * Export the model to a String in PMML format */ + @Experimental def toPMML(): String = { val writer = new StringWriter toPMML(new StreamResult(writer)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala index 34b447584e521..622b53a252ac5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExport.scala @@ -27,10 +27,10 @@ import org.apache.spark.mllib.regression.GeneralizedLinearModel * PMML Model Export for GeneralizedLinearModel class with binary ClassificationModel */ private[mllib] class BinaryClassificationPMMLModelExport( - model : GeneralizedLinearModel, + model : GeneralizedLinearModel, description : String, normalizationMethod : RegressionNormalizationMethodType, - threshold: Double) + threshold: Double) extends PMMLModelExport { populateBinaryClassificationPMML() @@ -72,7 +72,7 @@ private[mllib] class BinaryClassificationPMMLModelExport( .withUsageType(FieldUsageType.ACTIVE)) regressionTableYES.withNumericPredictors(new NumericPredictor(fields(i), model.weights(i))) } - + // add target field val targetField = FieldName.create("target") dataDictionary @@ -80,9 +80,9 @@ private[mllib] class BinaryClassificationPMMLModelExport( miningSchema .withMiningFields(new MiningField(targetField) .withUsageType(FieldUsageType.TARGET)) - + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) - + pmml.setDataDictionary(dataDictionary) pmml.withModels(regressionModel) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index ebdeae50bb32f..c5fdecd3ca17f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -25,7 +25,7 @@ import scala.beans.BeanProperty import org.dmg.pmml.{Application, Header, PMML, Timestamp} private[mllib] trait PMMLModelExport { - + /** * Holder of the exported model in PMML format */ @@ -33,7 +33,7 @@ private[mllib] trait PMMLModelExport { val pmml: PMML = new PMML setHeader(pmml) - + private def setHeader(pmml: PMML): Unit = { val version = getClass.getPackage.getImplementationVersion val app = new Application().withName("Apache Spark MLlib").withVersion(version) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index c16e83d6a067d..29bd689e1185a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -27,9 +27,9 @@ import org.apache.spark.mllib.regression.LinearRegressionModel import org.apache.spark.mllib.regression.RidgeRegressionModel private[mllib] object PMMLModelExportFactory { - + /** - * Factory object to help creating the necessary PMMLModelExport implementation + * Factory object to help creating the necessary PMMLModelExport implementation * taking as input the machine learning model (for example KMeansModel). */ def createPMMLModelExport(model: Any): PMMLModelExport = { @@ -44,7 +44,7 @@ private[mllib] object PMMLModelExportFactory { new GeneralizedLinearPMMLModelExport(lasso, "lasso regression") case svm: SVMModel => new BinaryClassificationPMMLModelExport( - svm, "linear SVM", RegressionNormalizationMethodType.NONE, + svm, "linear SVM", RegressionNormalizationMethodType.NONE, svm.getThreshold.getOrElse(0.0)) case logistic: LogisticRegressionModel => if (logistic.numClasses == 2) { @@ -60,5 +60,5 @@ private[mllib] object PMMLModelExportFactory { "PMML Export not supported for model: " + model.getClass.getName) } } - + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala index 8341bb86afd71..174d5e0f6c9f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -52,7 +52,7 @@ object RandomRDDs { numPartitions: Int = 0, seed: Long = Utils.random.nextLong()): RDD[Double] = { val uniform = new UniformGenerator() - randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) + randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) } /** @@ -234,7 +234,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param shape shape parameter (> 0) for the gamma distribution - * @param scale scale parameter (> 0) for the gamma distribution + * @param scale scale parameter (> 0) for the gamma distribution * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). @@ -293,7 +293,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param mean mean for the log normal distribution - * @param std standard deviation for the log normal distribution + * @param std standard deviation for the log normal distribution * @param size Size of the RDD. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). * @param seed Random seed (default: a random long integer). @@ -671,7 +671,7 @@ object RandomRDDs { * * @param sc SparkContext used to create the RDD. * @param shape shape parameter (> 0) for the gamma distribution. - * @param scale scale parameter (> 0) for the gamma distribution. + * @param scale scale parameter (> 0) for the gamma distribution. * @param numRows Number of Vectors in the RDD. * @param numCols Number of elements in each Vector. * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index dddefe1944e9d..93290e6508529 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -175,7 +175,7 @@ class ALS private ( /** * :: DeveloperApi :: * Sets storage level for final RDDs (user/product used in MatrixFactorizationModel). The default - * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. + * value is `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g. * `MEMORY_AND_DISK_SER` and set `spark.rdd.compress` to `true` to reduce the space requirement, * at the cost of speed. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 88c2148403313..93aa41e49961e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -281,8 +281,8 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) - model.userFeatures.toDF("id", "features").saveAsParquetFile(userPath(path)) - model.productFeatures.toDF("id", "features").saveAsParquetFile(productPath(path)) + model.userFeatures.toDF("id", "features").write.parquet(userPath(path)) + model.productFeatures.toDF("id", "features").write.parquet(productPath(path)) } def load(sc: SparkContext, path: String): MatrixFactorizationModel = { @@ -292,11 +292,11 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { assert(className == thisClassName) assert(formatVersion == thisFormatVersion) val rank = (metadata \ "rank").extract[Int] - val userFeatures = sqlContext.parquetFile(userPath(path)) + val userFeatures = sqlContext.read.parquet(userPath(path)) .map { case Row(id: Int, features: Seq[_]) => (id, features.asInstanceOf[Seq[Double]].toArray) } - val productFeatures = sqlContext.parquetFile(productPath(path)) + val productFeatures = sqlContext.read.parquet(productPath(path)) .map { case Row(id: Int, features: Seq[_]) => (id, features.asInstanceOf[Seq[Double]].toArray) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala index be2a00c2dfea4..f3b46c75c05f3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -69,7 +69,8 @@ class IsotonicRegressionModel ( /** Asserts the input array is monotone with the given ordering. */ private def assertOrdered(xs: Array[Double])(implicit ord: Ordering[Double]): Unit = { var i = 1 - while (i < xs.length) { + val len = xs.length + while (i < len) { require(ord.compare(xs(i - 1), xs(i)) <= 0, s"Elements (${xs(i - 1)}, ${xs(i)}) are not ordered.") i += 1 @@ -169,26 +170,26 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { case class Data(boundary: Double, prediction: Double) def save( - sc: SparkContext, - path: String, - boundaries: Array[Double], - predictions: Array[Double], + sc: SparkContext, + path: String, + boundaries: Array[Double], + predictions: Array[Double], isotonic: Boolean): Unit = { val sqlContext = new SQLContext(sc) val metadata = compact(render( - ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("isotonic" -> isotonic))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path)) sqlContext.createDataFrame( boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) } - ).saveAsParquetFile(dataPath(path)) + ).write.parquet(dataPath(path)) } def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = { val sqlContext = new SQLContext(sc) - val dataRDD = sqlContext.parquetFile(dataPath(path)) + val dataRDD = sqlContext.read.parquet(dataPath(path)) checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("boundary", "prediction").collect() @@ -202,7 +203,7 @@ object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] { override def load(sc: SparkContext, path: String): IsotonicRegressionModel = { implicit val formats = DefaultFormats val (loadedClassName, version, metadata) = loadMetadata(sc, path) - val isotonic = (metadata \ "isotonic").extract[Boolean] + val isotonic = (metadata \ "isotonic").extract[Boolean] val classNameV1_0 = SaveLoadV1_0.thisClassName (loadedClassName, version) match { case (className, "1.0") if className == classNameV1_0 => @@ -329,11 +330,12 @@ class IsotonicRegression private (private var isotonic: Boolean) extends Seriali } var i = 0 - while (i < input.length) { + val len = input.length + while (i < len) { var j = i // Find monotonicity violating sequence, if any. - while (j < input.length - 1 && input(j)._1 > input(j + 1)._1) { + while (j < len - 1 && input(j)._1 > input(j + 1)._1) { j = j + 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala index b55944f74f623..317d3a5702636 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala @@ -60,7 +60,7 @@ private[regression] object GLMRegressionModel { val data = Data(weights, intercept) val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF() // TODO: repartition with 1 partition after SPARK-5532 gets fixed - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } /** @@ -72,7 +72,7 @@ private[regression] object GLMRegressionModel { def loadData(sc: SparkContext, path: String, modelClass: String, numFeatures: Int): Data = { val datapath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) - val dataRDD = sqlContext.parquetFile(datapath) + val dataRDD = sqlContext.read.parquet(datapath) val dataArray = dataRDD.select("weights", "intercept").take(1) assert(dataArray.size == 1, s"Unable to load $modelClass data from: $datapath") val data = dataArray(0) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index 79747cc5d7d74..58a50f9c19f14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -17,52 +17,101 @@ package org.apache.spark.mllib.stat +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD -private[stat] object KernelDensity { +/** + * :: Experimental :: + * Kernel density estimation. Given a sample from a population, estimate its probability density + * function at each of the given evaluation points using kernels. Only Gaussian kernel is supported. + * + * Scala example: + * + * {{{ + * val sample = sc.parallelize(Seq(0.0, 1.0, 4.0, 4.0)) + * val kd = new KernelDensity() + * .setSample(sample) + * .setBandwidth(3.0) + * val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) + * }}} + */ +@Experimental +class KernelDensity extends Serializable { + + import KernelDensity._ + + /** Bandwidth of the kernel function. */ + private var bandwidth: Double = 1.0 + + /** A sample from a population. */ + private var sample: RDD[Double] = _ + /** - * Given a set of samples from a distribution, estimates its density at the set of given points. - * Uses a Gaussian kernel with the given standard deviation. + * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`). */ - def estimate(samples: RDD[Double], standardDeviation: Double, - evaluationPoints: Array[Double]): Array[Double] = { - if (standardDeviation <= 0.0) { - throw new IllegalArgumentException("Standard deviation must be positive") - } + def setBandwidth(bandwidth: Double): this.type = { + require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.") + this.bandwidth = bandwidth + this + } - // This gets used in each Gaussian PDF computation, so compute it up front - val logStandardDeviationPlusHalfLog2Pi = - math.log(standardDeviation) + 0.5 * math.log(2 * math.Pi) + /** + * Sets the sample to use for density estimation. + */ + def setSample(sample: RDD[Double]): this.type = { + this.sample = sample + this + } + + /** + * Sets the sample to use for density estimation (for Java users). + */ + def setSample(sample: JavaRDD[java.lang.Double]): this.type = { + this.sample = sample.rdd.asInstanceOf[RDD[Double]] + this + } + + /** + * Estimates probability density function at the given array of points. + */ + def estimate(points: Array[Double]): Array[Double] = { + val sample = this.sample + val bandwidth = this.bandwidth + + require(sample != null, "Must set sample before calling estimate.") - val (points, count) = samples.aggregate((new Array[Double](evaluationPoints.length), 0))( + val n = points.length + // This gets used in each Gaussian PDF computation, so compute it up front + val logStandardDeviationPlusHalfLog2Pi = math.log(bandwidth) + 0.5 * math.log(2 * math.Pi) + val (densities, count) = sample.aggregate((new Array[Double](n), 0L))( (x, y) => { var i = 0 - while (i < evaluationPoints.length) { - x._1(i) += normPdf(y, standardDeviation, logStandardDeviationPlusHalfLog2Pi, - evaluationPoints(i)) + while (i < n) { + x._1(i) += normPdf(y, bandwidth, logStandardDeviationPlusHalfLog2Pi, points(i)) i += 1 } - (x._1, i) + (x._1, x._2 + 1) }, (x, y) => { - var i = 0 - while (i < evaluationPoints.length) { - x._1(i) += y._1(i) - i += 1 - } + blas.daxpy(n, 1.0, y._1, 1, x._1, 1) (x._1, x._2 + y._2) }) - - var i = 0 - while (i < points.length) { - points(i) /= count - i += 1 - } - points + blas.dscal(n, 1.0 / count, densities, 1) + densities } +} + +private object KernelDensity { - private def normPdf(mean: Double, standardDeviation: Double, - logStandardDeviationPlusHalfLog2Pi: Double, x: Double): Double = { + /** Evaluates the PDF of a normal distribution. */ + def normPdf( + mean: Double, + standardDeviation: Double, + logStandardDeviationPlusHalfLog2Pi: Double, + x: Double): Double = { val x0 = x - mean val x1 = x0 / standardDeviation val logDensity = -0.5 * x1 * x1 - logStandardDeviationPlusHalfLog2Pi diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index fcc2a148791bd..d321cc554c1cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -70,23 +70,30 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") + val localCurrMean = currMean + val localCurrM2n = currM2n + val localCurrM2 = currM2 + val localCurrL1 = currL1 + val localNnz = nnz + val localCurrMax = currMax + val localCurrMin = currMin sample.foreachActive { (index, value) => if (value != 0.0) { - if (currMax(index) < value) { - currMax(index) = value + if (localCurrMax(index) < value) { + localCurrMax(index) = value } - if (currMin(index) > value) { - currMin(index) = value + if (localCurrMin(index) > value) { + localCurrMin(index) = value } - val prevMean = currMean(index) + val prevMean = localCurrMean(index) val diff = value - prevMean - currMean(index) = prevMean + diff / (nnz(index) + 1.0) - currM2n(index) += (value - currMean(index)) * diff - currM2(index) += value * value - currL1(index) += math.abs(value) + localCurrMean(index) = prevMean + diff / (localNnz(index) + 1.0) + localCurrM2n(index) += (value - localCurrMean(index)) * diff + localCurrM2(index) += value * value + localCurrL1(index) += math.abs(value) - nnz(index) += 1.0 + localNnz(index) += 1.0 } } @@ -130,14 +137,14 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } } else if (totalCnt == 0 && other.totalCnt != 0) { this.n = other.n - this.currMean = other.currMean.clone - this.currM2n = other.currM2n.clone - this.currM2 = other.currM2.clone - this.currL1 = other.currL1.clone + this.currMean = other.currMean.clone() + this.currM2n = other.currM2n.clone() + this.currM2 = other.currM2.clone() + this.currL1 = other.currL1.clone() this.totalCnt = other.totalCnt - this.nnz = other.nnz.clone - this.currMax = other.currMax.clone - this.currMin = other.currMin.clone + this.nnz = other.nnz.clone() + this.currMax = other.currMax.clone() + this.currMin = other.currMin.clone() } this } @@ -165,7 +172,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if (denominator > 0.0) { val deltaMean = currMean var i = 0 - while (i < currM2n.size) { + val len = currM2n.length + while (i < len) { realVariance(i) = currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt realVariance(i) /= denominator @@ -211,7 +219,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val realMagnitude = Array.ofDim[Double](n) var i = 0 - while (i < currM2.size) { + val len = currM2.length + while (i < len) { realMagnitude(i) = math.sqrt(currM2(i)) i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 32561620ac914..b3fad0c52d655 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -149,18 +149,4 @@ object Statistics { def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) } - - /** - * Given an empirical distribution defined by the input RDD of samples, estimate its density at - * each of the given evaluation points using a Gaussian kernel. - * - * @param samples The samples RDD used to define the empirical distribution. - * @param standardDeviation The standard deviation of the kernel Gaussians. - * @param evaluationPoints The points at which to estimate densities. - * @return An array the same size as evaluationPoints with the density at each point. - */ - def kernelDensity(samples: RDD[Double], standardDeviation: Double, - evaluationPoints: Iterable[Double]): Array[Double] = { - KernelDensity.estimate(samples, standardDeviation, evaluationPoints.toArray) - } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index cd6add9d60b0d..cf51b24ff777f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -29,102 +29,102 @@ import org.apache.spark.mllib.util.MLUtils * the event that the covariance matrix is singular, the density will be computed in a * reduced dimensional subspace under which the distribution is supported. * (see [[http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case]]) - * + * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ @DeveloperApi class MultivariateGaussian ( - val mu: Vector, + val mu: Vector, val sigma: Matrix) extends Serializable { require(sigma.numCols == sigma.numRows, "Covariance matrix must be square") require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size") - + private val breezeMu = mu.toBreeze.toDenseVector - + /** * private[mllib] constructor - * + * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution */ private[mllib] def this(mu: DBV[Double], sigma: DBM[Double]) = { this(Vectors.fromBreeze(mu), Matrices.fromBreeze(sigma)) } - + /** * Compute distribution dependent constants: * rootSigmaInv = D^(-1/2)^ * U, where sigma = U * D * U.t - * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) + * u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - + /** Returns density of this multivariate Gaussian at given point, x */ def pdf(x: Vector): Double = { pdf(x.toBreeze) } - + /** Returns the log-density of this multivariate Gaussian at given point, x */ def logpdf(x: Vector): Double = { logpdf(x.toBreeze) } - + /** Returns density of this multivariate Gaussian at given point, x */ private[mllib] def pdf(x: BV[Double]): Double = { math.exp(logpdf(x)) } - + /** Returns the log-density of this multivariate Gaussian at given point, x */ private[mllib] def logpdf(x: BV[Double]): Double = { val delta = x - breezeMu val v = rootSigmaInv * delta u + v.t * v * -0.5 } - + /** * Calculate distribution dependent components used for the density function: * pdf(x) = (2*pi)^(-k/2)^ * det(sigma)^(-1/2)^ * exp((-1/2) * (x-mu).t * inv(sigma) * (x-mu)) * where k is length of the mean vector. - * - * We here compute distribution-fixed parts + * + * We here compute distribution-fixed parts * log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^) * and * D^(-1/2)^ * U, where sigma = U * D * U.t - * + * * Both the determinant and the inverse can be computed from the singular value decomposition * of sigma. Noting that covariance matrices are always symmetric and positive semi-definite, * we can use the eigendecomposition. We also do not compute the inverse directly; noting - * that - * + * that + * * sigma = U * D * U.t - * inv(Sigma) = U * inv(D) * U.t + * inv(Sigma) = U * inv(D) * U.t * = (D^{-1/2}^ * U).t * (D^{-1/2}^ * U) - * + * * and thus - * + * * -0.5 * (x-mu).t * inv(Sigma) * (x-mu) = -0.5 * norm(D^{-1/2}^ * U * (x-mu))^2^ - * - * To guard against singular covariance matrices, this method computes both the + * + * To guard against singular covariance matrices, this method computes both the * pseudo-determinant and the pseudo-inverse (Moore-Penrose). Singular values are considered * to be non-zero only if they exceed a tolerance based on machine precision, matrix size, and * relation to the maximum singular value (same tolerance used by, e.g., Octave). */ private def calculateCovarianceConstants: (DBM[Double], Double) = { val eigSym.EigSym(d, u) = eigSym(sigma.toBreeze.toDenseMatrix) // sigma = u * diag(d) * u.t - + // For numerical stability, values are considered to be non-zero only if they exceed tol. // This prevents any inverted value from exceeding (eps * n * max(d))^-1 val tol = MLUtils.EPSILON * max(d) * d.length - + try { // log(pseudo-determinant) is sum of the logs of all non-zero singular values val logPseudoDetSigma = d.activeValuesIterator.filter(_ > tol).map(math.log).sum - - // calculate the root-pseudo-inverse of the diagonal matrix of singular values + + // calculate the root-pseudo-inverse of the diagonal matrix of singular values // by inverting the square root of all non-zero values val pinvS = diag(new DBV(d.map(v => if (v > tol) math.sqrt(1.0 / v) else 0.0).toArray)) - + (pinvS * u, -0.5 * (mu.size * math.log(2.0 * math.Pi) + logPseudoDetSigma)) } catch { case uex: UnsupportedOperationException => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index ea82d39b72c03..23c8d7c7c8075 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -196,7 +196,7 @@ private[stat] object ChiSqTest extends Logging { * Pearson's independence test on the input contingency matrix. * TODO: optimize for SparseMatrix when it becomes supported. */ - def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = { + def chiSquaredMatrix(counts: Matrix, methodName: String = PEARSON.name): ChiSqTestResult = { val method = methodFromString(methodName) val numRows = counts.numRows val numCols = counts.numCols @@ -205,8 +205,10 @@ private[stat] object ChiSqTest extends Logging { val colSums = new Array[Double](numCols) val rowSums = new Array[Double](numRows) val colMajorArr = counts.toArray + val colMajorArrLen = colMajorArr.length + var i = 0 - while (i < colMajorArr.size) { + while (i < colMajorArrLen) { val elem = colMajorArr(i) if (elem < 0.0) { throw new IllegalArgumentException("Contingency table cannot contain negative entries.") @@ -220,7 +222,7 @@ private[stat] object ChiSqTest extends Logging { // second pass to collect statistic var statistic = 0.0 var j = 0 - while (j < colMajorArr.size) { + while (j < colMajorArrLen) { val col = j / numRows val colSum = colSums(col) if (colSum == 0.0) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index dfe3a0b6913ef..cecd1fed896d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -169,7 +169,7 @@ object DecisionTree extends Serializable with Logging { numClasses: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, - categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { + categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).run(input) @@ -768,7 +768,7 @@ object DecisionTree extends Serializable with Logging { */ private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { + rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) val predict = calculatePredict(parentNodeAgg) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 1f779584dcffd..a835f96d5d0e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -60,12 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false) + case Regression => + GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, - remappedInput, boostingStrategy, validate=false) + GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -93,8 +93,8 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost( - input, validationInput, boostingStrategy, validate=true) + case Regression => + GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map( @@ -102,7 +102,7 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) val remappedValidationInput = validationInput.map( x => new LabeledPoint((x.label * 2) - 1, x.features)) GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate=true) + validate = true) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -270,7 +270,7 @@ object GradientBoostedTrees extends Logging { logInfo(s"$timer") if (persistedInput) input.unpersist() - + if (validate) { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 055e60c7d9c95..069959976a188 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +import org.apache.spark.util.random.SamplingUtils /** * :: Experimental :: @@ -248,7 +249,7 @@ private class RandomForest ( try { nodeIdCache.get.deleteAllCheckpoints() } catch { - case e:IOException => + case e: IOException => logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") } } @@ -473,9 +474,8 @@ object RandomForest extends Serializable with Logging { val (treeIndex, node) = nodeQueue.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir) - Some(rng.shuffle(Range(0, metadata.numFeatures).toList) - .take(metadata.numFeaturesPerNode).toArray) + Some(SamplingUtils.reservoirSampleAndCount(Range(0, + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) } else { None } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 60e2ab2bb829e..72eb24c49264a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -111,11 +111,12 @@ private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { * Add the stats from another calculator into this one, modifying and returning this calculator. */ def add(other: ImpurityCalculator): ImpurityCalculator = { - require(stats.size == other.stats.size, + require(stats.length == other.stats.length, s"Two ImpurityCalculator instances cannot be added with different counts sizes." + - s" Sizes are ${stats.size} and ${other.stats.size}.") + s" Sizes are ${stats.length} and ${other.stats.length}.") var i = 0 - while (i < other.stats.size) { + val len = other.stats.length + while (i < len) { stats(i) += other.stats(i) i += 1 } @@ -127,11 +128,12 @@ private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { * calculator. */ def subtract(other: ImpurityCalculator): ImpurityCalculator = { - require(stats.size == other.stats.size, + require(stats.length == other.stats.length, s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." + - s" Sizes are ${stats.size} and ${other.stats.size}.") + s" Sizes are ${stats.length} and ${other.stats.length}.") var i = 0 - while (i < other.stats.size) { + val len = other.stats.length + while (i < len) { stats(i) -= other.stats(i) i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 331af428533de..25bb1453db404 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -223,14 +223,14 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { val dataRDD: DataFrame = sc.parallelize(nodes) .map(NodeData.apply(0, _)) .toDF() - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = { val datapath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) // Load Parquet data. - val dataRDD = sqlContext.parquetFile(datapath) + val dataRDD = sqlContext.read.parquet(datapath) // Check schema explicitly since erasure makes it hard to use match-case for checking. Loader.checkSchema[NodeData](dataRDD.schema) val nodes = dataRDD.map(NodeData.apply) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 431a839817eac..a6d1398fc267b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -83,7 +83,7 @@ class Node ( def predict(features: Vector) : Double = { if (isLeaf) { predict.predict - } else{ + } else { if (split.get.featureType == Continuous) { if (features(split.get.feature) <= split.get.threshold) { leftNode.get.predict(features) @@ -151,9 +151,9 @@ class Node ( s"(feature ${split.feature} > ${split.threshold})" } case Categorical => if (left) { - s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})" + s"(feature ${split.feature} in ${split.categories.mkString("{", ",", "}")})" } else { - s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})" + s"(feature ${split.feature} not in ${split.categories.mkString("{", ",", "}")})" } } } @@ -161,9 +161,9 @@ class Node ( if (isLeaf) { prefix + s"Predict: ${predict.predict}\n" } else { - prefix + s"If ${splitToString(split.get, left=true)}\n" + + prefix + s"If ${splitToString(split.get, left = true)}\n" + leftNode.get.subtreeToString(indentFactor + 1) + - prefix + s"Else ${splitToString(split.get, left=false)}\n" + + prefix + s"Else ${splitToString(split.get, left = false)}\n" + rightNode.get.subtreeToString(indentFactor + 1) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 8341219bfa71c..1e3333d8d81d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -414,7 +414,7 @@ private[tree] object TreeEnsembleModel extends Logging { val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) => tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node)) }.toDF() - dataRDD.saveAsParquetFile(Loader.dataPath(path)) + dataRDD.write.parquet(Loader.dataPath(path)) } /** @@ -437,7 +437,7 @@ private[tree] object TreeEnsembleModel extends Logging { treeAlgo: String): Array[DecisionTreeModel] = { val datapath = Loader.dataPath(path) val sqlContext = new SQLContext(sc) - val nodes = sqlContext.parquetFile(datapath).map(NodeData.apply) + val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) val trees = constructTrees(nodes) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo))) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index b1a4517344970..b4e33c98ba7e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -107,7 +107,8 @@ object LinearDataGenerator { x.foreach { v => var i = 0 - while (i < v.length) { + val len = v.length + while (i < len) { v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) i += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 0c5b4f9d04a74..bd73a866c8a82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -82,8 +82,7 @@ object MFDataGenerator { BLAS.gemm(z, A, B, 1.0, fullData) val df = rank * (m + n - rank) - val sampSize = scala.math.min(scala.math.round(trainSampFact * df), - scala.math.round(.99 * m * n)).toInt + val sampSize = math.min(math.round(trainSampFact * df), math.round(.99 * m * n)).toInt val rand = new Random() val mn = m * n val shuffled = rand.shuffle((0 until mn).toList) @@ -102,8 +101,8 @@ object MFDataGenerator { // optionally generate testing data if (test) { - val testSampSize = scala.math - .min(scala.math.round(sampSize * testSampFact),scala.math.round(mn - sampSize)).toInt + val testSampSize = math.min( + math.round(sampSize * testSampFact), math.round(mn - sampSize)).toInt val testOmega = shuffled.slice(sampSize, sampSize + testSampSize) val testOrdered = testOmega.sortWith(_ < _).toArray val testData: RDD[(Int, Int, Double)] = sc.parallelize(testOrdered) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 681f4c618d302..52d6468a72af7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -82,6 +82,18 @@ object MLUtils { val value = indexAndValue(1).toDouble (index, value) }.unzip + + // check if indices are one-based and in ascending order + var previous = -1 + var i = 0 + val indicesLength = indices.length + while (i < indicesLength) { + val current = indices(i) + require(current > previous, "indices should be one-based and in ascending order" ) + previous = current + i += 1 + } + (label, indices.toArray, values.toArray) } @@ -265,7 +277,7 @@ object MLUtils { } Vectors.fromBreeze(vector1) } - + /** * Returns the squared Euclidean distance between two vectors. The following formula will be used * if it does not introduce too much numerical error: diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 7e7189a2b1d53..f75e024a713ee 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -84,7 +84,7 @@ public void logisticRegressionWithSetters() { .setThreshold(0.6) .setProbabilityCol("myProbability"); LogisticRegressionModel model = lr.fit(dataset); - LogisticRegression parent = model.parent(); + LogisticRegression parent = (LogisticRegression) model.parent(); assert(parent.getMaxIter() == 10); assert(parent.getRegParam() == 1.0); assert(parent.getThreshold() == 0.6); @@ -110,7 +110,7 @@ public void logisticRegressionWithSetters() { // Call fit() with new params, and check as many params as we can. LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); - LogisticRegression parent2 = model2.parent(); + LogisticRegression parent2 = (LogisticRegression) model2.parent(); assert(parent2.getMaxIter() == 5); assert(parent2.getRegParam() == 0.1); assert(parent2.getThreshold() == 0.4); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java new file mode 100644 index 0000000000000..d5bd230a957a1 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaBucketizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaBucketizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void bucketizerTest() { + double[] splits = {-0.5, 0.0, 0.5}; + + JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(-0.5), + RowFactory.create(-0.3), + RowFactory.create(0.0), + RowFactory.create(0.2) + )); + StructType schema = new StructType(new StructField[] { + new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) + }); + DataFrame dataset = jsql.createDataFrame(data, schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits); + + Row[] result = bucketizer.transform(dataset).select("result").collect(); + + for (Row r : result) { + double index = r.getDouble(0); + Assert.assertTrue((index >= 0) && (index <= 1)); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 23463ab5fe848..599e9cfd23ad4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -55,25 +55,30 @@ public void tearDown() { @Test public void hashingTF() { JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( - RowFactory.create(0, "Hi I heard about Spark"), - RowFactory.create(0, "I wish Java could use case classes"), - RowFactory.create(1, "Logistic regression models are neat") + RowFactory.create(0.0, "Hi I heard about Spark"), + RowFactory.create(0.0, "I wish Java could use case classes"), + RowFactory.create(1.0, "Logistic regression models are neat") )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceDataFrame = jsql.createDataFrame(jrdd, schema); - Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words"); - DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame); + DataFrame sentenceData = jsql.createDataFrame(jrdd, schema); + Tokenizer tokenizer = new Tokenizer() + .setInputCol("sentence") + .setOutputCol("words"); + DataFrame wordsData = tokenizer.transform(sentenceData); int numFeatures = 20; HashingTF hashingTF = new HashingTF() .setInputCol("words") - .setOutputCol("features") + .setOutputCol("rawFeatures") .setNumFeatures(numFeatures); - DataFrame featurized = hashingTF.transform(wordsDataFrame); - for (Row r : featurized.select("features", "words", "label").take(3)) { + DataFrame featurizedData = hashingTF.transform(wordsData); + IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); + IDFModel idfModel = idf.fit(featurizedData); + DataFrame rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").take(3)) { Vector features = r.getAs(0); Assert.assertEquals(features.size(), numFeatures); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java new file mode 100644 index 0000000000000..d82f3b7e8c076 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaNormalizerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaNormalizerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void normalizer() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) + ); + DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + VectorIndexerSuite.FeatureData.class); + Normalizer normalizer = new Normalizer() + .setInputCol("features") + .setOutputCol("normFeatures"); + + // Normalize each Vector using $L^2$ norm. + DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); + l2NormData.count(); + + // Normalize each Vector using $L^\infty$ norm. + DataFrame lInfNormData = + normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); + lInfNormData.count(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java new file mode 100644 index 0000000000000..5e8211c2c5118 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaPolynomialExpansionSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void polynomialExpansionTest() { + PolynomialExpansion polyExpansion = new PolynomialExpansion() + .setInputCol("features") + .setOutputCol("polyFeatures") + .setDegree(3); + + JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create( + Vectors.dense(-2.0, 2.3), + Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17) + ), + RowFactory.create(Vectors.dense(0.0, 0.0), Vectors.dense(new double[9])), + RowFactory.create( + Vectors.dense(0.6, -1.1), + Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331) + ) + )); + + StructType schema = new StructType(new StructField[] { + new StructField("features", new VectorUDT(), false, Metadata.empty()), + new StructField("expected", new VectorUDT(), false, Metadata.empty()) + }); + + DataFrame dataset = jsql.createDataFrame(data, schema); + + Row[] pairs = polyExpansion.transform(dataset) + .select("polyFeatures", "expected") + .collect(); + + for (Row r : pairs) { + double[] polyFeatures = ((Vector)r.get(0)).toArray(); + double[] expected = ((Vector)r.get(1)).toArray(); + Assert.assertArrayEquals(polyFeatures, expected, 1e-1); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java new file mode 100644 index 0000000000000..74eb2733f06ef --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaStandardScalerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStandardScalerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void standardScaler() { + // The tests are to check Java compatibility. + List points = Lists.newArrayList( + new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), + new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) + ); + DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + VectorIndexerSuite.FeatureData.class); + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures") + .setWithStd(true) + .setWithMean(false); + + // Compute summary statistics by fitting the StandardScaler + StandardScalerModel scalerModel = scaler.fit(dataFrame); + + // Normalize each feature to have unit standard deviation. + DataFrame scaledData = scalerModel.transform(dataFrame); + scaledData.count(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java new file mode 100644 index 0000000000000..35b18c5308f61 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaStringIndexerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaStringIndexerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + sqlContext = null; + } + + @Test + public void testStringIndexer() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("label", StringType, false) + }); + JavaRDD rdd = jsc.parallelize( + Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"))); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + + StringIndexer indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex"); + DataFrame output = indexer.fit(dataset).transform(dataset); + + Assert.assertArrayEquals( + new Row[] { c(0, 0.0), c(1, 2.0), c(2, 1.0), c(3, 0.0), c(4, 0.0), c(5, 1.0) }, + output.orderBy("id").select("id", "labelIndex").collect()); + } + + /** An alias for RowFactory.create. */ + private Row c(Object... values) { + return RowFactory.create(values); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java new file mode 100644 index 0000000000000..b7c564caad3bd --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import java.util.Arrays; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaVectorAssemblerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testVectorAssembler() { + StructType schema = createStructType(new StructField[] { + createStructField("id", IntegerType, false), + createStructField("x", DoubleType, false), + createStructField("y", new VectorUDT(), false), + createStructField("name", StringType, false), + createStructField("z", new VectorUDT(), false), + createStructField("n", LongType, false) + }); + Row row = RowFactory.create( + 0, 0.0, Vectors.dense(1.0, 2.0), "a", + Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); + JavaRDD rdd = jsc.parallelize(Arrays.asList(row)); + DataFrame dataset = sqlContext.createDataFrame(rdd, schema); + VectorAssembler assembler = new VectorAssembler() + .setInputCols(new String[] {"x", "y", "z", "n"}) + .setOutputCol("features"); + DataFrame output = assembler.transform(dataset); + Assert.assertEquals( + Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), + output.select("features").first().getAs(0)); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index 161100134c92d..c7ae5468b9429 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.List; +import java.util.Map; import org.junit.After; import org.junit.Assert; @@ -64,7 +65,8 @@ public void vectorIndexerAPI() { .setMaxCategories(2); VectorIndexerModel model = indexer.fit(data); Assert.assertEquals(model.numFeatures(), 2); - Assert.assertEquals(model.categoryMaps().size(), 1); + Map> categoryMaps = model.javaCategoryMaps(); + Assert.assertEquals(categoryMaps.size(), 1); DataFrame indexedData = model.transform(data); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java new file mode 100644 index 0000000000000..39c70157f83c0 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; + +public class JavaWord2VecSuite { + private transient JavaSparkContext jsc; + private transient SQLContext sqlContext; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaWord2VecSuite"); + sqlContext = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void testJavaWord2Vec() { + JavaRDD jrdd = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))), + RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))), + RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" "))) + )); + StructType schema = new StructType(new StructField[]{ + new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) + }); + DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema); + + Word2Vec word2Vec = new Word2Vec() + .setInputCol("text") + .setOutputCol("result") + .setVectorSize(3) + .setMinCount(0); + Word2VecModel model = word2Vec.fit(documentDF); + DataFrame result = model.transform(documentDF); + + for (Row r: result.select("result").collect()) { + double[] polyFeatures = ((Vector)r.get(0)).toArray(); + Assert.assertEquals(polyFeatures.length, 3); + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 8abe575610d19..947ae3a2ce06f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -21,43 +21,66 @@ import com.google.common.collect.Lists; +import org.apache.spark.ml.util.Identifiable$; + /** * A subclass of Params for testing. */ public class JavaTestParams extends JavaParams { - public IntParam myIntParam; + public JavaTestParams() { + this.uid_ = Identifiable$.MODULE$.randomUID("javaTestParams"); + init(); + } + + public JavaTestParams(String uid) { + this.uid_ = uid; + init(); + } + + private String uid_; + + @Override + public String uid() { + return uid_; + } - public int getMyIntParam() { return (Integer)getOrDefault(myIntParam); } + private IntParam myIntParam_; + public IntParam myIntParam() { return myIntParam_; } + + public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); } public JavaTestParams setMyIntParam(int value) { - set(myIntParam, value); return this; + set(myIntParam_, value); return this; } - public DoubleParam myDoubleParam; + private DoubleParam myDoubleParam_; + public DoubleParam myDoubleParam() { return myDoubleParam_; } - public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); } + public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); } public JavaTestParams setMyDoubleParam(double value) { - set(myDoubleParam, value); return this; + set(myDoubleParam_, value); return this; } - public Param myStringParam; + private Param myStringParam_; + public Param myStringParam() { return myStringParam_; } - public String getMyStringParam() { return (String)getOrDefault(myStringParam); } + public String getMyStringParam() { return getOrDefault(myStringParam_); } public JavaTestParams setMyStringParam(String value) { - set(myStringParam, value); return this; + set(myStringParam_, value); return this; } - public JavaTestParams() { - myIntParam = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0)); - myDoubleParam = new DoubleParam(this, "myDoubleParam", "this is a double param", + private void init() { + myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0)); + myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param", ParamValidators.inRange(0.0, 1.0)); List validStrings = Lists.newArrayList("a", "b"); - myStringParam = new Param(this, "myStringParam", "this is a string param", + myStringParam_ = new Param(this, "myStringParam", "this is a string param", ParamValidators.inArray(validStrings)); - setDefault(myIntParam, 1); - setDefault(myDoubleParam, 0.5); + setDefault(myIntParam_, 1); + setDefault(myDoubleParam_, 0.5); + setDefault(myIntParam().w(1), myDoubleParam().w(0.5)); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index a82b86d560b6e..d591a456864e4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -77,14 +77,14 @@ public void linearRegressionWithSetters() { .setMaxIter(10) .setRegParam(1.0); LinearRegressionModel model = lr.fit(dataset); - LinearRegression parent = model.parent(); + LinearRegression parent = (LinearRegression) model.parent(); assertEquals(10, parent.getMaxIter()); assertEquals(1.0, parent.getRegParam(), 0.0); // Call fit() with new params, and check as many params as we can. LinearRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred")); - LinearRegression parent2 = model2.parent(); + LinearRegression parent2 = (LinearRegression) model2.parent(); assertEquals(5, parent2.getMaxIter()); assertEquals(0.1, parent2.getRegParam(), 0.0); assertEquals("thePred", model2.getPredictionCol()); diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala new file mode 100644 index 0000000000000..928301523fba9 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.SparkFunSuite + +class IdentifiableSuite extends SparkFunSuite { + + import IdentifiableSuite.Test + + test("Identifiable") { + val test0 = new Test("test_0") + assert(test0.uid === "test_0") + + val test1 = new Test + assert(test1.uid.startsWith("test_")) + } +} + +object IdentifiableSuite { + + class Test(override val uid: String) extends Identifiable { + def this() = this(Identifiable.randomUID("test")) + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 71fb7f13c39c2..3771c0ea7ad83 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -108,7 +108,7 @@ public Vector call(LabeledPoint v) throws Exception { @Test public void testModelTypeSetters() { NaiveBayes nb = new NaiveBayes() - .setModelType("Bernoulli") - .setModelType("Multinomial"); + .setModelType("bernoulli") + .setModelType("multinomial"); } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 2b04a3034782e..05bf58e63abaf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.ml import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito.when -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamMap import org.apache.spark.sql.DataFrame -class PipelineSuite extends FunSuite { +class PipelineSuite extends SparkFunSuite { abstract class MyModel extends Model[MyModel] diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 17ddd335deb6d..512cffb1acb66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.attribute -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class AttributeGroupSuite extends FunSuite { +class AttributeGroupSuite extends SparkFunSuite { test("attribute group") { val attrs = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index ec9b717e41ce8..72b575d022547 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.ml.attribute -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ -class AttributeSuite extends FunSuite { +class AttributeSuite extends SparkFunSuite { test("default numeric attribute") { val attr: NumericAttribute = NumericAttribute.defaultAttr diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 03af4ecd7a7e0..ae40b0b8ff854 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeClassifierSuite.compareAPIs @@ -251,7 +250,7 @@ class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { */ } -private[ml] object DecisionTreeClassifierSuite extends FunSuite { +private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite { /** * Train 2 decision trees on the given dataset, one using the old API and one using the new API. @@ -266,9 +265,9 @@ private[ml] object DecisionTreeClassifierSuite extends FunSuite { val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newTree = dt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldTreeAsNew = DecisionTreeClassificationModel.fromOld( - oldTree, newTree.parent, categoricalFeatures) + oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 16c758b82c7cd..1302da3c373ff 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends FunSuite with MLlibTestSparkContext { +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import GBTClassifierSuite.compareAPIs @@ -128,9 +127,9 @@ private object GBTClassifierSuite { val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) val newModel = gbt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTClassificationModel.fromOld( - oldModel, newModel.parent, categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 4df8016009171..a755cac3ea76e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -17,28 +17,23 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - -import org.apache.spark.mllib.classification.LogisticRegressionSuite +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} - +import org.apache.spark.sql.{DataFrame, Row} -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { - @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ @transient var binaryDataset: DataFrame = _ private val eps: Double = 1e-5 override def beforeAll(): Unit = { super.beforeAll() - sqlContext = new SQLContext(sc) - dataset = sqlContext.createDataFrame(sc.parallelize(LogisticRegressionSuite - .generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 4)) + dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) /** * Here is the instruction describing how to export the test data into CSV format @@ -60,32 +55,32 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - val testData = LogisticRegressionSuite.generateMultinomialLogisticInput( - weights, xMean, xVariance, true, nPoints, 42) + val testData = generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42) - sqlContext.createDataFrame(sc.parallelize(LogisticRegressionSuite - .generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42), 4)) + sqlContext.createDataFrame( + generateMultinomialLogisticInput(weights, xMean, xVariance, true, nPoints, 42)) } } test("logistic regression: default params") { val lr = new LogisticRegression - assert(lr.getLabelCol == "label") - assert(lr.getFeaturesCol == "features") - assert(lr.getPredictionCol == "prediction") - assert(lr.getRawPredictionCol == "rawPrediction") - assert(lr.getProbabilityCol == "probability") - assert(lr.getFitIntercept == true) + assert(lr.getLabelCol === "label") + assert(lr.getFeaturesCol === "features") + assert(lr.getPredictionCol === "prediction") + assert(lr.getRawPredictionCol === "rawPrediction") + assert(lr.getProbabilityCol === "probability") + assert(lr.getFitIntercept) val model = lr.fit(dataset) model.transform(dataset) .select("label", "probability", "prediction", "rawPrediction") .collect() assert(model.getThreshold === 0.5) - assert(model.getFeaturesCol == "features") - assert(model.getPredictionCol == "prediction") - assert(model.getRawPredictionCol == "rawPrediction") - assert(model.getProbabilityCol == "probability") + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.getRawPredictionCol === "rawPrediction") + assert(model.getProbabilityCol === "probability") assert(model.intercept !== 0.0) + assert(model.hasParent) } test("logistic regression doesn't fit intercept when fitIntercept is off") { @@ -103,7 +98,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { .setThreshold(0.6) .setProbabilityCol("myProbability") val model = lr.fit(dataset) - val parent = model.parent + val parent = model.parent.asInstanceOf[LogisticRegression] assert(parent.getMaxIter === 10) assert(parent.getRegParam === 1.0) assert(parent.getThreshold === 0.6) @@ -129,12 +124,12 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { // Call fit() with new params, and check as many params as we can. val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4, lr.probabilityCol -> "theProb") - val parent2 = model2.parent + val parent2 = model2.parent.asInstanceOf[LogisticRegression] assert(parent2.getMaxIter === 5) assert(parent2.getRegParam === 0.1) assert(parent2.getThreshold === 0.4) assert(model2.getThreshold === 0.4) - assert(model2.getProbabilityCol == "theProb") + assert(model2.getProbabilityCol === "theProb") } test("logistic regression: Predictor, Classifier methods") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index e65ffae918ca9..1d04ccb509057 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -17,28 +17,26 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.util.MetadataUtils -import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.DataFrame -class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { +class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { - @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ @transient var rdd: RDD[LabeledPoint] = _ override def beforeAll(): Unit = { super.beforeAll() - sqlContext = new SQLContext(sc) + val nPoints = 1000 // The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2. @@ -57,7 +55,7 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { test("one-vs-rest: default params") { val numClasses = 3 val ova = new OneVsRest() - ova.setClassifier(new LogisticRegression) + .setClassifier(new LogisticRegression) assert(ova.getLabelCol === "label") assert(ova.getPredictionCol === "prediction") val ovaModel = ova.fit(dataset) @@ -95,9 +93,20 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) ova.fit(datasetWithLabelMetadata) } + + test("SPARK-8049: OneVsRest shouldn't output temp columns") { + val logReg = new LogisticRegression() + .setMaxIter(1) + val ovr = new OneVsRest() + .setClassifier(logReg) + val output = ovr.fit(dataset).transform(dataset) + assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) + } } -private class MockLogisticRegression extends LogisticRegression { +private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { + + def this() = this("mockLogReg") setMaxIter(1) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index c41def9330504..eee9355a67be3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.classification -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint @@ -32,7 +31,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestClassifierSuite.compareAPIs @@ -158,9 +157,11 @@ private object RandomForestClassifierSuite { data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) val newModel = rf.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( - oldModel, newModel.parent, categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) + assert(newModel.hasParent) + assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala new file mode 100644 index 0000000000000..36a1ac6b7996d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ + +class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("Regression Evaluator: default params") { + /** + * Here is the instruction describing how to export the test data into CSV format + * so we can validate the metrics compared with R's mmetric package. + * + * import org.apache.spark.mllib.util.LinearDataGenerator + * val data = sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, + * Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1)) + * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)) + * .saveAsTextFile("path") + */ + val dataset = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + + /** + * Using the following R code to load the data, train the model and evaluate metrics. + * + * > library("glmnet") + * > library("rminer") + * > data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + * > features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + * > label <- as.numeric(data$V1) + * > model <- glmnet(features, label, family="gaussian", alpha = 0, lambda = 0) + * > rmse <- mmetric(label, predict(model, features), metric='RMSE') + * > mae <- mmetric(label, predict(model, features), metric='MAE') + * > r2 <- mmetric(label, predict(model, features), metric='R2') + */ + val trainer = new LinearRegression + val model = trainer.fit(dataset) + val predictions = model.transform(dataset) + + // default = rmse + val evaluator = new RegressionEvaluator() + assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001) + + // r2 score + evaluator.setMetricName("r2") + assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001) + + // mae + evaluator.setMetricName("mae") + assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index caf1b759593f3..7953bd0417191 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -17,21 +17,16 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} - +import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends FunSuite with MLlibTestSparkContext { +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var data: Array[Double] = _ - @transient var sqlContext: SQLContext = _ override def beforeAll(): Unit = { super.beforeAll() - sqlContext = new SQLContext(sc) data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4) } @@ -52,7 +47,7 @@ class BinarizerSuite extends FunSuite with MLlibTestSparkContext { test("Binarize continuous features with setter") { val threshold: Double = 0.2 - val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) + val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) val dataFrame: DataFrame = sqlContext.createDataFrame( data.zip(thresholdBinarized)).toDF("feature", "expected") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 1900820400aee..507a8a7db24c7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -19,22 +19,13 @@ package org.apache.spark.ml.feature import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} - -class BucketizerSuite extends FunSuite with MLlibTestSparkContext { +import org.apache.spark.sql.{DataFrame, Row} - @transient private var sqlContext: SQLContext = _ - - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) - } +class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { test("Bucket continuous features, without -inf,inf") { // Check a set of valid feature values. @@ -117,12 +108,13 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext { } } -private object BucketizerSuite extends FunSuite { +private object BucketizerSuite extends SparkFunSuite { /** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */ def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = { require(feature >= splits.head) var i = 0 - while (i < splits.length - 1) { + val n = splits.length - 1 + while (i < n) { if (feature < splits(i + 1)) return i i += 1 } @@ -138,7 +130,8 @@ private object BucketizerSuite extends FunSuite { s" ${splits.mkString(", ")}") } var i = 0 - while (i < splits.length - 1) { + val n = splits.length - 1 + while (i < n) { // Split i should fall in bucket i. testFeature(splits(i), i) // Value between splits i,i+1 should be in i, which is also true if the (i+1)-th split is inf. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala new file mode 100644 index 0000000000000..7b2d70e644005 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + val hashingTF = new HashingTF + ParamsSuite.checkParams(hashingTF, 3) + } + + test("hashingTF") { + val df = sqlContext.createDataFrame(Seq( + (0, "a a b b c d".split(" ").toSeq) + )).toDF("id", "words") + val n = 100 + val hashingTF = new HashingTF() + .setInputCol("words") + .setOutputCol("features") + .setNumFeatures(n) + val output = hashingTF.transform(df) + val attrGroup = AttributeGroup.fromStructField(output.schema("features")) + require(attrGroup.numAttributes === Some(n)) + val features = output.select("features").first().getAs[Vector](0) + // Assume perfect hash on "a", "b", "c", and "d". + def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n) + val expected = Vectors.sparse(n, + Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) + assert(features ~== expected absTol 1e-14) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala index eaee3443c1f23..d83772e8be755 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala @@ -17,21 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{Row, SQLContext} - -class IDFSuite extends FunSuite with MLlibTestSparkContext { +import org.apache.spark.sql.Row - @transient var sqlContext: SQLContext = _ - - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) - } +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = { dataSet.map { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 9d09f24709e23..9f03470b7f328 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} -class NormalizerSuite extends FunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var data: Array[Vector] = _ @transient var dataFrame: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 92ec407b98d69..2e5036a844562 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,20 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} - +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col -class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { - private var sqlContext: SQLContext = _ - - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) - } +class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) @@ -42,15 +36,16 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { indexer.transform(df) } - test("OneHotEncoder includeFirst = true") { + test("OneHotEncoder dropLast = false") { val transformed = stringIndexed() val encoder = new OneHotEncoder() .setInputCol("labelIndex") .setOutputCol("labelVec") + .setDropLast(false) val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => - val vec = r.get(1).asInstanceOf[Vector] + val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1), vec(2)) }.collect().toSet // a -> 0, b -> 2, c -> 1 @@ -59,22 +54,46 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { assert(output === expected) } - test("OneHotEncoder includeFirst = false") { + test("OneHotEncoder dropLast = true") { val transformed = stringIndexed() val encoder = new OneHotEncoder() - .setIncludeFirst(false) .setInputCol("labelIndex") .setOutputCol("labelVec") val encoded = encoder.transform(transformed) val output = encoded.select("id", "labelVec").map { r => - val vec = r.get(1).asInstanceOf[Vector] + val vec = r.getAs[Vector](1) (r.getInt(0), vec(0), vec(1)) }.collect().toSet // a -> 0, b -> 2, c -> 1 - val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0), - (3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)) + val expected = Set((0, 1.0, 0.0), (1, 0.0, 0.0), (2, 0.0, 1.0), + (3, 1.0, 0.0), (4, 1.0, 0.0), (5, 0.0, 1.0)) assert(output === expected) } + test("input column with ML attribute") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoder() + .setInputCol("size") + .setOutputCol("encoded") + val output = encoder.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) + } + + test("input column without ML attribute") { + val df = sqlContext.createDataFrame(Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply)).toDF("index") + val encoder = new OneHotEncoder() + .setInputCol("index") + .setOutputCol("encoded") + val output = encoder.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index c1d64fba0aa8f..feca866cd711d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,22 +17,15 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite +import org.scalatest.exceptions.TestFailedException +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{Row, SQLContext} -import org.scalatest.exceptions.TestFailedException - -class PolynomialExpansionSuite extends FunSuite with MLlibTestSparkContext { +import org.apache.spark.sql.Row - @transient var sqlContext: SQLContext = _ - - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) - } +class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { test("Polynomial expansion with default parameter") { val data = Array( diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index b6939e5870410..cbf1e8ddcb48a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,19 +17,11 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.SQLContext - -class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { - private var sqlContext: SQLContext = _ - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) - } +class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { test("StringIndexer") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index d186ead8f542f..ac279cb3215c2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -19,64 +19,54 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class RegexTokenizerSuite extends FunSuite with MLlibTestSparkContext { +class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.ml.feature.RegexTokenizerSuite._ - - @transient var sqlContext: SQLContext = _ - - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) - } test("RegexTokenizer") { - val tokenizer = new RegexTokenizer() + val tokenizer0 = new RegexTokenizer() + .setGaps(false) + .setPattern("\\w+|\\p{Punct}") .setInputCol("rawText") .setOutputCol("tokens") - val dataset0 = sqlContext.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")), TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct")) )) - testRegexTokenizer(tokenizer, dataset0) + testRegexTokenizer(tokenizer0, dataset0) val dataset1 = sqlContext.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")), TokenizerTestData("Te,st. punct", Array("punct")) )) + tokenizer0.setMinTokenLength(3) + testRegexTokenizer(tokenizer0, dataset1) - tokenizer.setMinTokenLength(3) - testRegexTokenizer(tokenizer, dataset1) - - tokenizer - .setPattern("\\s") - .setGaps(true) - .setMinTokenLength(0) + val tokenizer2 = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") val dataset2 = sqlContext.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")), - TokenizerTestData("Te,st. punct", Array("Te,st.", "", "punct")) + TokenizerTestData("Te,st. punct", Array("Te,st.", "punct")) )) - testRegexTokenizer(tokenizer, dataset2) + testRegexTokenizer(tokenizer2, dataset2) } } -object RegexTokenizerSuite extends FunSuite { +object RegexTokenizerSuite extends SparkFunSuite { def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { t.transform(dataset) .select("tokens", "wantedTokens") .collect() - .foreach { - case Row(tokens, wantedTokens) => - assert(tokens === wantedTokens) - } + .foreach { case Row(tokens, wantedTokens) => + assert(tokens === wantedTokens) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 0db27607bc274..489abb5af7130 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -17,21 +17,14 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SQLContext} - -class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col - @transient var sqlContext: SQLContext = _ - - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) - } +class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { test("assemble") { import org.apache.spark.ml.feature.VectorAssembler.assemble @@ -68,4 +61,39 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { assert(v === Vectors.sparse(6, Array(1, 2, 4, 5), Array(1.0, 2.0, 3.0, 10.0))) } } + + test("ML attributes") { + val browser = NominalAttribute.defaultAttr.withValues("chrome", "firefox", "safari") + val hour = NumericAttribute.defaultAttr.withMin(0.0).withMax(24.0) + val user = new AttributeGroup("user", Array( + NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"), + NumericAttribute.defaultAttr.withName("salary"))) + val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0))) + val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") + .select( + col("browser").as("browser", browser.toMetadata()), + col("hour").as("hour", hour.toMetadata()), + col("count"), // "count" is an integer column without ML attribute + col("user").as("user", user.toMetadata()), + col("ad")) // "ad" is a vector column without ML attribute + val assembler = new VectorAssembler() + .setInputCols(Array("browser", "hour", "count", "user", "ad")) + .setOutputCol("features") + val output = assembler.transform(df) + val schema = output.schema + val features = AttributeGroup.fromStructField(schema("features")) + assert(features.size === 7) + val browserOut = features.getAttr(0) + assert(browserOut === browser.withIndex(0).withName("browser")) + val hourOut = features.getAttr(1) + assert(hourOut === hour.withIndex(1).withName("hour")) + val countOut = features.getAttr(2) + assert(countOut === NumericAttribute.defaultAttr.withName("count").withIndex(2)) + val userGenderOut = features.getAttr(3) + assert(userGenderOut === user.getAttr("gender").withName("user_gender").withIndex(3)) + val userSalaryOut = features.getAttr(4) + assert(userSalaryOut === user.getAttr("salary").withName("user_salary").withIndex(4)) + assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5)) + assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 38dc83b1241cf..06affc7305cf5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -19,22 +19,17 @@ package org.apache.spark.ml.feature import scala.beans.{BeanInfo, BeanProperty} -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext} - +import org.apache.spark.sql.DataFrame -class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { +class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { import VectorIndexerSuite.FeatureData - @transient var sqlContext: SQLContext = _ - // identical, of length 3 @transient var densePoints1: DataFrame = _ @transient var sparsePoints1: DataFrame = _ @@ -86,7 +81,6 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { checkPair(densePoints1Seq, sparsePoints1Seq) checkPair(densePoints2Seq, sparsePoints2Seq) - sqlContext = new SQLContext(sc) densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 03ba86670d453..94ebc3aebfa37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.ml.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} -class Word2VecSuite extends FunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { test("Word2Vec") { val sqlContext = new SQLContext(sc) @@ -35,9 +34,9 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) val codes = Map( - "a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451), - "b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342), - "c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351) + "a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451), + "b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342), + "c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351) ) val expected = doc.map { sentence => @@ -52,6 +51,7 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { .setVectorSize(3) .setInputCol("text") .setOutputCol("result") + .setSeed(42L) .fit(docDF) model.transform(docDF).select("result", "expected").collect().foreach { diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 1505ad872536b..778abcba22c10 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -19,8 +19,7 @@ package org.apache.spark.ml.impl import scala.collection.JavaConverters._ -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.tree._ @@ -29,7 +28,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, DataFrame} -private[ml] object TreeTests extends FunSuite { +private[ml] object TreeTests extends SparkFunSuite { /** * Convert the given data to a DataFrame, and set the features and label metadata. diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 6056e7d3f6ff8..96094d7a099aa 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -17,27 +17,28 @@ package org.apache.spark.ml.param -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ParamsSuite extends FunSuite { +class ParamsSuite extends SparkFunSuite { test("param") { val solver = new TestParams() + val uid = solver.uid import solver.{maxIter, inputCol} assert(maxIter.name === "maxIter") - assert(maxIter.doc === "max number of iterations (>= 0)") - assert(maxIter.parent.eq(solver)) - assert(maxIter.toString === "maxIter: max number of iterations (>= 0) (default: 10)") + assert(maxIter.doc === "maximum number of iterations (>= 0)") + assert(maxIter.parent === uid) + assert(maxIter.toString === s"${uid}__maxIter") assert(!maxIter.isValid(-1)) assert(maxIter.isValid(0)) assert(maxIter.isValid(1)) solver.setMaxIter(5) - assert(maxIter.toString === - "maxIter: max number of iterations (>= 0) (default: 10, current: 5)") + assert(solver.explainParam(maxIter) === + "maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)") - assert(inputCol.toString === "inputCol: input column name (undefined)") + assert(inputCol.toString === s"${uid}__inputCol") intercept[IllegalArgumentException] { solver.setMaxIter(-1) @@ -118,7 +119,10 @@ class ParamsSuite extends FunSuite { assert(!solver.isDefined(inputCol)) intercept[NoSuchElementException](solver.getInputCol) - assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + assert(solver.explainParam(maxIter) === + "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") + assert(solver.explainParams() === + Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) @@ -131,7 +135,7 @@ class ParamsSuite extends FunSuite { intercept[IllegalArgumentException] { solver.validateParams() } - solver.validateParams(ParamMap(inputCol -> "input")) + solver.copy(ParamMap(inputCol -> "input")).validateParams() solver.setInputCol("input") assert(solver.isSet(inputCol)) assert(solver.isDefined(inputCol)) @@ -148,7 +152,7 @@ class ParamsSuite extends FunSuite { assert(!solver.isSet(maxIter)) val copied = solver.copy(ParamMap(solver.maxIter -> 50)) - assert(copied.uid !== solver.uid) + assert(copied.uid === solver.uid) assert(copied.getInputCol === solver.getInputCol) assert(copied.getMaxIter === 50) } @@ -197,3 +201,23 @@ class ParamsSuite extends FunSuite { assert(inArray(1) && inArray(2) && !inArray(0)) } } + +object ParamsSuite extends SparkFunSuite { + + /** + * Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered + * by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as + * the param method name. + */ + def checkParams(obj: Params, expectedNumParams: Int): Unit = { + val params = obj.params + require(params.length === expectedNumParams, + s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.") + val paramNames = params.map(_.name) + require(paramNames === paramNames.sorted) + params.foreach { p => + assert(p.parent === obj.uid) + assert(obj.getParam(p.name) === p) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index dc16073640407..a9e78366ad98f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -18,9 +18,12 @@ package org.apache.spark.ml.param import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} +import org.apache.spark.ml.util.Identifiable /** A subclass of Params for testing. */ -class TestParams extends Params with HasMaxIter with HasInputCol { +class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol { + + def this() = this(Identifiable.randomUID("testParams")) def setMaxIter(value: Int): this.type = { set(maxIter, value); this } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala new file mode 100644 index 0000000000000..eb5408d3fee7c --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.Params + +class SharedParamsSuite extends SparkFunSuite { + + test("outputCol") { + + class Obj(override val uid: String) extends Params with HasOutputCol + + val obj = new Obj("obj") + + assert(obj.hasDefault(obj.outputCol)) + assert(obj.getOrDefault(obj.outputCol) === "obj__output") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index fc7349330cf86..2e5cfe7027eb6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -25,9 +25,8 @@ import scala.collection.mutable.ArrayBuffer import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.scalatest.FunSuite -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -36,16 +35,14 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.Utils -class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { - private var sqlContext: SQLContext = _ private var tempDir: File = _ override def beforeAll(): Unit = { super.beforeAll() tempDir = Utils.createTempDir() sc.setCheckpointDir(tempDir.getAbsolutePath) - sqlContext = new SQLContext(sc) } override def afterAll(): Unit = { @@ -345,6 +342,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { .setImplicitPrefs(implicitPrefs) .setNumUserBlocks(numUserBlocks) .setNumItemBlocks(numItemBlocks) + .setSeed(0) val alpha = als.getAlpha val model = als.fit(training.toDF()) val predictions = model.transform(test.toDF()) @@ -425,17 +423,18 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) val longRatings = ratings.map(r => Rating(r.user.toLong, r.item.toLong, r.rating)) - val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4) + val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4, seed = 0) assert(longUserFactors.first()._1.getClass === classOf[Long]) val strRatings = ratings.map(r => Rating(r.user.toString, r.item.toString, r.rating)) - val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4) + val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4, seed = 0) assert(strUserFactors.first()._1.getClass === classOf[String]) } test("nonnegative constraint") { val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) - val (userFactors, itemFactors) = ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true) + val (userFactors, itemFactors) = + ALS.train(ratings, rank = 2, maxIter = 4, nonnegative = true, seed = 0) def isNonnegative(factors: RDD[(Int, Array[Float])]): Boolean = { factors.values.map { _.forall(_ >= 0.0) }.reduce(_ && _) } @@ -459,7 +458,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { test("partitioner in returned factors") { val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) val (userFactors, itemFactors) = ALS.train( - ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4) + ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4, seed = 0) for ((tpe, factors) <- Seq(("User", userFactors), ("Item", itemFactors))) { assert(userFactors.partitioner.isDefined, s"$tpe factors should have partitioner.") val part = userFactors.partitioner.get @@ -476,8 +475,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { test("als with large number of iterations") { val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) - ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2) - ALS.train( - ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, implicitPrefs = true) + ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, seed = 0) + ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2, + implicitPrefs = true, seed = 0) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 5aa81b44ddaf9..33aa9d0d62343 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, @@ -28,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame -class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import DecisionTreeRegressorSuite.compareAPIs @@ -69,7 +68,7 @@ class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { // TODO: test("model save/load") SPARK-6725 } -private[ml] object DecisionTreeRegressorSuite extends FunSuite { +private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite { /** * Train 2 decision trees on the given dataset, one using the old API and one using the new API. @@ -83,9 +82,9 @@ private[ml] object DecisionTreeRegressorSuite extends FunSuite { val oldTree = OldDecisionTree.train(data, oldStrategy) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newTree = dt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldTreeAsNew = DecisionTreeRegressionModel.fromOld( - oldTree, newTree.parent, categoricalFeatures) + oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures) TreeTests.checkEqual(oldTreeAsNew, newTree) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 25b36ab08b67c..98fb3d3f5f22c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[GBTRegressor]]. */ -class GBTRegressorSuite extends FunSuite with MLlibTestSparkContext { +class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import GBTRegressorSuite.compareAPIs @@ -129,8 +128,9 @@ private object GBTRegressorSuite { val oldModel = oldGBT.run(data) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = gbt.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. - val oldModelAsNew = GBTRegressionModel.fromOld(oldModel, newModel.parent, categoricalFeatures) + // Use parent from newTree since this is not checked anyways. + val oldModelAsNew = GBTRegressionModel.fromOld( + oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 80323ef5201a6..732e2c42be144 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,16 +17,14 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.DenseVector import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{Row, SQLContext, DataFrame} +import org.apache.spark.sql.{DataFrame, Row} -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { - @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ /** @@ -41,7 +39,6 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { */ override def beforeAll(): Unit = { super.beforeAll() - sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 45f09f4fdab81..b24ecaa57c89b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -31,7 +30,7 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { import RandomForestRegressorSuite.compareAPIs @@ -98,7 +97,7 @@ class RandomForestRegressorSuite extends FunSuite with MLlibTestSparkContext { */ } -private object RandomForestRegressorSuite extends FunSuite { +private object RandomForestRegressorSuite extends SparkFunSuite { /** * Train 2 models on the given dataset, one using the old API and one using the new API. @@ -114,9 +113,9 @@ private object RandomForestRegressorSuite extends FunSuite { data, oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) val newModel = rf.fit(newData) - // Use parent, fittingParamMap from newTree since these are not checked anyways. + // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestRegressionModel.fromOld( - oldModel, newModel.parent, categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 05313d440fbf6..5ba469c7b10a0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -17,15 +17,19 @@ package org.apache.spark.ml.tuning -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, DataFrame} +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.types.StructType -class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { +class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @@ -53,4 +57,54 @@ class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) } + + test("validateParams should check estimatorParamMaps") { + import CrossValidatorSuite._ + + val est = new MyEstimator("est") + val eval = new MyEvaluator + val paramMaps = new ParamGridBuilder() + .addGrid(est.inputCol, Array("input1", "input2")) + .build() + + val cv = new CrossValidator() + .setEstimator(est) + .setEstimatorParamMaps(paramMaps) + .setEvaluator(eval) + + cv.validateParams() // This should pass. + + val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") + cv.setEstimatorParamMaps(invalidParamMaps) + intercept[IllegalArgumentException] { + cv.validateParams() + } + } +} + +object CrossValidatorSuite { + + abstract class MyModel extends Model[MyModel] + + class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { + + override def validateParams(): Unit = require($(inputCol).nonEmpty) + + override def fit(dataset: DataFrame): MyModel = { + throw new UnsupportedOperationException + } + + override def transformSchema(schema: StructType): StructType = { + throw new UnsupportedOperationException + } + } + + class MyEvaluator extends Evaluator { + + override def evaluate(dataset: DataFrame): Double = { + throw new UnsupportedOperationException + } + + override val uid: String = "eval" + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala index 20aa100112bfe..810b70049ec15 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.ml.tuning import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.{ParamMap, TestParams} -class ParamGridBuilderSuite extends FunSuite { +class ParamGridBuilderSuite extends SparkFunSuite { val solver = new TestParams() import solver.{inputCol, maxIter} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index a629dba8a426f..59944416d96a6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.api.python -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.recommendation.Rating -class PythonMLLibAPISuite extends FunSuite { +class PythonMLLibAPISuite extends SparkFunSuite { SerDe.initialize() @@ -84,7 +83,7 @@ class PythonMLLibAPISuite extends FunSuite { val smt = new SparseMatrix( 3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9), - isTransposed=true) + isTransposed = true) val nsmt = SerDe.loads(SerDe.dumps(smt)).asInstanceOf[SparseMatrix] assert(smt.toArray === nsmt.toArray) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index fb0a194718802..e8f3d0c4db20a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConversions._ import scala.util.Random import scala.util.control.Breaks._ -import org.scalatest.FunSuite import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -101,7 +101,8 @@ object LogisticRegressionSuite { // This doesn't work if `vector` is a sparse vector. val vectorArray = vector.toArray var i = 0 - while (i < vectorArray.length) { + val len = vectorArray.length + while (i < len) { vectorArray(i) = vectorArray(i) * math.sqrt(xVariance(i)) + xMean(i) i += 1 } @@ -118,7 +119,7 @@ object LogisticRegressionSuite { } // Preventing the overflow when we compute the probability val maxMargin = margins.max - if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin + if (maxMargin > 0) for (i <- 0 until nClasses) margins(i) -= maxMargin // Computing the probabilities for each class from the margins. val norm = { @@ -129,7 +130,7 @@ object LogisticRegressionSuite { } temp } - for (i <-0 until nClasses) probs(i) /= norm + for (i <- 0 until nClasses) probs(i) /= norm // Compute the cumulative probability so we can generate a random number and assign a label. for (i <- 1 until nClasses) probs(i) += probs(i - 1) @@ -168,7 +169,7 @@ object LogisticRegressionSuite { } -class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], input: Seq[LabeledPoint], @@ -540,7 +541,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M } -class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class LogisticRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction using SGD optimizer") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index ea89b17b7c08f..f7fc8730606af 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -19,20 +19,19 @@ package org.apache.spark.mllib.classification import scala.util.Random -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import breeze.stats.distributions.{Multinomial => BrzMultinomial} -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.util.Utils - object NaiveBayesSuite { + import NaiveBayes.{Multinomial, Bernoulli} + private def calcLabel(p: Double, pi: Array[Double]): Int = { var sum = 0.0 for (j <- 0 until pi.length) { @@ -48,7 +47,7 @@ object NaiveBayesSuite { theta: Array[Array[Double]], // CXD nPoints: Int, seed: Int, - modelType: String = "Multinomial", + modelType: String = Multinomial, sample: Int = 10): Seq[LabeledPoint] = { val D = theta(0).length val rnd = new Random(seed) @@ -58,10 +57,10 @@ object NaiveBayesSuite { for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) val xi = modelType match { - case "Bernoulli" => Array.tabulate[Double] (D) { j => + case Bernoulli => Array.tabulate[Double] (D) { j => if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0 } - case "Multinomial" => + case Multinomial => val mult = BrzMultinomial(BDV(_theta(y))) val emptyMap = (0 until D).map(x => (x, 0.0)).toMap val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map { @@ -70,7 +69,7 @@ object NaiveBayesSuite { counts.toArray.sortBy(_._1).map(_._2) case _ => // This should never happen. - throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType") + throw new UnknownError(s"Invalid modelType: $modelType.") } LabeledPoint(y, Vectors.dense(xi)) @@ -79,16 +78,16 @@ object NaiveBayesSuite { /** Bernoulli NaiveBayes with binary labels, 3 features */ private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0), - pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), - "Bernoulli") + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Bernoulli) /** Multinomial NaiveBayes with binary labels, 3 features */ private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0), - pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), - "Multinomial") + pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial) } -class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { +class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + + import NaiveBayes.{Multinomial, Bernoulli} def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOfPredictions = predictions.zip(input).count { @@ -117,6 +116,11 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } + test("model types") { + assert(Multinomial === "multinomial") + assert(Bernoulli === "bernoulli") + } + test("get, set params") { val nb = new NaiveBayes() nb.setLambda(2.0) @@ -134,16 +138,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { Array(0.10, 0.10, 0.70, 0.10) // label 2 ).map(_.map(math.log)) - val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 42, "Multinomial") + val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42, Multinomial) val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD, 1.0, "Multinomial") + val model = NaiveBayes.train(testRDD, 1.0, Multinomial) validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 17, "Multinomial") + pi, theta, nPoints, 17, Multinomial) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -159,19 +162,19 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { val theta = Array( Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 - Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 + Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 ).map(_.map(math.log)) val testData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 45, "Bernoulli") + pi, theta, nPoints, 45, Bernoulli) val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli") + val model = NaiveBayes.train(testRDD, 1.0, Bernoulli) validateModelFit(pi, theta, model) val validationData = NaiveBayesSuite.generateNaiveBayesInput( - pi, theta, nPoints, 20, "Bernoulli") + pi, theta, nPoints, 20, Bernoulli) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -208,6 +211,39 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { } } + test("detect non zero or one values in Bernoulli") { + val badTrain = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0))) + + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, Bernoulli) + } + + val okTrain = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)) + ) + + val badPredict = Seq( + Vectors.dense(1.0), + Vectors.dense(2.0), + Vectors.dense(1.0), + Vectors.dense(0.0)) + + val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, Bernoulli) + intercept[SparkException] { + model.predict(sc.makeRDD(badPredict, 2)).collect() + } + } + test("model save/load: 2.0 to 2.0") { val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString @@ -242,14 +278,14 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { assert(model.labels === sameModel.labels) assert(model.pi === sameModel.pi) assert(model.theta === sameModel.theta) - assert(model.modelType === "Multinomial") + assert(model.modelType === Multinomial) } finally { Utils.deleteRecursively(tempDir) } } } -class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { +class NaiveBayesClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 6de098b383ba3..b1d78cba9e3dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -21,9 +21,8 @@ import scala.collection.JavaConversions._ import scala.util.Random import org.jblas.DoubleMatrix -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -46,7 +45,7 @@ object SVMSuite { nPoints: Int, seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) - val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) + val weightsMat = new DoubleMatrix(1, weights.length, weights : _*) val x = Array.fill[Array[Double]](nPoints)( Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) val y = x.map { xi => @@ -62,7 +61,7 @@ object SVMSuite { } -class SVMSuite extends FunSuite with MLlibTestSparkContext { +class SVMSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -91,7 +90,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD) val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -117,7 +116,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -127,8 +126,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD) - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData, 2) + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -145,7 +144,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val initialB = -1.0 val initialC = -1.0 @@ -159,8 +158,8 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val model = svm.run(testRDD, initialWeights) - val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData,2) + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -177,7 +176,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { val B = -1.5 val C = 1.0 - val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) val testRDD = sc.parallelize(testData, 2) val testRDDInvalid = testRDD.map { lp => @@ -229,7 +228,7 @@ class SVMSuite extends FunSuite with MLlibTestSparkContext { } } -class SVMClusterSuite extends FunSuite with LocalClusterSparkContext { +class SVMClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index 5683b55e8500a..e98b61e13e21f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -19,15 +19,14 @@ package org.apache.spark.mllib.classification import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.TestSuiteBase -class StreamingLogisticRegressionSuite extends FunSuite with TestSuiteBase { +class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 30000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index f356ffa3e3a26..b218d72f1268a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vectors, Matrices} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { +class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext { test("single cluster") { val data = sc.parallelize(Array( Vectors.dense(6.0, 9.0), @@ -47,7 +46,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { } } - + test("two clusters") { val data = sc.parallelize(GaussianTestData.data) @@ -63,7 +62,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { val Ew = Array(1.0 / 3.0, 2.0 / 3.0) val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) - + val gmm = new GaussianMixture() .setK(2) .setInitialModel(initialGmm) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index 0f2b26d462ad2..0dbbd7127444f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.mllib.clustering import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class KMeansSuite extends FunSuite with MLlibTestSparkContext { +class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM} @@ -75,7 +74,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { val center = Vectors.dense(1.0, 2.0, 3.0) // Make sure code runs. - var model = KMeans.train(data, k=2, maxIterations=1) + var model = KMeans.train(data, k = 2, maxIterations = 1) assert(model.clusterCenters.size === 2) } @@ -87,7 +86,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { 2) // Make sure code runs. - var model = KMeans.train(data, k=3, maxIterations=1) + var model = KMeans.train(data, k = 3, maxIterations = 1) assert(model.clusterCenters.size === 3) } @@ -281,7 +280,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { } } -object KMeansSuite extends FunSuite { +object KMeansSuite extends SparkFunSuite { def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = { val singlePoint = isSparse match { case true => @@ -305,7 +304,7 @@ object KMeansSuite extends FunSuite { } } -class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext { +class KMeansClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index d5b7d96335744..406affa25539d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class LDASuite extends FunSuite with MLlibTestSparkContext { +class LDASuite extends SparkFunSuite with MLlibTestSparkContext { import LDASuite._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 6d6fe6fe46bab..19e65f1b53ab5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -20,15 +20,13 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable import scala.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { +class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.PowerIterationClustering._ @@ -58,7 +56,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext predictions(a.cluster) += a.id } assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) - + val model2 = new PowerIterationClustering() .setK(2) .setInitializationMode("degree") @@ -94,11 +92,13 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext */ val similarities = Seq[(Long, Long, Double)]( (0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0)) + // scalastyle:off val expected = Array( Array(0.0, 1.0/3.0, 1.0/3.0, 1.0/3.0), Array(1.0/2.0, 0.0, 1.0/2.0, 0.0), Array(1.0/3.0, 1.0/3.0, 0.0, 1.0/3.0), Array(1.0/2.0, 0.0, 1.0/2.0, 0.0)) + // scalastyle:on val w = normalize(sc.parallelize(similarities, 2)) w.edges.collect().foreach { case Edge(i, j, x) => assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14) @@ -128,7 +128,7 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext } } -object PowerIterationClusteringSuite extends FunSuite { +object PowerIterationClusteringSuite extends SparkFunSuite { def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = { val assignments = sc.parallelize( (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k)))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index f90025d535e45..ac01622b8a089 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.clustering -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom -class StreamingKMeansSuite extends FunSuite with TestSuiteBase { +class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { override def maxWaitTimeMillis: Int = 30000 @@ -133,6 +132,13 @@ class StreamingKMeansSuite extends FunSuite with TestSuiteBase { assert(math.abs(c1) ~== 0.8 absTol 0.6) } + test("SPARK-7946 setDecayFactor") { + val kMeans = new StreamingKMeans() + assert(kMeans.decayFactor === 1.0) + kMeans.setDecayFactor(2.0) + assert(kMeans.decayFactor === 2.0) + } + def StreamingKMeansDataGenerator( numPoints: Int, numBatches: Int, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 79847633ff0dc..87ccc7eda44ea 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext { +class AreaUnderCurveSuite extends SparkFunSuite with MLlibTestSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index e0224f960cc43..99d52fabc5309 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext { +class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 7dc4f3cfbc4e4..d55bc8c3ec09f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Matrices import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext { +class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala index 2537dd62c92f2..f3b19aeb42f84 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext { +class MultilabelMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Multilabel evaluation metrics") { /* * Documents true labels (5x class0, 3x class1, 4x class2): diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index 609eed983ff4e..c0924a213a844 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext -class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext { +class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Ranking metrics: map, ndcg") { val predictionAndLabels = sc.parallelize( Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 670b4c34e6095..9de2bdb6d7246 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.mllib.evaluation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { +class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("regression metrics") { val predictionAndObservations = sc.parallelize( - Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2) + Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5, "explained variance regression score mismatch") @@ -39,7 +38,7 @@ class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { test("regression metrics with complete fitting") { val predictionAndObservations = sc.parallelize( - Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2) + Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2) val metrics = new RegressionMetrics(predictionAndObservations) assert(metrics.explainedVariance ~== 1.0 absTol 1E-5, "explained variance regression score mismatch") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index 747f5914598ec..889727fb55823 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext -class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext { +class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { /* * Contingency tables diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala index f3a482abda873..ccbf8a91cdd37 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ElementwiseProductSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class ElementwiseProductSuite extends FunSuite with MLlibTestSparkContext { +class ElementwiseProductSuite extends SparkFunSuite with MLlibTestSparkContext { test("elementwise (hadamard) product should properly apply vector to dense data set") { val denseData = Array( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala index 0c4dfb7b97c7f..cf279c02334e9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -class HashingTFSuite extends FunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { test("hashing tf on a single doc") { val hashingTF = new HashingTF(1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 0a5cad7caf8e4..21163633051e5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class IDFSuite extends FunSuite with MLlibTestSparkContext { +class IDFSuite extends SparkFunSuite with MLlibTestSparkContext { test("idf") { val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index 5c4af2b99e68b..34122d6ed2e95 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - import breeze.linalg.{norm => brzNorm} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class NormalizerSuite extends FunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { val data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala index 758af588f1c69..e57f49191378f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.util.MLlibTestSparkContext -class PCASuite extends FunSuite with MLlibTestSparkContext { +class PCASuite extends SparkFunSuite with MLlibTestSparkContext { private val data = Array( Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 7f94564b2a3ae..6ab2fa6770123 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD -class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { +class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { // When the input data is all constant, the variance is zero. The standardization against // zero variance is not well-defined, but we decide to just set it into zero here. @@ -360,7 +359,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { } withClue("model needs std and mean vectors to be equal size when both are provided") { intercept[IllegalArgumentException] { - val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0)) + val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0, 1.0)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 98a98a7599bcb..b6818369208d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.feature -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class Word2VecSuite extends FunSuite with MLlibTestSparkContext { +class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { // TODO: add more tests diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index bd5b9cc3afa10..66ae3543ecc4e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -16,11 +16,10 @@ */ package org.apache.spark.mllib.fpm -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { +class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { test("FP-Growth using String type") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala index 04017f67c311d..a56d7b3579213 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.mllib.fpm import scala.language.existentials -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class FPTreeSuite extends FunSuite with MLlibTestSparkContext { +class FPTreeSuite extends SparkFunSuite with MLlibTestSparkContext { test("add transaction") { val tree = new FPTree[String] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index 699f009f0f2ec..d34888af2d73b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -17,18 +17,16 @@ package org.apache.spark.mllib.impl -import org.scalatest.FunSuite - import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.graphx.{Edge, Graph} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext { +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { import PeriodicGraphCheckpointerSuite._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 002cb253862b5..b0f3f71113c57 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.linalg.BLAS._ -class BLASSuite extends FunSuite { +class BLASSuite extends SparkFunSuite { test("copy") { val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0)) @@ -140,7 +139,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dA) assert(dA ~== expected absTol 1e-15) - + val dB = new DenseMatrix(3, 4, Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0)) @@ -149,7 +148,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dB) } } - + val dC = new DenseMatrix(3, 3, Array(0.0, 1.2, 2.2, 1.2, 3.2, 5.3, 2.2, 5.3, 1.8)) @@ -158,7 +157,7 @@ class BLASSuite extends FunSuite { syr(alpha, x, dC) } } - + val y = new DenseVector(Array(0.0, 2.7, 3.5, 2.1, 1.5)) withClue("Size of vector must match the rank of matrix") { @@ -257,32 +256,96 @@ class BLASSuite extends FunSuite { new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) - val x = new DenseVector(Array(1.0, 2.0, 3.0)) + val dA2 = + new DenseMatrix(4, 3, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0), true) + val sA2 = + new SparseMatrix(4, 3, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0), + true) + + val dx = new DenseVector(Array(1.0, 2.0, 3.0)) + val sx = dx.toSparse val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) - assert(dA.multiply(x) ~== expected absTol 1e-15) - assert(sA.multiply(x) ~== expected absTol 1e-15) + assert(dA.multiply(dx) ~== expected absTol 1e-15) + assert(sA.multiply(dx) ~== expected absTol 1e-15) + assert(dA.multiply(sx) ~== expected absTol 1e-15) + assert(sA.multiply(sx) ~== expected absTol 1e-15) val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) val y2 = y1.copy val y3 = y1.copy val y4 = y1.copy + val y5 = y1.copy + val y6 = y1.copy + val y7 = y1.copy + val y8 = y1.copy + val y9 = y1.copy + val y10 = y1.copy + val y11 = y1.copy + val y12 = y1.copy + val y13 = y1.copy + val y14 = y1.copy + val y15 = y1.copy + val y16 = y1.copy + val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) - gemv(1.0, dA, x, 2.0, y1) - gemv(1.0, sA, x, 2.0, y2) - gemv(2.0, dA, x, 2.0, y3) - gemv(2.0, sA, x, 2.0, y4) + gemv(1.0, dA, dx, 2.0, y1) + gemv(1.0, sA, dx, 2.0, y2) + gemv(1.0, dA, sx, 2.0, y3) + gemv(1.0, sA, sx, 2.0, y4) + + gemv(1.0, dA2, dx, 2.0, y5) + gemv(1.0, sA2, dx, 2.0, y6) + gemv(1.0, dA2, sx, 2.0, y7) + gemv(1.0, sA2, sx, 2.0, y8) + + gemv(2.0, dA, dx, 2.0, y9) + gemv(2.0, sA, dx, 2.0, y10) + gemv(2.0, dA, sx, 2.0, y11) + gemv(2.0, sA, sx, 2.0, y12) + + gemv(2.0, dA2, dx, 2.0, y13) + gemv(2.0, sA2, dx, 2.0, y14) + gemv(2.0, dA2, sx, 2.0, y15) + gemv(2.0, sA2, sx, 2.0, y16) + assert(y1 ~== expected2 absTol 1e-15) assert(y2 ~== expected2 absTol 1e-15) - assert(y3 ~== expected3 absTol 1e-15) - assert(y4 ~== expected3 absTol 1e-15) + assert(y3 ~== expected2 absTol 1e-15) + assert(y4 ~== expected2 absTol 1e-15) + + assert(y5 ~== expected2 absTol 1e-15) + assert(y6 ~== expected2 absTol 1e-15) + assert(y7 ~== expected2 absTol 1e-15) + assert(y8 ~== expected2 absTol 1e-15) + + assert(y9 ~== expected3 absTol 1e-15) + assert(y10 ~== expected3 absTol 1e-15) + assert(y11 ~== expected3 absTol 1e-15) + assert(y12 ~== expected3 absTol 1e-15) + + assert(y13 ~== expected3 absTol 1e-15) + assert(y14 ~== expected3 absTol 1e-15) + assert(y15 ~== expected3 absTol 1e-15) + assert(y16 ~== expected3 absTol 1e-15) + withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemv(1.0, dA.transpose, x, 2.0, y1) + gemv(1.0, dA.transpose, dx, 2.0, y1) + } + intercept[Exception] { + gemv(1.0, sA.transpose, dx, 2.0, y1) + } + intercept[Exception] { + gemv(1.0, dA.transpose, sx, 2.0, y1) + } + intercept[Exception] { + gemv(1.0, sA.transpose, sx, 2.0, y1) } } + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = @@ -291,7 +354,9 @@ class BLASSuite extends FunSuite { val dATT = dAT.transpose val sATT = sAT.transpose - assert(dATT.multiply(x) ~== expected absTol 1e-15) - assert(sATT.multiply(x) ~== expected absTol 1e-15) + assert(dATT.multiply(dx) ~== expected absTol 1e-15) + assert(sATT.multiply(dx) ~== expected absTol 1e-15) + assert(dATT.multiply(sx) ~== expected absTol 1e-15) + assert(sATT.multiply(sx) ~== expected absTol 1e-15) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 2031032373971..dc04258e41d27 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} -class BreezeMatrixConversionSuite extends FunSuite { +import org.apache.spark.SparkFunSuite + +class BreezeMatrixConversionSuite extends SparkFunSuite { test("dense matrix to breeze") { val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) val breeze = mat.toBreeze.asInstanceOf[BDM[Double]] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala index 8abdac72902c6..3772c9235ad3a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.mllib.linalg -import org.scalatest.FunSuite - import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import org.apache.spark.SparkFunSuite + /** * Test Breeze vector conversions. */ -class BreezeVectorConversionSuite extends FunSuite { +class BreezeVectorConversionSuite extends SparkFunSuite { val arr = Array(0.1, 0.2, 0.3, 0.4) val n = 20 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 86119ec38101e..8dbb70f5d1c4c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.mllib.linalg import java.util.Random import org.mockito.Mockito.when -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar._ import scala.collection.mutable.{Map => MutableMap} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ -class MatricesSuite extends FunSuite { +class MatricesSuite extends SparkFunSuite { test("dense matrix construction") { val m = 3 val n = 2 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 24755e9ff46fc..c4ae0a16f7c04 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -20,12 +20,11 @@ package org.apache.spark.mllib.linalg import scala.util.Random import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.util.TestingUtils._ -class VectorsSuite extends FunSuite { +class VectorsSuite extends SparkFunSuite { val arr = Array(0.1, 0.0, 0.3, 0.4) val n = 4 @@ -215,13 +214,13 @@ class VectorsSuite extends FunSuite { val squaredDist = breezeSquaredDistance(sparseVector1.toBreeze, sparseVector2.toBreeze) - // SparseVector vs. SparseVector - assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) + // SparseVector vs. SparseVector + assert(Vectors.sqdist(sparseVector1, sparseVector2) ~== squaredDist relTol 1E-8) // DenseVector vs. SparseVector assert(Vectors.sqdist(denseVector1, sparseVector2) ~== squaredDist relTol 1E-8) // DenseVector vs. DenseVector assert(Vectors.sqdist(denseVector1, denseVector2) ~== squaredDist relTol 1E-8) - } + } } test("foreachActive") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index 949d1c9939570..93fe04c139b9a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -20,14 +20,13 @@ package org.apache.spark.mllib.linalg.distributed import java.{util => ju} import breeze.linalg.{DenseMatrix => BDM} -import org.scalatest.FunSuite -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { +class BlockMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 @@ -57,11 +56,13 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { val random = new ju.Random() // This should generate a 4x4 grid of 1x2 blocks. val part0 = GridPartitioner(4, 7, suggestedNumPartitions = 12) + // scalastyle:off val expected0 = Array( Array(0, 0, 4, 4, 8, 8, 12), Array(1, 1, 5, 5, 9, 9, 13), Array(2, 2, 6, 6, 10, 10, 14), Array(3, 3, 7, 7, 11, 11, 15)) + // scalastyle:on for (i <- 0 until 4; j <- 0 until 7) { assert(part0.getPartition((i, j)) === expected0(i)(j)) assert(part0.getPartition((i, j, random.nextInt())) === expected0(i)(j)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index 04b36a9ef9990..f3728cd036a3f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors -class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { +class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 5 val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 2ab53cc13db71..4a7b99a976f0a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.linalg.distributed -import org.scalatest.FunSuite - import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrices, Vectors} -class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { +class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 4 val n = 3 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 27bb19f472e1e..b6cb53d0c743e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.linalg.distributed import scala.util.Random import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} -class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { +class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { val m = 4 val n = 3 @@ -240,7 +240,7 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { } } -class RowMatrixClusterSuite extends FunSuite with LocalClusterSparkContext { +class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { var mat: RowMatrix = _ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 86481c6e66200..a5a59e9fad5ae 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -20,8 +20,9 @@ package org.apache.spark.mllib.optimization import scala.collection.JavaConversions._ import scala.util.Random -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} @@ -42,7 +43,7 @@ object GradientDescentSuite { offset: Double, scale: Double, nPoints: Int, - seed: Int): Seq[LabeledPoint] = { + seed: Int): Seq[LabeledPoint] = { val rnd = new Random(seed) val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) @@ -61,7 +62,7 @@ object GradientDescentSuite { } } -class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class GradientDescentSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { test("Assert the loss is decreasing.") { val nPoints = 10000 @@ -140,7 +141,7 @@ class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matc } } -class GradientDescentClusterSuite extends FunSuite with LocalClusterSparkContext { +class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index c8f2adcf155a7..d07b9d5b89227 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { val nPoints = 10000 val A = 2.0 @@ -229,7 +230,7 @@ class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { } } -class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext { +class LBFGSClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small") { val m = 10 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index 22855e4e8f247..d8f9b8c33963d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.mllib.optimization import scala.util.Random -import org.scalatest.FunSuite - import org.jblas.{DoubleMatrix, SimpleBlas} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.TestingUtils._ -class NNLSSuite extends FunSuite { +class NNLSSuite extends SparkFunSuite { /** Generate an NNLS problem whose optimal solution is the all-ones vector. */ def genOnesData(n: Int, rand: Random): (DoubleMatrix, DoubleMatrix) = { val A = new DoubleMatrix(n, n, Array.fill(n*n)(rand.nextDouble()): _*) @@ -68,12 +67,14 @@ class NNLSSuite extends FunSuite { test("NNLS: nonnegativity constraint active") { val n = 5 + // scalastyle:off val ata = new DoubleMatrix(Array( Array( 4.377, -3.531, -1.306, -0.139, 3.418), Array(-3.531, 4.344, 0.934, 0.305, -2.140), Array(-1.306, 0.934, 2.644, -0.203, -0.170), Array(-0.139, 0.305, -0.203, 5.883, 1.428), Array( 3.418, -2.140, -0.170, 1.428, 4.684))) + // scalastyle:on val atb = new DoubleMatrix(Array(-1.632, 2.115, 1.094, -1.025, -0.636)) val goodx = Array(0.13025, 0.54506, 0.2874, 0.0, 0.028628) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala index 0b646cf1ce6c4..4c6e76e47419b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/BinaryClassificationPMMLModelExportSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel import org.dmg.pmml.RegressionNormalizationMethodType -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.classification.SVMModel import org.apache.spark.mllib.util.LinearDataGenerator -class BinaryClassificationPMMLModelExportSuite extends FunSuite { +class BinaryClassificationPMMLModelExportSuite extends SparkFunSuite { test("logistic regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) @@ -53,13 +53,13 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite { // ensure logistic regression has normalization method set to LOGIT assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.LOGIT) } - + test("linear SVM PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) - + val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) - + // assert that the PMML format is as expected assert(svmModelExport.isInstanceOf[PMMLModelExport]) val pmml = svmModelExport.getPmml @@ -80,5 +80,5 @@ class BinaryClassificationPMMLModelExportSuite extends FunSuite { // ensure linear SVM has normalization method set to NONE assert(pmmlRegressionModel.getNormalizationMethod() == RegressionNormalizationMethodType.NONE) } - + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala index f9afbd888dfc5..1d32309481787 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/GeneralizedLinearPMMLModelExportSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionModel -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator -class GeneralizedLinearPMMLModelExportSuite extends FunSuite { +class GeneralizedLinearPMMLModelExportSuite extends SparkFunSuite { test("linear regression PMML export") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala index b985d0446d7b0..b3f9750afa730 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExportSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.ClusteringModel -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors -class KMeansPMMLModelExportSuite extends FunSuite { +class KMeansPMMLModelExportSuite extends SparkFunSuite { test("KMeansPMMLModelExport generate PMML format") { val clusterCenters = Array( @@ -45,5 +45,5 @@ class KMeansPMMLModelExportSuite extends FunSuite { val pmmlClusteringModel = pmml.getModels.get(0).asInstanceOf[ClusteringModel] assert(pmmlClusteringModel.getNumberOfClusters === clusterCenters.length) } - + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index f28a4ac8ad01f..af49450961750 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.pmml.export -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel} import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} import org.apache.spark.mllib.util.LinearDataGenerator -class PMMLModelExportFactorySuite extends FunSuite { +class PMMLModelExportFactorySuite extends SparkFunSuite { test("PMMLModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") { val clusterCenters = Array( @@ -61,25 +60,25 @@ class PMMLModelExportFactorySuite extends FunSuite { test("PMMLModelExportFactory create BinaryClassificationPMMLModelExport " + "when passing a LogisticRegressionModel or SVMModel") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) - + val logisticRegressionModel = new LogisticRegressionModel(linearInput(0).features, linearInput(0).label) val logisticRegressionModelExport = PMMLModelExportFactory.createPMMLModelExport(logisticRegressionModel) assert(logisticRegressionModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) - + val svmModel = new SVMModel(linearInput(0).features, linearInput(0).label) val svmModelExport = PMMLModelExportFactory.createPMMLModelExport(svmModel) assert(svmModelExport.isInstanceOf[BinaryClassificationPMMLModelExport]) } - + test("PMMLModelExportFactory throw IllegalArgumentException " + "when passing a Multinomial Logistic Regression") { /** 3 classes, 2 features */ val multiclassLogisticRegressionModel = new LogisticRegressionModel( - weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, + weights = Vectors.dense(0.1, 0.2, 0.3, 0.4), intercept = 1.0, numFeatures = 2, numClasses = 3) - + intercept[IllegalArgumentException] { PMMLModelExportFactory.createPMMLModelExport(multiclassLogisticRegressionModel) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index b792d819fdabb..a5ca1518f82f5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.mllib.random import scala.math -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter // TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged -class RandomDataGeneratorSuite extends FunSuite { +class RandomDataGeneratorSuite extends SparkFunSuite { def apiChecks(gen: RandomDataGenerator[Double]) { // resetting seed should generate the same sequence of random numbers diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 63f2ea916d457..413db2000d6d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.random import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} @@ -34,7 +33,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable { +class RandomRDDsSuite extends SparkFunSuite with MLlibTestSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala index 57216e8eb4a55..10f5a2be48f7c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.rdd -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ -class MLPairRDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { +class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("topByKey") { val topMap = sc.parallelize(Array((1, 7), (1, 3), (1, 6), (1, 1), (1, 2), (3, 2), (3, 7), (5, 1), (3, 5)), 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 6d6c0aa5be812..bc64172614830 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.rdd -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ -class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { +class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding") { val data = 0 until 6 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index b3798940ddc38..05b87728d6fdb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConversions._ import scala.math.abs import scala.util.Random -import org.scalatest.FunSuite import org.jblas.DoubleMatrix +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.storage.StorageLevel @@ -84,7 +84,7 @@ object ALSSuite { } -class ALSSuite extends FunSuite with MLlibTestSparkContext { +class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { test("rank-1 matrices") { testALS(50, 100, 1, 15, 0.7, 0.3) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index 2c92866f3893d..2c8ed057a516a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.mllib.recommendation -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { +class MatrixFactorizationModelSuite extends SparkFunSuite with MLlibTestSparkContext { val rank = 2 var userFeatures: RDD[(Int, Array[Double])] = _ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala index 3b38bdf5ef5eb..ea4f2865757c1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.regression -import org.scalatest.{Matchers, FunSuite} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers { private def round(d: Double) = { math.round(d * 100).toDouble / 100 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index 110c44a7193fd..d8364a06de4da 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.mllib.regression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -class LabeledPointSuite extends FunSuite { +class LabeledPointSuite extends SparkFunSuite { test("parse labeled points") { val points = Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index c9f5dc069ef2e..08a152ffc7a23 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -32,7 +31,7 @@ private object LassoSuite { val model = new LassoModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class LassoSuite extends FunSuite with MLlibTestSparkContext { +class LassoSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -67,11 +66,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationData = LinearDataGenerator + .generateLinearInput(A, Array[Double](B, C), nPoints, 17) .map { case LabeledPoint(label, features) => LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) } - val validationRDD = sc.parallelize(validationData, 2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -110,11 +110,12 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") - val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + val validationData = LinearDataGenerator + .generateLinearInput(A, Array[Double](B, C), nPoints, 17) .map { case LabeledPoint(label, features) => LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) } - val validationRDD = sc.parallelize(validationData,2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -141,7 +142,7 @@ class LassoSuite extends FunSuite with MLlibTestSparkContext { } } -class LassoClusterSuite extends FunSuite with LocalClusterSparkContext { +class LassoClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 3781931c2f819..f88a1c33c9f7c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.regression import scala.util.Random -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -32,7 +31,7 @@ private object LinearRegressionSuite { val model = new LinearRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { +class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => @@ -150,7 +149,7 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { } } -class LinearRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class LinearRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index d6c93cc0e49cd..7a781fee634c8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.regression import scala.util.Random import org.jblas.DoubleMatrix -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, MLlibTestSparkContext} @@ -33,7 +33,7 @@ private object RidgeRegressionSuite { val model = new RidgeRegressionModel(weights = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) } -class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { +class RidgeRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]): Double = { predictions.zip(input).map { case (prediction, expected) => @@ -101,7 +101,7 @@ class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { } } -class RidgeRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { +class RidgeRegressionClusterSuite extends SparkFunSuite with LocalClusterSparkContext { test("task size should be small in both training and prediction") { val m = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 26604dbe6c1ef..9a379406d5061 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark.mllib.regression import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.TestSuiteBase -class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { +class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 20000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index d20a09b4b4925..c292ced75e870 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, SpearmanCorrelation} import org.apache.spark.mllib.util.MLlibTestSparkContext -class CorrelationSuite extends FunSuite with MLlibTestSparkContext { +class CorrelationSuite extends SparkFunSuite with MLlibTestSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) @@ -96,11 +95,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext { val X = sc.parallelize(data) val defaultMat = Statistics.corr(X) val pearsonMat = Statistics.corr(X, "pearson") + // scalastyle:off val expected = BDM( (1.00000000, 0.05564149, Double.NaN, 0.4004714), (0.05564149, 1.00000000, Double.NaN, 0.9135959), (Double.NaN, Double.NaN, 1.00000000, Double.NaN), - (0.40047142, 0.91359586, Double.NaN,1.0000000)) + (0.40047142, 0.91359586, Double.NaN, 1.0000000)) + // scalastyle:on assert(matrixApproxEqual(defaultMat.toBreeze, expected)) assert(matrixApproxEqual(pearsonMat.toBreeze, expected)) } @@ -108,11 +109,13 @@ class CorrelationSuite extends FunSuite with MLlibTestSparkContext { test("corr(X) spearman") { val X = sc.parallelize(data) val spearmanMat = Statistics.corr(X, "spearman") + // scalastyle:off val expected = BDM( (1.0000000, 0.1054093, Double.NaN, 0.4000000), (0.1054093, 1.0000000, Double.NaN, 0.9486833), (Double.NaN, Double.NaN, 1.00000000, Double.NaN), (0.4000000, 0.9486833, Double.NaN, 1.0000000)) + // scalastyle:on assert(matrixApproxEqual(spearmanMat.toBreeze, expected)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 15418e6035965..b084a5fb4313f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -19,16 +19,14 @@ package org.apache.spark.mllib.stat import java.util.Random -import org.scalatest.FunSuite - -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext { +class HypothesisTestSuite extends SparkFunSuite with MLlibTestSparkContext { test("chi squared pearson goodness of fit") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala index 16ecae23dd9d4..5feccdf33681a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/KernelDensitySuite.scala @@ -17,31 +17,32 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - import org.apache.commons.math3.distribution.NormalDistribution +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext -class KernelDensitySuite extends FunSuite with MLlibTestSparkContext { +class KernelDensitySuite extends SparkFunSuite with MLlibTestSparkContext { test("kernel density single sample") { val rdd = sc.parallelize(Array(5.0)) val evaluationPoints = Array(5.0, 6.0) - val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints) + val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) val normal = new NormalDistribution(5.0, 3.0) val acceptableErr = 1e-6 - assert(densities(0) - normal.density(5.0) < acceptableErr) - assert(densities(0) - normal.density(6.0) < acceptableErr) + assert(math.abs(densities(0) - normal.density(5.0)) < acceptableErr) + assert(math.abs(densities(1) - normal.density(6.0)) < acceptableErr) } test("kernel density multiple samples") { val rdd = sc.parallelize(Array(5.0, 10.0)) val evaluationPoints = Array(5.0, 6.0) - val densities = KernelDensity.estimate(rdd, 3.0, evaluationPoints) + val densities = new KernelDensity().setSample(rdd).setBandwidth(3.0).estimate(evaluationPoints) val normal1 = new NormalDistribution(5.0, 3.0) val normal2 = new NormalDistribution(10.0, 3.0) val acceptableErr = 1e-6 - assert(densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2 < acceptableErr) - assert(densities(0) - (normal1.density(6.0) + normal2.density(6.0)) / 2 < acceptableErr) + assert(math.abs( + densities(0) - (normal1.density(5.0) + normal2.density(5.0)) / 2) < acceptableErr) + assert(math.abs( + densities(1) - (normal1.density(6.0) + normal2.density(6.0)) / 2) < acceptableErr) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 23b0eec865de6..07efde4f5e6dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.stat -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.TestingUtils._ -class MultivariateOnlineSummarizerSuite extends FunSuite { +class MultivariateOnlineSummarizerSuite extends SparkFunSuite { test("basic error handing") { val summarizer = new MultivariateOnlineSummarizer diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index fac2498e4dcb3..aa60deb665aeb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -17,49 +17,48 @@ package org.apache.spark.mllib.stat.distribution -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{ Vectors, Matrices } import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class MultivariateGaussianSuite extends FunSuite with MLlibTestSparkContext { +class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext { test("univariate") { val x1 = Vectors.dense(0.0) val x2 = Vectors.dense(1.5) - + val mu = Vectors.dense(0.0) val sigma1 = Matrices.dense(1, 1, Array(1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.39894 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.12952 absTol 1E-5) - + val sigma2 = Matrices.dense(1, 1, Array(4.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.19947 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.15057 absTol 1E-5) } - + test("multivariate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) - + val mu = Vectors.dense(0.0, 0.0) val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0)) val dist1 = new MultivariateGaussian(mu, sigma1) assert(dist1.pdf(x1) ~== 0.15915 absTol 1E-5) assert(dist1.pdf(x2) ~== 0.05855 absTol 1E-5) - + val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0)) val dist2 = new MultivariateGaussian(mu, sigma2) assert(dist2.pdf(x1) ~== 0.060155 absTol 1E-5) assert(dist2.pdf(x2) ~== 0.033971 absTol 1E-5) } - + test("multivariate degenerate") { val x1 = Vectors.dense(0.0, 0.0) val x2 = Vectors.dense(1.0, 1.0) - + val mu = Vectors.dense(0.0, 0.0) val sigma = Matrices.dense(2, 2, Array(1.0, 1.0, 1.0, 1.0)) val dist = new MultivariateGaussian(mu, sigma) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index ce983eb27fa35..356d957f15909 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ @@ -34,7 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils -class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { +class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { ///////////////////////////////////////////////////////////////////////////// // Tests examining individual elements of training @@ -859,7 +858,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } } -object DecisionTreeSuite extends FunSuite { +object DecisionTreeSuite extends SparkFunSuite { def validateClassifier( model: DecisionTreeModel, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 55b0bac7d49fe..84dd3b342d4c0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.mllib.tree -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} @@ -32,7 +31,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[GradientBoostedTrees]]. */ -class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext { test("Regression with continuous features: SquaredError") { GradientBoostedTreesSuite.testCombinations.foreach { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index 92b498580af03..49aff21fe7914 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suites for [[GiniAggregator]] and [[EntropyAggregator]]. */ -class ImpuritySuite extends FunSuite with MLlibTestSparkContext { +class ImpuritySuite extends SparkFunSuite with MLlibTestSparkContext { test("Gini impurity does not support negative labels") { val gini = new GiniAggregator(2) intercept[IllegalArgumentException] { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index ee3bc98486862..e6df5d974bf36 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.tree import scala.collection.mutable -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ @@ -35,7 +34,7 @@ import org.apache.spark.util.Utils /** * Test suite for [[RandomForest]]. */ -class RandomForestSuite extends FunSuite with MLlibTestSparkContext { +class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) @@ -196,7 +195,6 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, featureSubsetStrategy = "sqrt", seed = 12345) - EnsembleTestHelper.validateClassifier(model, arr, 1.0) } test("subsampling rate in RandomForest"){ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala index b184e936672ca..9d756da410325 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.mllib.tree.impl -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.EnsembleTestHelper import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suite for [[BaggedPoint]]. */ -class BaggedPointSuite extends FunSuite with MLlibTestSparkContext { +class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 668fc1d43c5d6..70219e9ad9d3e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -21,19 +21,19 @@ import java.io.File import scala.io.Source -import org.scalatest.FunSuite - import breeze.linalg.{squaredDistance => breezeSquaredDistance} import com.google.common.base.Charsets import com.google.common.io.Files +import org.apache.spark.SparkException +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { +class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { test("epsilon computation") { assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") @@ -63,7 +63,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { val fastSquaredDist3 = fastSquaredDistance(v2, norm2, v3, norm3, precision) assert((fastSquaredDist3 - squaredDist2) <= precision * squaredDist2, s"failed with m = $m") - if (m > 10) { + if (m > 10) { val v4 = Vectors.sparse(n, indices.slice(0, m - 10), indices.map(i => a(i) + 0.5).slice(0, m - 10)) val norm4 = Vectors.norm(v4, 2.0) @@ -109,6 +109,40 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { Utils.deleteRecursively(tempDir) } + test("loadLibSVMFile throws IllegalArgumentException when indices is zero-based") { + val lines = + """ + |0 + |0 0:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + + test("loadLibSVMFile throws IllegalArgumentException when indices is not in ascending order") { + val lines = + """ + |0 + |0 3:4.0 2:5.0 6:6.0 + """.stripMargin + val tempDir = Utils.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + intercept[SparkException] { + loadLibSVMFile(sc, path).collect() + } + Utils.deleteRecursively(tempDir) + } + test("saveAsLibSVMFile") { val examples = sc.parallelize(Seq( LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))), @@ -168,7 +202,7 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { "Each training+validation set combined should contain all of the data.") } // K fold cross validation should only have each element in the validation set exactly once - assert(foldedRdds.map(_._2).reduce((x,y) => x.union(y)).collect().sorted === + assert(foldedRdds.map(_._2).reduce((x, y) => x.union(y)).collect().sorted === data.collect().sorted) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index b658889476d37..5d1796ef65722 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.util -import org.scalatest.Suite -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, Suite} import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.SQLContext trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ + @transient var sqlContext: SQLContext = _ override def beforeAll() { super.beforeAll() @@ -31,12 +32,15 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => .setMaster("local[2]") .setAppName("MLlibUnitTest") sc = new SparkContext(conf) + sqlContext = new SQLContext(sc) } override def afterAll() { + sqlContext = null if (sc != null) { sc.stop() } + sc = null super.afterAll() } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala index f68fb95eac4e4..8dcb9ba9be108 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/NumericParserSuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.util -import org.scalatest.FunSuite +import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.SparkException - -class NumericParserSuite extends FunSuite { +class NumericParserSuite extends SparkFunSuite { test("parser") { val s = "((1.0,2e3),-4,[5e-6,7.0E8],+9)" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index 59e6c778806f4..8f475f30249d6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.util +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.scalatest.FunSuite import org.apache.spark.mllib.util.TestingUtils._ import org.scalatest.exceptions.TestFailedException -class TestingUtilsSuite extends FunSuite { +class TestingUtilsSuite extends SparkFunSuite { test("Comparing doubles using relative error.") { diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java index 60485bace643c..ce954b8a289e4 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -24,6 +24,9 @@ import org.apache.spark.network.protocol.Encoders; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** Request to read a set of blocks. Returns {@link StreamHandle}. */ public class OpenBlocks extends BlockTransferMessage { public final String appId; diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java index 38acae3b31d64..cca8b17c4f129 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -22,6 +22,9 @@ import org.apache.spark.network.protocol.Encoders; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** * Initial registration message between an executor and its local shuffle server. * Returns nothing (empty bye array). diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java index 9a9220211a50c..1915295aa6cc2 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -20,6 +20,9 @@ import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}. diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java index 2ff9aaa650f92..3caed59d508fd 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -24,6 +24,9 @@ import org.apache.spark.network.protocol.Encoders; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ public class UploadBlock extends BlockTransferMessage { diff --git a/pom.xml b/pom.xml index cf9279ea5a2a6..711edf9efad2b 100644 --- a/pom.xml +++ b/pom.xml @@ -107,6 +107,8 @@ examples repl launcher + external/kafka + external/kafka-assembly @@ -122,12 +124,13 @@ 1.7.10 1.2.17 2.2.0 - 2.4.1 + 2.5.0 ${hadoop.version} - 0.98.7-hadoop1 + 0.98.7-hadoop2 hbase 1.4.0 3.4.5 + 2.4.0 org.spark-project.hive 0.13.1a @@ -135,7 +138,7 @@ 0.13.1 10.10.1.1 1.6.0rc3 - 1.2.3 + 1.2.4 8.1.14.v20131031 3.0.0.v201112011016 0.5.0 @@ -143,10 +146,10 @@ 2.0.8 3.1.0 1.7.7 - + hadoop2 0.7.1 - 1.8.3 - 1.1.0 + 1.9.16 + 1.2.1 4.3.2 3.4.1 ${project.build.directory}/spark-test-classpath.txt @@ -155,7 +158,7 @@ ${scala.version} org.scala-lang 3.6.3 - 1.8.8 + 1.9.13 2.4.4 1.1.1.7 1.1.2 @@ -492,7 +495,7 @@ net.jpountz.lz4 lz4 - 1.2.0 + 1.3.0 com.clearspring.analytics @@ -669,7 +672,7 @@ org.mockito mockito-all - 1.9.0 + 1.9.5 test @@ -684,6 +687,18 @@ 4.10 test + + org.hamcrest + hamcrest-core + 1.3 + test + + + org.hamcrest + hamcrest-library + 1.3 + test + com.novocode junit-interface @@ -693,7 +708,7 @@ org.apache.curator curator-recipes - 2.4.0 + ${curator.version} ${hadoop.deps.scope} @@ -702,6 +717,16 @@ + + org.apache.curator + curator-client + ${curator.version} + + + org.apache.curator + curator-framework + ${curator.version} + org.apache.hadoop hadoop-client @@ -1632,26 +1657,27 @@ --> - hadoop-2.2 + hadoop-1 - 2.2.0 - 2.5.0 - 0.98.7-hadoop2 - hadoop2 - 1.9.13 + 1.0.4 + 2.4.1 + 0.98.7-hadoop1 + hadoop1 + 1.8.8 + + hadoop-2.2 + + + hadoop-2.3 2.3.0 - 2.5.0 0.9.3 - 0.98.7-hadoop2 3.1.1 - hadoop2 - 1.9.13 @@ -1659,12 +1685,19 @@ hadoop-2.4 2.4.0 - 2.5.0 0.9.3 - 0.98.7-hadoop2 3.1.1 - hadoop2 - 1.9.13 + +
    + + + hadoop-2.6 + + 2.6.0 + 0.9.3 + 3.1.1 + 3.4.6 + 2.6.0 @@ -1698,7 +1731,7 @@ org.apache.curator curator-recipes - 2.4.0 + ${curator.version} org.apache.zookeeper @@ -1720,22 +1753,6 @@ sql/hive-thriftserver - - hive-0.12.0 - - 0.12.0-protobuf-2.5 - 0.12.0 - 10.4.2.0 - - - - hive-0.13.1 - - 0.13.1a - 0.13.1 - 10.10.1.1 - - scala-2.10 @@ -1748,10 +1765,6 @@ ${scala.version} org.scala-lang - - external/kafka - external/kafka-assembly - diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a47e29e2ef365..8da72b3fa7cdb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -38,6 +38,8 @@ object MimaExcludes { Seq( MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("ml"), + // SPARK-7910 Adding a method to get the partioner to JavaRDD, + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), // SPARK-5922 Adding a generalized diff(other: RDD[(VertexId, VD)]) to VertexRDD ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.VertexRDD.diff"), // These are needed if checking against the sbt build, since they are part of @@ -87,7 +89,14 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Vector.toSparse"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.linalg.Vector.numActives") + "org.apache.spark.mllib.linalg.Vector.numActives"), + // SPARK-7681 add SparseVector support for gemv + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.multiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.multiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.multiply") ) ++ Seq( // Execution should never be included as its always internal. MimaBuild.excludeSparkPackage("sql.execution"), @@ -111,17 +120,43 @@ object MimaExcludes { "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.parquet.ParquetRelation2$PartitionValues$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetRelation2$MetadataCache"), // These test support classes were moved out of src/main and into src/test: ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.parquet.ParquetTestData"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.parquet.ParquetTestData$"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.TestGroupWriteSupport") + "org.apache.spark.sql.parquet.TestGroupWriteSupport"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CachedData$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.CacheManager"), + // TODO: Remove the following rule once ParquetTest has been moved to src/test. + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTest") ) ++ Seq( // SPARK-7530 Added StreamingContext.getState() ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.streaming.StreamingContext.state_=") + ) ++ Seq( + // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some + // unnecessary type bounds in order to fix some compiler warnings that occurred when + // implementing this interface in Java. Note that ShuffleWriter is private[spark]. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.shuffle.ShuffleWriter") + ) ++ Seq( + // SPARK-6888 make jdbc driver handling user definable + // This patch renames some classes to API friendly names. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks") ) case v if v.startsWith("1.3") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1b87e4e98bd83..9a849639233bc 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConversions._ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ -import sbtunidoc.Plugin.genjavadocSettings import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings @@ -118,7 +117,12 @@ object SparkBuild extends PomBuild { lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") - lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq ( + lazy val sparkGenjavadocSettings: Seq[sbt.Def.Setting[_]] = Seq( + libraryDependencies += compilerPlugin( + "org.spark-project" %% "genjavadoc-plugin" % unidocGenjavadocVersion.value cross CrossVersion.full), + scalacOptions <+= target.map(t => "-P:genjavadoc:out=" + (t / "java"))) + + lazy val sharedSettings = graphSettings ++ sparkGenjavadocSettings ++ Seq ( javaHome := sys.env.get("JAVA_HOME") .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) .map(file), @@ -126,7 +130,7 @@ object SparkBuild extends PomBuild { retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, - unidocGenjavadocVersion := "0.8", + unidocGenjavadocVersion := "0.9-spark0", resolvers += Resolver.mavenLocal, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), @@ -324,6 +328,7 @@ object Hive { |import org.apache.spark.sql.functions._ |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ + |import org.apache.spark.sql.hive.test.TestHive.implicits._ |import org.apache.spark.sql.types._""".stripMargin, cleanupCommands in console := "sparkContext.stop()", // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce diff --git a/project/plugins.sbt b/project/plugins.sbt index 7096b0d3ee7de..75bd604a1b857 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -25,7 +25,7 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") -addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1") +addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.3") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst index 8379b8fc8a1e1..518b8e774dd5f 100644 --- a/python/docs/pyspark.ml.rst +++ b/python/docs/pyspark.ml.rst @@ -1,8 +1,8 @@ pyspark.ml package -===================== +================== ML Pipeline APIs --------------- +---------------- .. automodule:: pyspark.ml :members: @@ -10,7 +10,7 @@ ML Pipeline APIs :inherited-members: pyspark.ml.param module -------------------------- +----------------------- .. automodule:: pyspark.ml.param :members: @@ -34,7 +34,7 @@ pyspark.ml.classification module :inherited-members: pyspark.ml.recommendation module -------------------------- +-------------------------------- .. automodule:: pyspark.ml.recommendation :members: @@ -42,7 +42,7 @@ pyspark.ml.recommendation module :inherited-members: pyspark.ml.regression module -------------------------- +---------------------------- .. automodule:: pyspark.ml.regression :members: @@ -50,7 +50,7 @@ pyspark.ml.regression module :inherited-members: pyspark.ml.tuning module --------------------------------- +------------------------ .. automodule:: pyspark.ml.tuning :members: @@ -58,7 +58,7 @@ pyspark.ml.tuning module :inherited-members: pyspark.ml.evaluation module --------------------------------- +---------------------------- .. automodule:: pyspark.ml.evaluation :members: diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 0d21a132048a5..adca90ddaf397 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -261,3 +261,7 @@ def _start_update_server(): thread.daemon = True thread.start() return server + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 31992795a9e45..aeb7ad4f2f83e 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -173,6 +173,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._jvm.PythonAccumulatorParam(host, port)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') + self.pythonVer = "%d.%d" % sys.version_info[:2] # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have @@ -290,6 +291,11 @@ def version(self): """ return self._jsc.version() + @property + def startTime(self): + """Return the epoch time when the Spark Context was started.""" + return self._jsc.startTime() + @property def defaultParallelism(self): """ @@ -318,6 +324,22 @@ def stop(self): with SparkContext._lock: SparkContext._active_spark_context = None + def range(self, start, end, step=1, numSlices=None): + """ + Create a new RDD of int containing elements from `start` to `end` + (exclusive), increased by `step` every element. + + :param start: the start value + :param end: the end value (exclusive) + :param step: the incremental step (default: 1) + :param numSlices: the number of partitions of the new RDD + :return: An RDD of int + + >>> sc.range(1, 7, 2).collect() + [1, 3, 5] + """ + return self.parallelize(xrange(start, end, step), numSlices) + def parallelize(self, c, numSlices=None): """ Distribute a local Python collection to form an RDD. Using xrange diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index da793d9db7f91..327a11b14b5aa 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -15,6 +15,6 @@ # limitations under the License. # -from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel, Evaluator +from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel -__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel", "Evaluator"] +__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"] diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 8a009c4ac721f..7abbde8b260eb 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -17,17 +17,19 @@ from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel -from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ - HasRegParam +from pyspark.ml.param.shared import * +from pyspark.ml.regression import RandomForestParams from pyspark.mllib.common import inherit_doc -__all__ = ['LogisticRegression', 'LogisticRegressionModel'] +__all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', + 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', + 'RandomForestClassifier', 'RandomForestClassificationModel'] @inherit_doc class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, - HasRegParam): + HasRegParam, HasTol, HasProbabilityCol): """ Logistic regression. @@ -41,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> model.transform(test0).head().prediction 0.0 + >>> model.weights + DenseVector([5.5...]) + >>> model.intercept + -2.68... >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 @@ -49,26 +55,52 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti ... TypeError: Method setParams forces keyword arguments. """ - _java_class = "org.apache.spark.ml.classification.LogisticRegression" + + # a placeholder to make it appear in the generated doc + elasticNetParam = \ + Param(Params._dummy(), "elasticNetParam", + "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") + threshold = Param(Params._dummy(), "threshold", + "threshold in binary classification prediction, in range [0, 1].") @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.1): + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + threshold=0.5, probabilityCol="probability"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.1) + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + threshold=0.5, probabilityCol="probability") """ super(LogisticRegression, self).__init__() - self._setDefault(maxIter=100, regParam=0.1) + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.LogisticRegression", self.uid) + #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty + # is an L2 penalty. For alpha = 1, it is an L1 penalty. + self.elasticNetParam = \ + Param(self, "elasticNetParam", + "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty " + + "is an L2 penalty. For alpha = 1, it is an L1 penalty.") + #: param for whether to fit an intercept term. + self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") + #: param for threshold in binary classification prediction, in range [0, 1]. + self.threshold = Param(self, "threshold", + "threshold in binary classification prediction, in range [0, 1].") + self._setDefault(maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1E-6, + fitIntercept=True, threshold=0.5) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxIter=100, regParam=0.1): + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, + threshold=0.5, probabilityCol="probability"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, regParam=0.1) + maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ + threshold=0.5, probabilityCol="probability") Sets params for logistic regression. """ kwargs = self.setParams._input_kwargs @@ -77,12 +109,460 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LogisticRegressionModel(java_model) + def setElasticNetParam(self, value): + """ + Sets the value of :py:attr:`elasticNetParam`. + """ + self._paramMap[self.elasticNetParam] = value + return self + + def getElasticNetParam(self): + """ + Gets the value of elasticNetParam or its default value. + """ + return self.getOrDefault(self.elasticNetParam) + + def setFitIntercept(self, value): + """ + Sets the value of :py:attr:`fitIntercept`. + """ + self._paramMap[self.fitIntercept] = value + return self + + def getFitIntercept(self): + """ + Gets the value of fitIntercept or its default value. + """ + return self.getOrDefault(self.fitIntercept) + + def setThreshold(self, value): + """ + Sets the value of :py:attr:`threshold`. + """ + self._paramMap[self.threshold] = value + return self + + def getThreshold(self): + """ + Gets the value of threshold or its default value. + """ + return self.getOrDefault(self.threshold) + class LogisticRegressionModel(JavaModel): """ Model fitted by LogisticRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + + +class TreeClassifierParams(object): + """ + Private class to track supported impurity measures. + """ + supportedImpurities = ["entropy", "gini"] + + +class GBTParams(object): + """ + Private class to track supported GBT params. + """ + supportedLossTypes = ["logistic"] + + +@inherit_doc +class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, + DecisionTreeParams, HasCheckpointInterval): + """ + `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` + learning algorithm for classification. + It supports both binary and multiclass labels, as well as both continuous and categorical + features. + + >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.feature import StringIndexer + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> si_model = stringIndexer.fit(df) + >>> td = si_model.transform(df) + >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed") + >>> model = dt.fit(td) + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> model.transform(test0).head().prediction + 0.0 + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 1.0 + """ + + # a placeholder to make it appear in the generated doc + impurity = Param(Params._dummy(), "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") + """ + super(DecisionTreeClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.DecisionTreeClassifier", self.uid) + #: param for Criterion used for information gain calculation (case-insensitive). + self.impurity = \ + Param(self, "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="gini") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + impurity="gini"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") + Sets params for the DecisionTreeClassifier. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return DecisionTreeClassificationModel(java_model) + + def setImpurity(self, value): + """ + Sets the value of :py:attr:`impurity`. + """ + self._paramMap[self.impurity] = value + return self + + def getImpurity(self): + """ + Gets the value of impurity or its default value. + """ + return self.getOrDefault(self.impurity) + + +class DecisionTreeClassificationModel(JavaModel): + """ + Model fitted by DecisionTreeClassifier. + """ + + +@inherit_doc +class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, + DecisionTreeParams, HasCheckpointInterval): + """ + `http://en.wikipedia.org/wiki/Random_forest Random Forest` + learning algorithm for classification. + It supports both binary and multiclass labels, as well as both continuous and categorical + features. + + >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.feature import StringIndexer + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> si_model = stringIndexer.fit(df) + >>> td = si_model.transform(df) + >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42) + >>> model = rf.fit(td) + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> model.transform(test0).head().prediction + 0.0 + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 1.0 + """ + + # a placeholder to make it appear in the generated doc + impurity = Param(Params._dummy(), "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + subsamplingRate = Param(Params._dummy(), "subsamplingRate", + "Fraction of the training data used for learning each decision tree, " + + "in range (0, 1].") + numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1)") + featureSubsetStrategy = \ + Param(Params._dummy(), "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", + numTrees=20, featureSubsetStrategy="auto", seed=None): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ + numTrees=20, featureSubsetStrategy="auto", seed=None) + """ + super(RandomForestClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.RandomForestClassifier", self.uid) + #: param for Criterion used for information gain calculation (case-insensitive). + self.impurity = \ + Param(self, "impurity", + "Criterion used for information gain calculation (case-insensitive). " + + "Supported options: " + ", ".join(TreeClassifierParams.supportedImpurities)) + #: param for Fraction of the training data used for learning each decision tree, + # in range (0, 1] + self.subsamplingRate = Param(self, "subsamplingRate", + "Fraction of the training data used for learning each " + + "decision tree, in range (0, 1].") + #: param for Number of trees to train (>= 1) + self.numTrees = Param(self, "numTrees", "Number of trees to train (>= 1)") + #: param for The number of features to consider for splits at each tree node + self.featureSubsetStrategy = \ + Param(self, "featureSubsetStrategy", + "The number of features to consider for splits at each tree node. Supported " + + "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, + impurity="gini", numTrees=20, featureSubsetStrategy="auto") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, + impurity="gini", numTrees=20, featureSubsetStrategy="auto"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ + impurity="gini", numTrees=20, featureSubsetStrategy="auto") + Sets params for linear classification. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return RandomForestClassificationModel(java_model) + + def setImpurity(self, value): + """ + Sets the value of :py:attr:`impurity`. + """ + self._paramMap[self.impurity] = value + return self + + def getImpurity(self): + """ + Gets the value of impurity or its default value. + """ + return self.getOrDefault(self.impurity) + + def setSubsamplingRate(self, value): + """ + Sets the value of :py:attr:`subsamplingRate`. + """ + self._paramMap[self.subsamplingRate] = value + return self + + def getSubsamplingRate(self): + """ + Gets the value of subsamplingRate or its default value. + """ + return self.getOrDefault(self.subsamplingRate) + + def setNumTrees(self, value): + """ + Sets the value of :py:attr:`numTrees`. + """ + self._paramMap[self.numTrees] = value + return self + + def getNumTrees(self): + """ + Gets the value of numTrees or its default value. + """ + return self.getOrDefault(self.numTrees) + + def setFeatureSubsetStrategy(self, value): + """ + Sets the value of :py:attr:`featureSubsetStrategy`. + """ + self._paramMap[self.featureSubsetStrategy] = value + return self + + def getFeatureSubsetStrategy(self): + """ + Gets the value of featureSubsetStrategy or its default value. + """ + return self.getOrDefault(self.featureSubsetStrategy) + + +class RandomForestClassificationModel(JavaModel): + """ + Model fitted by RandomForestClassifier. + """ + + +@inherit_doc +class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, + DecisionTreeParams, HasCheckpointInterval): + """ + `http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)` + learning algorithm for classification. + It supports binary labels, as well as both continuous and categorical features. + Note: Multiclass labels are not currently supported. + + >>> from pyspark.mllib.linalg import Vectors + >>> from pyspark.ml.feature import StringIndexer + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> si_model = stringIndexer.fit(df) + >>> td = si_model.transform(df) + >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed") + >>> model = gbt.fit(td) + >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) + >>> model.transform(test0).head().prediction + 0.0 + >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) + >>> model.transform(test1).head().prediction + 1.0 + """ + + # a placeholder to make it appear in the generated doc + lossType = Param(Params._dummy(), "lossType", + "Loss function which GBT tries to minimize (case-insensitive). " + + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + subsamplingRate = Param(Params._dummy(), "subsamplingRate", + "Fraction of the training data used for learning each decision tree, " + + "in range (0, 1].") + stepSize = Param(Params._dummy(), "stepSize", + "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the " + + "contribution of each estimator") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", + maxIter=20, stepSize=0.1): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + lossType="logistic", maxIter=20, stepSize=0.1) + """ + super(GBTClassifier, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.GBTClassifier", self.uid) + #: param for Loss function which GBT tries to minimize (case-insensitive). + self.lossType = Param(self, "lossType", + "Loss function which GBT tries to minimize (case-insensitive). " + + "Supported options: " + ", ".join(GBTParams.supportedLossTypes)) + #: Fraction of the training data used for learning each decision tree, in range (0, 1]. + self.subsamplingRate = Param(self, "subsamplingRate", + "Fraction of the training data used for learning each " + + "decision tree, in range (0, 1].") + #: Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of + # each estimator + self.stepSize = Param(self, "stepSize", + "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + + "the contribution of each estimator") + self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + lossType="logistic", maxIter=20, stepSize=0.1) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + lossType="logistic", maxIter=20, stepSize=0.1): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + lossType="logistic", maxIter=20, stepSize=0.1) + Sets params for Gradient Boosted Tree Classification. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return GBTClassificationModel(java_model) + + def setLossType(self, value): + """ + Sets the value of :py:attr:`lossType`. + """ + self._paramMap[self.lossType] = value + return self + + def getLossType(self): + """ + Gets the value of lossType or its default value. + """ + return self.getOrDefault(self.lossType) + + def setSubsamplingRate(self, value): + """ + Sets the value of :py:attr:`subsamplingRate`. + """ + self._paramMap[self.subsamplingRate] = value + return self + + def getSubsamplingRate(self): + """ + Gets the value of subsamplingRate or its default value. + """ + return self.getOrDefault(self.subsamplingRate) + + def setStepSize(self, value): + """ + Sets the value of :py:attr:`stepSize`. + """ + self._paramMap[self.stepSize] = value + return self + + def getStepSize(self): + """ + Gets the value of stepSize or its default value. + """ + return self.getOrDefault(self.stepSize) + + +class GBTClassificationModel(JavaModel): + """ + Model fitted by GBTClassifier. + """ + if __name__ == "__main__": import doctest diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 02020ebff94c2..d8ddb78c6d639 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -15,13 +15,72 @@ # limitations under the License. # -from pyspark.ml.wrapper import JavaEvaluator +from abc import abstractmethod, ABCMeta + +from pyspark.ml.wrapper import JavaWrapper from pyspark.ml.param import Param, Params -from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol +from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc -__all__ = ['BinaryClassificationEvaluator'] +__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator'] + + +@inherit_doc +class Evaluator(Params): + """ + Base class for evaluators that compute metrics from predictions. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def _evaluate(self, dataset): + """ + Evaluates the output. + + :param dataset: a dataset that contains labels/observations and + predictions + :return: metric + """ + raise NotImplementedError() + + def evaluate(self, dataset, params={}): + """ + Evaluates the output with optional parameters. + + :param dataset: a dataset that contains labels/observations and + predictions + :param params: an optional param map that overrides embedded + params + :return: metric + """ + if isinstance(params, dict): + if params: + return self.copy(params)._evaluate(dataset) + else: + return self._evaluate(dataset) + else: + raise ValueError("Params must be a param map but got %s." % type(params)) + + +@inherit_doc +class JavaEvaluator(Evaluator, JavaWrapper): + """ + Base class for :py:class:`Evaluator`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def _evaluate(self, dataset): + """ + Evaluates the output. + :param dataset: a dataset that contains labels/observations and predictions. + :return: evaluation metric + """ + self._transfer_params_to_java() + return self._java_obj.evaluate(dataset._jdf) @inherit_doc @@ -42,8 +101,6 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction 0.83... """ - _java_class = "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator" - # a placeholder to make it appear in the generated doc metricName = Param(Params._dummy(), "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)") @@ -56,6 +113,8 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC") """ super(BinaryClassificationEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid) #: param for metric name in evaluation (areaUnderROC|areaUnderPR) self.metricName = Param(self, "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)") @@ -68,7 +127,7 @@ def setMetricName(self, value): """ Sets the value of :py:attr:`metricName`. """ - self.paramMap[self.metricName] = value + self._paramMap[self.metricName] = value return self def getMetricName(self): @@ -89,6 +148,70 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", return self._set(**kwargs) +@inherit_doc +class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): + """ + Evaluator for Regression, which expects two input + columns: prediction and label. + + >>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5), + ... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)] + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"]) + ... + >>> evaluator = RegressionEvaluator(predictionCol="raw") + >>> evaluator.evaluate(dataset) + 2.842... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"}) + 0.993... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"}) + 2.649... + """ + # a placeholder to make it appear in the generated doc + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation (mse|rmse|r2|mae)") + + @keyword_only + def __init__(self, predictionCol="prediction", labelCol="label", + metricName="rmse"): + """ + __init__(self, predictionCol="prediction", labelCol="label", \ + metricName="rmse") + """ + super(RegressionEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) + #: param for metric name in evaluation (mse|rmse|r2|mae) + self.metricName = Param(self, "metricName", + "metric name in evaluation (mse|rmse|r2|mae)") + self._setDefault(predictionCol="prediction", labelCol="label", + metricName="rmse") + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + self._paramMap[self.metricName] = value + return self + + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + def setParams(self, predictionCol="prediction", labelCol="label", + metricName="rmse"): + """ + setParams(self, predictionCol="prediction", labelCol="label", \ + metricName="rmse") + Sets params for regression evaluator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + if __name__ == "__main__": import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index f35bc1463d51b..ddb33f427ac64 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -43,7 +43,6 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): 1.0 """ - _java_class = "org.apache.spark.ml.feature.Binarizer" # a placeholder to make it appear in the generated doc threshold = Param(Params._dummy(), "threshold", "threshold in binary classification prediction, in range [0, 1]") @@ -54,6 +53,7 @@ def __init__(self, threshold=0.0, inputCol=None, outputCol=None): __init__(self, threshold=0.0, inputCol=None, outputCol=None) """ super(Binarizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid) self.threshold = Param(self, "threshold", "threshold in binary classification prediction, in range [0, 1]") self._setDefault(threshold=0.0) @@ -73,7 +73,7 @@ def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. """ - self.paramMap[self.threshold] = value + self._paramMap[self.threshold] = value return self def getThreshold(self): @@ -83,6 +83,83 @@ def getThreshold(self): return self.getOrDefault(self.threshold) +@inherit_doc +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): + """ + Maps a column of continuous features to a column of feature buckets. + + >>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) + >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")], + ... inputCol="values", outputCol="buckets") + >>> bucketed = bucketizer.transform(df).collect() + >>> bucketed[0].buckets + 0.0 + >>> bucketed[1].buckets + 0.0 + >>> bucketed[2].buckets + 1.0 + >>> bucketed[3].buckets + 2.0 + >>> bucketizer.setParams(outputCol="b").transform(df).head().b + 0.0 + """ + + # a placeholder to make it appear in the generated doc + splits = \ + Param(Params._dummy(), "splits", + "Split points for mapping continuous features into buckets. With n+1 splits, " + + "there are n buckets. A bucket defined by splits x,y holds values in the " + + "range [x,y) except the last bucket, which also includes y. The splits " + + "should be strictly increasing. Values at -inf, inf must be explicitly " + + "provided to cover all Double values; otherwise, values outside the splits " + + "specified will be treated as errors.") + + @keyword_only + def __init__(self, splits=None, inputCol=None, outputCol=None): + """ + __init__(self, splits=None, inputCol=None, outputCol=None) + """ + super(Bucketizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) + #: param for Splitting points for mapping continuous features into buckets. With n+1 splits, + # there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) + # except the last bucket, which also includes y. The splits should be strictly increasing. + # Values at -inf, inf must be explicitly provided to cover all Double values; otherwise, + # values outside the splits specified will be treated as errors. + self.splits = \ + Param(self, "splits", + "Split points for mapping continuous features into buckets. With n+1 splits, " + + "there are n buckets. A bucket defined by splits x,y holds values in the " + + "range [x,y) except the last bucket, which also includes y. The splits " + + "should be strictly increasing. Values at -inf, inf must be explicitly " + + "provided to cover all Double values; otherwise, values outside the splits " + + "specified will be treated as errors.") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, splits=None, inputCol=None, outputCol=None): + """ + setParams(self, splits=None, inputCol=None, outputCol=None) + Sets params for this Bucketizer. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setSplits(self, value): + """ + Sets the value of :py:attr:`splits`. + """ + self._paramMap[self.splits] = value + return self + + def getSplits(self): + """ + Gets the value of threshold or its default value. + """ + return self.getOrDefault(self.splits) + + @inherit_doc class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): """ @@ -100,14 +177,13 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0}) """ - _java_class = "org.apache.spark.ml.feature.HashingTF" - @keyword_only def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) """ super(HashingTF, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid) self._setDefault(numFeatures=1 << 18) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -140,8 +216,6 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol): DenseVector([0.2877, 0.0]) """ - _java_class = "org.apache.spark.ml.feature.IDF" - # a placeholder to make it appear in the generated doc minDocFreq = Param(Params._dummy(), "minDocFreq", "minimum of documents in which a term should appear for filtering") @@ -152,6 +226,7 @@ def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): __init__(self, minDocFreq=0, inputCol=None, outputCol=None) """ super(IDF, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid) self.minDocFreq = Param(self, "minDocFreq", "minimum of documents in which a term should appear for filtering") self._setDefault(minDocFreq=0) @@ -171,7 +246,7 @@ def setMinDocFreq(self, value): """ Sets the value of :py:attr:`minDocFreq`. """ - self.paramMap[self.minDocFreq] = value + self._paramMap[self.minDocFreq] = value return self def getMinDocFreq(self): @@ -180,6 +255,9 @@ def getMinDocFreq(self): """ return self.getOrDefault(self.minDocFreq) + def _create_model(self, java_model): + return IDFModel(java_model) + class IDFModel(JavaModel): """ @@ -208,14 +286,13 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): # a placeholder to make it appear in the generated doc p = Param(Params._dummy(), "p", "the p norm value.") - _java_class = "org.apache.spark.ml.feature.Normalizer" - @keyword_only def __init__(self, p=2.0, inputCol=None, outputCol=None): """ __init__(self, p=2.0, inputCol=None, outputCol=None) """ super(Normalizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid) self.p = Param(self, "p", "the p norm value.") self._setDefault(p=2.0) kwargs = self.__init__._input_kwargs @@ -234,7 +311,7 @@ def setP(self, value): """ Sets the value of :py:attr:`p`. """ - self.paramMap[self.p] = value + self._paramMap[self.p] = value return self def getP(self): @@ -247,66 +324,73 @@ def getP(self): @inherit_doc class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): """ - A one-hot encoder that maps a column of label indices to a column of binary vectors, with - at most a single one-value. By default, the binary vector has an element for each category, so - with 5 categories, an input value of 2.0 would map to an output vector of - (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so - the output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value - of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns - linearly dependent because they sum up to one. + A one-hot encoder that maps a column of category indices to a + column of binary vectors, with at most a single one-value per row + that indicates the input category index. + For example with 5 categories, an input value of 2.0 would map to + an output vector of `[0.0, 0.0, 1.0, 0.0]`. + The last category is not included by default (configurable via + :py:attr:`dropLast`) because it makes the vector entries sum up to + one, and hence linearly dependent. + So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + Note that this is different from scikit-learn's OneHotEncoder, + which keeps all categories. + The output vectors are sparse. + + .. seealso:: - TODO: This method requires the use of StringIndexer first. Decouple them. + :py:class:`StringIndexer` for converting categorical values into + category indices >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> model = stringIndexer.fit(stringIndDf) >>> td = model.transform(stringIndDf) - >>> encoder = OneHotEncoder(includeFirst=False, inputCol="indexed", outputCol="features") + >>> encoder = OneHotEncoder(inputCol="indexed", outputCol="features") >>> encoder.transform(td).head().features - SparseVector(2, {}) + SparseVector(2, {0: 1.0}) >>> encoder.setParams(outputCol="freqs").transform(td).head().freqs - SparseVector(2, {}) - >>> params = {encoder.includeFirst: True, encoder.outputCol: "test"} + SparseVector(2, {0: 1.0}) + >>> params = {encoder.dropLast: False, encoder.outputCol: "test"} >>> encoder.transform(td, params).head().test SparseVector(3, {0: 1.0}) """ - _java_class = "org.apache.spark.ml.feature.OneHotEncoder" - # a placeholder to make it appear in the generated doc - includeFirst = Param(Params._dummy(), "includeFirst", "include first category") + dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category") @keyword_only - def __init__(self, includeFirst=True, inputCol=None, outputCol=None): + def __init__(self, dropLast=True, inputCol=None, outputCol=None): """ __init__(self, includeFirst=True, inputCol=None, outputCol=None) """ super(OneHotEncoder, self).__init__() - self.includeFirst = Param(self, "includeFirst", "include first category") - self._setDefault(includeFirst=True) + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) + self.dropLast = Param(self, "dropLast", "whether to drop the last category") + self._setDefault(dropLast=True) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, includeFirst=True, inputCol=None, outputCol=None): + def setParams(self, dropLast=True, inputCol=None, outputCol=None): """ - setParams(self, includeFirst=True, inputCol=None, outputCol=None) + setParams(self, dropLast=True, inputCol=None, outputCol=None) Sets params for this OneHotEncoder. """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) - def setIncludeFirst(self, value): + def setDropLast(self, value): """ - Sets the value of :py:attr:`includeFirst`. + Sets the value of :py:attr:`dropLast`. """ - self.paramMap[self.includeFirst] = value + self._paramMap[self.dropLast] = value return self - def getIncludeFirst(self): + def getDropLast(self): """ - Gets the value of includeFirst or its default value. + Gets the value of dropLast or its default value. """ - return self.getOrDefault(self.includeFirst) + return self.getOrDefault(self.dropLast) @inherit_doc @@ -327,8 +411,6 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) """ - _java_class = "org.apache.spark.ml.feature.PolynomialExpansion" - # a placeholder to make it appear in the generated doc degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)") @@ -338,6 +420,8 @@ def __init__(self, degree=2, inputCol=None, outputCol=None): __init__(self, degree=2, inputCol=None, outputCol=None) """ super(PolynomialExpansion, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.feature.PolynomialExpansion", self.uid) self.degree = Param(self, "degree", "the polynomial degree to expand (>= 1)") self._setDefault(degree=2) kwargs = self.__init__._input_kwargs @@ -356,7 +440,7 @@ def setDegree(self, value): """ Sets the value of :py:attr:`degree`. """ - self.paramMap[self.degree] = value + self._paramMap[self.degree] = value return self def getDegree(self): @@ -370,23 +454,25 @@ def getDegree(self): @ignore_unicode_prefix class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ - A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default) - or using it to split the text (set matching to false). Optional parameters also allow filtering - tokens using a minimal length. + A regex based tokenizer that extracts tokens either by using the + provided regex pattern (in Java dialect) to split the text + (default) or repeatedly matching the regex (if gaps is true). + Optional parameters also allow filtering tokens using a minimal + length. It returns an array of strings that can be empty. - >>> df = sqlContext.createDataFrame([("a b c",)], ["text"]) + >>> df = sqlContext.createDataFrame([("a b c",)], ["text"]) >>> reTokenizer = RegexTokenizer(inputCol="text", outputCol="words") >>> reTokenizer.transform(df).head() - Row(text=u'a b c', words=[u'a', u'b', u'c']) + Row(text=u'a b c', words=[u'a', u'b', u'c']) >>> # Change a parameter. >>> reTokenizer.setParams(outputCol="tokens").transform(df).head() - Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) >>> # Temporarily modify a parameter. >>> reTokenizer.transform(df, {reTokenizer.outputCol: "words"}).head() - Row(text=u'a b c', words=[u'a', u'b', u'c']) + Row(text=u'a b c', words=[u'a', u'b', u'c']) >>> reTokenizer.transform(df).head() - Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) >>> # Must use keyword arguments to specify params. >>> reTokenizer.setParams("text") Traceback (most recent call last): @@ -394,33 +480,29 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): TypeError: Method setParams forces keyword arguments. """ - _java_class = "org.apache.spark.ml.feature.RegexTokenizer" # a placeholder to make it appear in the generated doc minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)") - gaps = Param(Params._dummy(), "gaps", "Set regex to match gaps or tokens") - pattern = Param(Params._dummy(), "pattern", "regex pattern used for tokenizing") + gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens") + pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing") @keyword_only - def __init__(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", - inputCol=None, outputCol=None): + def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None): """ - __init__(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", - inputCol=None, outputCol=None) + __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None) """ super(RegexTokenizer, self).__init__() - self.minTokenLength = Param(self, "minLength", "minimum token length (>= 0)") - self.gaps = Param(self, "gaps", "Set regex to match gaps or tokens") - self.pattern = Param(self, "pattern", "regex pattern used for tokenizing") - self._setDefault(minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+") + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid) + self.minTokenLength = Param(self, "minTokenLength", "minimum token length (>= 0)") + self.gaps = Param(self, "gaps", "whether regex splits on gaps (True) or matches tokens") + self.pattern = Param(self, "pattern", "regex pattern (Java dialect) used for tokenizing") + self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", - inputCol=None, outputCol=None): + def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None): """ - setParams(self, minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+", - inputCol="input", outputCol="output") + setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None) Sets params for this RegexTokenizer. """ kwargs = self.setParams._input_kwargs @@ -430,7 +512,7 @@ def setMinTokenLength(self, value): """ Sets the value of :py:attr:`minTokenLength`. """ - self.paramMap[self.minTokenLength] = value + self._paramMap[self.minTokenLength] = value return self def getMinTokenLength(self): @@ -443,7 +525,7 @@ def setGaps(self, value): """ Sets the value of :py:attr:`gaps`. """ - self.paramMap[self.gaps] = value + self._paramMap[self.gaps] = value return self def getGaps(self): @@ -456,7 +538,7 @@ def setPattern(self, value): """ Sets the value of :py:attr:`pattern`. """ - self.paramMap[self.pattern] = value + self._paramMap[self.pattern] = value return self def getPattern(self): @@ -480,8 +562,6 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): DenseVector([1.4142]) """ - _java_class = "org.apache.spark.ml.feature.StandardScaler" - # a placeholder to make it appear in the generated doc withMean = Param(Params._dummy(), "withMean", "Center data with mean") withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation") @@ -492,6 +572,7 @@ def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None) """ super(StandardScaler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid) self.withMean = Param(self, "withMean", "Center data with mean") self.withStd = Param(self, "withStd", "Scale to unit standard deviation") self._setDefault(withMean=False, withStd=True) @@ -511,7 +592,7 @@ def setWithMean(self, value): """ Sets the value of :py:attr:`withMean`. """ - self.paramMap[self.withMean] = value + self._paramMap[self.withMean] = value return self def getWithMean(self): @@ -524,7 +605,7 @@ def setWithStd(self, value): """ Sets the value of :py:attr:`withStd`. """ - self.paramMap[self.withStd] = value + self._paramMap[self.withStd] = value return self def getWithStd(self): @@ -533,6 +614,9 @@ def getWithStd(self): """ return self.getOrDefault(self.withStd) + def _create_model(self, java_model): + return StandardScalerModel(java_model) + class StandardScalerModel(JavaModel): """ @@ -556,14 +640,13 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] """ - _java_class = "org.apache.spark.ml.feature.StringIndexer" - @keyword_only def __init__(self, inputCol=None, outputCol=None): """ __init__(self, inputCol=None, outputCol=None) """ super(StringIndexer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -576,6 +659,9 @@ def setParams(self, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + def _create_model(self, java_model): + return StringIndexerModel(java_model) + class StringIndexerModel(JavaModel): """ @@ -609,14 +695,13 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): TypeError: Method setParams forces keyword arguments. """ - _java_class = "org.apache.spark.ml.feature.Tokenizer" - @keyword_only def __init__(self, inputCol=None, outputCol=None): """ __init__(self, inputCol=None, outputCol=None) """ super(Tokenizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -646,14 +731,13 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): DenseVector([0.0, 1.0]) """ - _java_class = "org.apache.spark.ml.feature.VectorAssembler" - @keyword_only def __init__(self, inputCols=None, outputCol=None): """ __init__(self, inputCols=None, outputCol=None) """ super(VectorAssembler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -720,7 +804,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): DenseVector([1.0, 0.0]) """ - _java_class = "org.apache.spark.ml.feature.VectorIndexer" # a placeholder to make it appear in the generated doc maxCategories = Param(Params._dummy(), "maxCategories", "Threshold for the number of values a categorical feature can take " + @@ -733,6 +816,7 @@ def __init__(self, maxCategories=20, inputCol=None, outputCol=None): __init__(self, maxCategories=20, inputCol=None, outputCol=None) """ super(VectorIndexer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) self.maxCategories = Param(self, "maxCategories", "Threshold for the number of values a categorical feature " + "can take (>= 2). If a feature is found to have " + @@ -754,7 +838,7 @@ def setMaxCategories(self, value): """ Sets the value of :py:attr:`maxCategories`. """ - self.paramMap[self.maxCategories] = value + self._paramMap[self.maxCategories] = value return self def getMaxCategories(self): @@ -763,6 +847,15 @@ def getMaxCategories(self): """ return self.getOrDefault(self.maxCategories) + def _create_model(self, java_model): + return VectorIndexerModel(java_model) + + +class VectorIndexerModel(JavaModel): + """ + Model fitted by VectorIndexer. + """ + @inherit_doc @ignore_unicode_prefix @@ -778,7 +871,6 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) """ - _java_class = "org.apache.spark.ml.feature.Word2Vec" # a placeholder to make it appear in the generated doc vectorSize = Param(Params._dummy(), "vectorSize", "the dimension of codes after transforming from words") @@ -790,12 +882,13 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has @keyword_only def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, - seed=42, inputCol=None, outputCol=None): + seed=None, inputCol=None, outputCol=None): """ - __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, - seed=42, inputCol=None, outputCol=None) + __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, \ + seed=None, inputCol=None, outputCol=None) """ super(Word2Vec, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid) self.vectorSize = Param(self, "vectorSize", "the dimension of codes after transforming from words") self.numPartitions = Param(self, "numPartitions", @@ -804,15 +897,15 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, "the minimum number of times a token must appear to be included " + "in the word2vec model's vocabulary") self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, - seed=42) + seed=None) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, - seed=42, inputCol=None, outputCol=None): + seed=None, inputCol=None, outputCol=None): """ - setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=42, + setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, \ inputCol=None, outputCol=None) Sets params for this Word2Vec. """ @@ -823,7 +916,7 @@ def setVectorSize(self, value): """ Sets the value of :py:attr:`vectorSize`. """ - self.paramMap[self.vectorSize] = value + self._paramMap[self.vectorSize] = value return self def getVectorSize(self): @@ -836,7 +929,7 @@ def setNumPartitions(self, value): """ Sets the value of :py:attr:`numPartitions`. """ - self.paramMap[self.numPartitions] = value + self._paramMap[self.numPartitions] = value return self def getNumPartitions(self): @@ -849,7 +942,7 @@ def setMinCount(self, value): """ Sets the value of :py:attr:`minCount`. """ - self.paramMap[self.minCount] = value + self._paramMap[self.minCount] = value return self def getMinCount(self): @@ -858,6 +951,9 @@ def getMinCount(self): """ return self.getOrDefault(self.minCount) + def _create_model(self, java_model): + return Word2VecModel(java_model) + class Word2VecModel(JavaModel): """ diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 49c20b4cf70cf..7845536161e07 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -16,6 +16,7 @@ # from abc import ABCMeta +import copy from pyspark.ml.util import Identifiable @@ -29,9 +30,9 @@ class Param(object): """ def __init__(self, parent, name, doc): - if not isinstance(parent, Params): - raise TypeError("Parent must be a Params but got type %s." % type(parent)) - self.parent = parent + if not isinstance(parent, Identifiable): + raise TypeError("Parent must be an Identifiable but got type %s." % type(parent)) + self.parent = parent.uid self.name = str(name) self.doc = str(doc) @@ -41,6 +42,15 @@ def __str__(self): def __repr__(self): return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc) + def __hash__(self): + return hash(str(self)) + + def __eq__(self, other): + if isinstance(other, Param): + return self.parent == other.parent and self.name == other.name + else: + return False + class Params(Identifiable): """ @@ -51,10 +61,13 @@ class Params(Identifiable): __metaclass__ = ABCMeta #: internal param map for user-supplied values param map - paramMap = {} + _paramMap = {} #: internal param map for default values - defaultParamMap = {} + _defaultParamMap = {} + + #: value returned by :py:func:`params` + _params = None @property def params(self): @@ -63,10 +76,12 @@ def params(self): uses :py:func:`dir` to get all attributes of type :py:class:`Param`. """ - return list(filter(lambda attr: isinstance(attr, Param), - [getattr(self, x) for x in dir(self) if x != "params"])) + if self._params is None: + self._params = list(filter(lambda attr: isinstance(attr, Param), + [getattr(self, x) for x in dir(self) if x != "params"])) + return self._params - def _explain(self, param): + def explainParam(self, param): """ Explains a single param and returns its name, doc, and optional default value and user-supplied value in a string. @@ -74,10 +89,10 @@ def _explain(self, param): param = self._resolveParam(param) values = [] if self.isDefined(param): - if param in self.defaultParamMap: - values.append("default: %s" % self.defaultParamMap[param]) - if param in self.paramMap: - values.append("current: %s" % self.paramMap[param]) + if param in self._defaultParamMap: + values.append("default: %s" % self._defaultParamMap[param]) + if param in self._paramMap: + values.append("current: %s" % self._paramMap[param]) else: values.append("undefined") valueStr = "(" + ", ".join(values) + ")" @@ -88,7 +103,7 @@ def explainParams(self): Returns the documentation of all params with their optionally default values and user-supplied values. """ - return "\n".join([self._explain(param) for param in self.params]) + return "\n".join([self.explainParam(param) for param in self.params]) def getParam(self, paramName): """ @@ -105,56 +120,76 @@ def isSet(self, param): Checks whether a param is explicitly set by user. """ param = self._resolveParam(param) - return param in self.paramMap + return param in self._paramMap def hasDefault(self, param): """ Checks whether a param has a default value. """ param = self._resolveParam(param) - return param in self.defaultParamMap + return param in self._defaultParamMap def isDefined(self, param): """ - Checks whether a param is explicitly set by user or has a default value. + Checks whether a param is explicitly set by user or has + a default value. """ return self.isSet(param) or self.hasDefault(param) + def hasParam(self, paramName): + """ + Tests whether this instance contains a param with a given + (string) name. + """ + param = self._resolveParam(paramName) + return param in self.params + def getOrDefault(self, param): """ Gets the value of a param in the user-supplied param map or its - default value. Raises an error if either is set. + default value. Raises an error if neither is set. """ - if isinstance(param, Param): - if param in self.paramMap: - return self.paramMap[param] - else: - return self.defaultParamMap[param] - elif isinstance(param, str): - return self.getOrDefault(self.getParam(param)) + param = self._resolveParam(param) + if param in self._paramMap: + return self._paramMap[param] else: - raise KeyError("Cannot recognize %r as a param." % param) + return self._defaultParamMap[param] - def extractParamMap(self, extraParamMap={}): + def extractParamMap(self, extra={}): """ Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into a flat param map, where the latter value is used if there exist conflicts, i.e., with ordering: default param values < - user-supplied values < extraParamMap. - :param extraParamMap: extra param values + user-supplied values < extra. + :param extra: extra param values :return: merged param map """ - paramMap = self.defaultParamMap.copy() - paramMap.update(self.paramMap) - paramMap.update(extraParamMap) + paramMap = self._defaultParamMap.copy() + paramMap.update(self._paramMap) + paramMap.update(extra) return paramMap + def copy(self, extra={}): + """ + Creates a copy of this instance with the same uid and some + extra params. The default implementation creates a + shallow copy using :py:func:`copy.copy`, and then copies the + embedded and extra parameters over and returns the copy. + Subclasses should override this method if the default approach + is not sufficient. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + that = copy.copy(self) + that._paramMap = self.extractParamMap(extra) + return that + def _shouldOwn(self, param): """ Validates that the input param belongs to this Params instance. """ - if param.parent is not self: + if not (self.uid == param.parent and self.hasParam(param.name)): raise ValueError("Param %r does not belong to %r." % (param, self)) def _resolveParam(self, param): @@ -175,7 +210,8 @@ def _resolveParam(self, param): @staticmethod def _dummy(): """ - Returns a dummy Params instance used as a placeholder to generate docs. + Returns a dummy Params instance used as a placeholder to + generate docs. """ dummy = Params() dummy.uid = "undefined" @@ -186,7 +222,7 @@ def _set(self, **kwargs): Sets user-supplied params. """ for param, value in kwargs.items(): - self.paramMap[getattr(self, param)] = value + self._paramMap[getattr(self, param)] = value return self def _setDefault(self, **kwargs): @@ -194,5 +230,19 @@ def _setDefault(self, **kwargs): Sets default params. """ for param, value in kwargs.items(): - self.defaultParamMap[getattr(self, param)] = value + self._defaultParamMap[getattr(self, param)] = value return self + + def _copyValues(self, to, extra={}): + """ + Copies param values from this instance to another instance for + params shared by them. + :param to: the target instance + :param extra: extra params to be copied + :return: the target instance with param values copied + """ + paramMap = self.extractParamMap(extra) + for p in self.params: + if p in paramMap and to.hasParam(p.name): + to._set(**{p.name: paramMap[p]}) + return to diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 4a5cc6e64f023..69efc424ec4ef 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -56,9 +56,10 @@ def _gen_param_header(name, doc, defaultValueStr): def __init__(self): super(Has$Name, self).__init__() #: param for $doc - self.$name = Param(self, "$name", "$doc") - if $defaultValueStr is not None: - self._setDefault($name=$defaultValueStr)''' + self.$name = Param(self, "$name", "$doc")''' + if defaultValueStr is not None: + template += ''' + self._setDefault($name=$defaultValueStr)''' Name = name[0].upper() + name[1:] return template \ @@ -83,7 +84,7 @@ def set$Name(self, value): """ Sets the value of :py:attr:`$name`. """ - self.paramMap[self.$name] = value + self._paramMap[self.$name] = value return self def get$Name(self): @@ -109,13 +110,16 @@ def get$Name(self): ("featuresCol", "features column name", "'features'"), ("labelCol", "label column name", "'label'"), ("predictionCol", "prediction column name", "'prediction'"), + ("probabilityCol", "Column name for predicted class conditional probabilities. " + + "Note: Not all models output well-calibrated probability estimates! These probabilities " + + "should be treated as confidences, not precise probabilities.", "'probability'"), ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", "'rawPrediction'"), ("inputCol", "input column name", None), ("inputCols", "input column names", None), - ("outputCol", "output column name", None), + ("outputCol", "output column name", "self.uid + '__output'"), ("numFeatures", "number of features", None), ("checkpointInterval", "checkpoint interval (>= 1)", None), - ("seed", "random seed", None), + ("seed", "random seed", "hash(type(self).__name__)"), ("tol", "the convergence tolerance for iterative algorithms", None), ("stepSize", "Step size to be used for each iteration of optimization.", None)] code = [] @@ -156,6 +160,7 @@ def __init__(self): for name, doc in decisionTreeParams: variable = paramTemplate.replace("$name", name).replace("$doc", doc) dummyPlaceholders += variable.replace("$owner", "Params._dummy()") + "\n " + realParams += "#: param for " + doc + "\n " realParams += "self." + variable.replace("$owner", "self") + "\n " dtParamMethods += _gen_param_code(name, doc, None) + "\n" code.append(decisionTreeCode.replace("$dummyPlaceHolders", dummyPlaceholders) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 779cabe853f8e..bc088e4c29e26 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -32,14 +32,12 @@ def __init__(self): super(HasMaxIter, self).__init__() #: param for max number of iterations (>= 0) self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0)") - if None is not None: - self._setDefault(maxIter=None) def setMaxIter(self, value): """ Sets the value of :py:attr:`maxIter`. """ - self.paramMap[self.maxIter] = value + self._paramMap[self.maxIter] = value return self def getMaxIter(self): @@ -61,14 +59,12 @@ def __init__(self): super(HasRegParam, self).__init__() #: param for regularization parameter (>= 0) self.regParam = Param(self, "regParam", "regularization parameter (>= 0)") - if None is not None: - self._setDefault(regParam=None) def setRegParam(self, value): """ Sets the value of :py:attr:`regParam`. """ - self.paramMap[self.regParam] = value + self._paramMap[self.regParam] = value return self def getRegParam(self): @@ -90,14 +86,13 @@ def __init__(self): super(HasFeaturesCol, self).__init__() #: param for features column name self.featuresCol = Param(self, "featuresCol", "features column name") - if 'features' is not None: - self._setDefault(featuresCol='features') + self._setDefault(featuresCol='features') def setFeaturesCol(self, value): """ Sets the value of :py:attr:`featuresCol`. """ - self.paramMap[self.featuresCol] = value + self._paramMap[self.featuresCol] = value return self def getFeaturesCol(self): @@ -119,14 +114,13 @@ def __init__(self): super(HasLabelCol, self).__init__() #: param for label column name self.labelCol = Param(self, "labelCol", "label column name") - if 'label' is not None: - self._setDefault(labelCol='label') + self._setDefault(labelCol='label') def setLabelCol(self, value): """ Sets the value of :py:attr:`labelCol`. """ - self.paramMap[self.labelCol] = value + self._paramMap[self.labelCol] = value return self def getLabelCol(self): @@ -148,14 +142,13 @@ def __init__(self): super(HasPredictionCol, self).__init__() #: param for prediction column name self.predictionCol = Param(self, "predictionCol", "prediction column name") - if 'prediction' is not None: - self._setDefault(predictionCol='prediction') + self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): """ Sets the value of :py:attr:`predictionCol`. """ - self.paramMap[self.predictionCol] = value + self._paramMap[self.predictionCol] = value return self def getPredictionCol(self): @@ -165,6 +158,34 @@ def getPredictionCol(self): return self.getOrDefault(self.predictionCol) +class HasProbabilityCol(Params): + """ + Mixin for param probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.. + """ + + # a placeholder to make it appear in the generated doc + probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + + def __init__(self): + super(HasProbabilityCol, self).__init__() + #: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. + self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + self._setDefault(probabilityCol='probability') + + def setProbabilityCol(self, value): + """ + Sets the value of :py:attr:`probabilityCol`. + """ + self._paramMap[self.probabilityCol] = value + return self + + def getProbabilityCol(self): + """ + Gets the value of probabilityCol or its default value. + """ + return self.getOrDefault(self.probabilityCol) + + class HasRawPredictionCol(Params): """ Mixin for param rawPredictionCol: raw prediction (a.k.a. confidence) column name. @@ -177,14 +198,13 @@ def __init__(self): super(HasRawPredictionCol, self).__init__() #: param for raw prediction (a.k.a. confidence) column name self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") - if 'rawPrediction' is not None: - self._setDefault(rawPredictionCol='rawPrediction') + self._setDefault(rawPredictionCol='rawPrediction') def setRawPredictionCol(self, value): """ Sets the value of :py:attr:`rawPredictionCol`. """ - self.paramMap[self.rawPredictionCol] = value + self._paramMap[self.rawPredictionCol] = value return self def getRawPredictionCol(self): @@ -206,14 +226,12 @@ def __init__(self): super(HasInputCol, self).__init__() #: param for input column name self.inputCol = Param(self, "inputCol", "input column name") - if None is not None: - self._setDefault(inputCol=None) def setInputCol(self, value): """ Sets the value of :py:attr:`inputCol`. """ - self.paramMap[self.inputCol] = value + self._paramMap[self.inputCol] = value return self def getInputCol(self): @@ -235,14 +253,12 @@ def __init__(self): super(HasInputCols, self).__init__() #: param for input column names self.inputCols = Param(self, "inputCols", "input column names") - if None is not None: - self._setDefault(inputCols=None) def setInputCols(self, value): """ Sets the value of :py:attr:`inputCols`. """ - self.paramMap[self.inputCols] = value + self._paramMap[self.inputCols] = value return self def getInputCols(self): @@ -264,14 +280,13 @@ def __init__(self): super(HasOutputCol, self).__init__() #: param for output column name self.outputCol = Param(self, "outputCol", "output column name") - if None is not None: - self._setDefault(outputCol=None) + self._setDefault(outputCol=self.uid + '__output') def setOutputCol(self, value): """ Sets the value of :py:attr:`outputCol`. """ - self.paramMap[self.outputCol] = value + self._paramMap[self.outputCol] = value return self def getOutputCol(self): @@ -293,14 +308,12 @@ def __init__(self): super(HasNumFeatures, self).__init__() #: param for number of features self.numFeatures = Param(self, "numFeatures", "number of features") - if None is not None: - self._setDefault(numFeatures=None) def setNumFeatures(self, value): """ Sets the value of :py:attr:`numFeatures`. """ - self.paramMap[self.numFeatures] = value + self._paramMap[self.numFeatures] = value return self def getNumFeatures(self): @@ -322,14 +335,12 @@ def __init__(self): super(HasCheckpointInterval, self).__init__() #: param for checkpoint interval (>= 1) self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1)") - if None is not None: - self._setDefault(checkpointInterval=None) def setCheckpointInterval(self, value): """ Sets the value of :py:attr:`checkpointInterval`. """ - self.paramMap[self.checkpointInterval] = value + self._paramMap[self.checkpointInterval] = value return self def getCheckpointInterval(self): @@ -351,14 +362,13 @@ def __init__(self): super(HasSeed, self).__init__() #: param for random seed self.seed = Param(self, "seed", "random seed") - if None is not None: - self._setDefault(seed=None) + self._setDefault(seed=hash(type(self).__name__)) def setSeed(self, value): """ Sets the value of :py:attr:`seed`. """ - self.paramMap[self.seed] = value + self._paramMap[self.seed] = value return self def getSeed(self): @@ -380,14 +390,12 @@ def __init__(self): super(HasTol, self).__init__() #: param for the convergence tolerance for iterative algorithms self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms") - if None is not None: - self._setDefault(tol=None) def setTol(self, value): """ Sets the value of :py:attr:`tol`. """ - self.paramMap[self.tol] = value + self._paramMap[self.tol] = value return self def getTol(self): @@ -409,14 +417,12 @@ def __init__(self): super(HasStepSize, self).__init__() #: param for Step size to be used for each iteration of optimization. self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.") - if None is not None: - self._setDefault(stepSize=None) def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. """ - self.paramMap[self.stepSize] = value + self._paramMap[self.stepSize] = value return self def getStepSize(self): @@ -438,6 +444,7 @@ class DecisionTreeParams(Params): minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") + def __init__(self): super(DecisionTreeParams, self).__init__() @@ -453,12 +460,12 @@ def __init__(self): self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") - + def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. """ - self.paramMap[self.maxDepth] = value + self._paramMap[self.maxDepth] = value return self def getMaxDepth(self): @@ -471,7 +478,7 @@ def setMaxBins(self, value): """ Sets the value of :py:attr:`maxBins`. """ - self.paramMap[self.maxBins] = value + self._paramMap[self.maxBins] = value return self def getMaxBins(self): @@ -484,7 +491,7 @@ def setMinInstancesPerNode(self, value): """ Sets the value of :py:attr:`minInstancesPerNode`. """ - self.paramMap[self.minInstancesPerNode] = value + self._paramMap[self.minInstancesPerNode] = value return self def getMinInstancesPerNode(self): @@ -497,7 +504,7 @@ def setMinInfoGain(self, value): """ Sets the value of :py:attr:`minInfoGain`. """ - self.paramMap[self.minInfoGain] = value + self._paramMap[self.minInfoGain] = value return self def getMinInfoGain(self): @@ -510,7 +517,7 @@ def setMaxMemoryInMB(self, value): """ Sets the value of :py:attr:`maxMemoryInMB`. """ - self.paramMap[self.maxMemoryInMB] = value + self._paramMap[self.maxMemoryInMB] = value return self def getMaxMemoryInMB(self): @@ -523,7 +530,7 @@ def setCacheNodeIds(self, value): """ Sets the value of :py:attr:`cacheNodeIds`. """ - self.paramMap[self.cacheNodeIds] = value + self._paramMap[self.cacheNodeIds] = value return self def getCacheNodeIds(self): diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a328bcf84a2e7..a563024b2cdcb 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -31,18 +31,40 @@ class Estimator(Params): __metaclass__ = ABCMeta @abstractmethod - def fit(self, dataset, params={}): + def _fit(self, dataset): """ - Fits a model to the input dataset with optional parameters. + Fits a model to the input dataset. This is called by the + default implementation of fit. :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` - :param params: an optional param map that overwrites embedded - params :returns: fitted model """ raise NotImplementedError() + def fit(self, dataset, params={}): + """ + Fits a model to the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.DataFrame` + :param params: an optional param map that overrides embedded + params. If a list/tuple of param maps is given, + this calls fit on each param map and returns a + list of models. + :returns: fitted model(s) + """ + if isinstance(params, (list, tuple)): + return [self.fit(dataset, paramMap) for paramMap in params] + elif isinstance(params, dict): + if params: + return self.copy(params)._fit(dataset) + else: + return self._fit(dataset) + else: + raise ValueError("Params must be either a param map or a list/tuple of param maps, " + "but got %s." % type(params)) + @inherit_doc class Transformer(Params): @@ -54,18 +76,34 @@ class Transformer(Params): __metaclass__ = ABCMeta @abstractmethod - def transform(self, dataset, params={}): + def _transform(self, dataset): """ Transforms the input dataset with optional parameters. :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` - :param params: an optional param map that overwrites embedded - params :returns: transformed dataset """ raise NotImplementedError() + def transform(self, dataset, params={}): + """ + Transforms the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.DataFrame` + :param params: an optional param map that overrides embedded + params. + :returns: transformed dataset + """ + if isinstance(params, dict): + if params: + return self.copy(params,)._transform(dataset) + else: + return self._transform(dataset) + else: + raise ValueError("Params must be either a param map but got %s." % type(params)) + @inherit_doc class Model(Transformer): @@ -113,15 +151,15 @@ def setStages(self, value): :param value: a list of transformers or estimators :return: the pipeline instance """ - self.paramMap[self.stages] = value + self._paramMap[self.stages] = value return self def getStages(self): """ Get pipeline stages. """ - if self.stages in self.paramMap: - return self.paramMap[self.stages] + if self.stages in self._paramMap: + return self._paramMap[self.stages] @keyword_only def setParams(self, stages=[]): @@ -132,9 +170,8 @@ def setParams(self, stages=[]): kwargs = self.setParams._input_kwargs return self._set(**kwargs) - def fit(self, dataset, params={}): - paramMap = self.extractParamMap(params) - stages = paramMap[self.stages] + def _fit(self, dataset): + stages = self.getStages() for stage in stages: if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): raise TypeError( @@ -148,16 +185,21 @@ def fit(self, dataset, params={}): if i <= indexOfLastEstimator: if isinstance(stage, Transformer): transformers.append(stage) - dataset = stage.transform(dataset, paramMap) + dataset = stage.transform(dataset) else: # must be an Estimator - model = stage.fit(dataset, paramMap) + model = stage.fit(dataset) transformers.append(model) if i < indexOfLastEstimator: - dataset = model.transform(dataset, paramMap) + dataset = model.transform(dataset) else: transformers.append(stage) return PipelineModel(transformers) + def copy(self, extra={}): + that = Params.copy(self, extra) + stages = [stage.copy(extra) for stage in that.getStages()] + return that.setStages(stages) + @inherit_doc class PipelineModel(Model): @@ -165,33 +207,15 @@ class PipelineModel(Model): Represents a compiled pipeline with transformers and fitted models. """ - def __init__(self, transformers): + def __init__(self, stages): super(PipelineModel, self).__init__() - self.transformers = transformers + self.stages = stages - def transform(self, dataset, params={}): - paramMap = self.extractParamMap(params) - for t in self.transformers: - dataset = t.transform(dataset, paramMap) + def _transform(self, dataset): + for t in self.stages: + dataset = t.transform(dataset) return dataset - -class Evaluator(Params): - """ - Base class for evaluators that compute metrics from predictions. - """ - - __metaclass__ = ABCMeta - - @abstractmethod - def evaluate(self, dataset, params={}): - """ - Evaluates the output. - - :param dataset: a dataset that contains labels/observations and - predictions - :param params: an optional param map that overrides embedded - params - :return: metric - """ - raise NotImplementedError() + def copy(self, extra={}): + stages = [stage.copy(extra) for stage in self.stages] + return PipelineModel(stages) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 4846b907e85ec..b06099ac0aee6 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha indicated user preferences rather than explicit ratings given to items. + >>> df = sqlContext.createDataFrame( + ... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], + ... ["user", "item", "rating"]) >>> als = ALS(rank=10, maxIter=5) >>> model = als.fit(df) + >>> model.rank + 10 + >>> model.userFactors.orderBy("id").collect() + [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)] >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] @@ -74,7 +81,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha >>> predictions[2] Row(user=2, item=0, prediction=-1.15...) """ - _java_class = "org.apache.spark.ml.recommendation.ALS" + # a placeholder to make it appear in the generated doc rank = Param(Params._dummy(), "rank", "rank of the factorization") numUserBlocks = Param(Params._dummy(), "numUserBlocks", "number of user blocks") @@ -89,14 +96,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha @keyword_only def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, - implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10): """ - __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, - implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=0, + __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ + implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=false, checkpointInterval=10) """ super(ALS, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) self.rank = Param(self, "rank", "rank of the factorization") self.numUserBlocks = Param(self, "numUserBlocks", "number of user blocks") self.numItemBlocks = Param(self, "numItemBlocks", "number of item blocks") @@ -108,18 +116,18 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB self.nonnegative = Param(self, "nonnegative", "whether to use nonnegative constraint for least squares") self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, - implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, - implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10): """ - setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, - implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, + setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ + implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=False, checkpointInterval=10) Sets params for ALS. """ @@ -133,7 +141,7 @@ def setRank(self, value): """ Sets the value of :py:attr:`rank`. """ - self.paramMap[self.rank] = value + self._paramMap[self.rank] = value return self def getRank(self): @@ -146,7 +154,7 @@ def setNumUserBlocks(self, value): """ Sets the value of :py:attr:`numUserBlocks`. """ - self.paramMap[self.numUserBlocks] = value + self._paramMap[self.numUserBlocks] = value return self def getNumUserBlocks(self): @@ -159,7 +167,7 @@ def setNumItemBlocks(self, value): """ Sets the value of :py:attr:`numItemBlocks`. """ - self.paramMap[self.numItemBlocks] = value + self._paramMap[self.numItemBlocks] = value return self def getNumItemBlocks(self): @@ -172,14 +180,14 @@ def setNumBlocks(self, value): """ Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. """ - self.paramMap[self.numUserBlocks] = value - self.paramMap[self.numItemBlocks] = value + self._paramMap[self.numUserBlocks] = value + self._paramMap[self.numItemBlocks] = value def setImplicitPrefs(self, value): """ Sets the value of :py:attr:`implicitPrefs`. """ - self.paramMap[self.implicitPrefs] = value + self._paramMap[self.implicitPrefs] = value return self def getImplicitPrefs(self): @@ -192,7 +200,7 @@ def setAlpha(self, value): """ Sets the value of :py:attr:`alpha`. """ - self.paramMap[self.alpha] = value + self._paramMap[self.alpha] = value return self def getAlpha(self): @@ -205,7 +213,7 @@ def setUserCol(self, value): """ Sets the value of :py:attr:`userCol`. """ - self.paramMap[self.userCol] = value + self._paramMap[self.userCol] = value return self def getUserCol(self): @@ -218,7 +226,7 @@ def setItemCol(self, value): """ Sets the value of :py:attr:`itemCol`. """ - self.paramMap[self.itemCol] = value + self._paramMap[self.itemCol] = value return self def getItemCol(self): @@ -231,7 +239,7 @@ def setRatingCol(self, value): """ Sets the value of :py:attr:`ratingCol`. """ - self.paramMap[self.ratingCol] = value + self._paramMap[self.ratingCol] = value return self def getRatingCol(self): @@ -244,7 +252,7 @@ def setNonnegative(self, value): """ Sets the value of :py:attr:`nonnegative`. """ - self.paramMap[self.nonnegative] = value + self._paramMap[self.nonnegative] = value return self def getNonnegative(self): @@ -259,6 +267,27 @@ class ALSModel(JavaModel): Model fitted by ALS. """ + @property + def rank(self): + """rank of the matrix factorization model""" + return self._call_java("rank") + + @property + def userFactors(self): + """ + a DataFrame that stores user factors in two columns: `id` and + `features` + """ + return self._call_java("userFactors") + + @property + def itemFactors(self): + """ + a DataFrame that stores item factors in two columns: `id` and + `features` + """ + return self._call_java("itemFactors") + if __name__ == "__main__": import doctest @@ -271,8 +300,6 @@ class ALSModel(JavaModel): sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), - (2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"]) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() if failure_count: diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 0ab5c6c3d20c3..b139e27372d80 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -33,8 +33,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Linear regression. The learning objective is to minimize the squared error, with regularization. - The specific squared error loss function used is: - L = 1/2n ||A weights - y||^2^ + The specific squared error loss function used is: L = 1/2n ||A weights - y||^2^ This support multiple types of regularization: - none (a.k.a. ordinary least squares) @@ -51,6 +50,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction -1.0 + >>> model.weights + DenseVector([1.0]) + >>> model.intercept + 0.0 >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 @@ -59,7 +62,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction ... TypeError: Method setParams forces keyword arguments. """ - _java_class = "org.apache.spark.ml.regression.LinearRegression" + # a placeholder to make it appear in the generated doc elasticNetParam = \ Param(Params._dummy(), "elasticNetParam", @@ -74,6 +77,8 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6) """ super(LinearRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.LinearRegression", self.uid) #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty # is an L2 penalty. For alpha = 1, it is an L1 penalty. self.elasticNetParam = \ @@ -102,7 +107,7 @@ def setElasticNetParam(self, value): """ Sets the value of :py:attr:`elasticNetParam`. """ - self.paramMap[self.elasticNetParam] = value + self._paramMap[self.elasticNetParam] = value return self def getElasticNetParam(self): @@ -117,6 +122,20 @@ class LinearRegressionModel(JavaModel): Model fitted by LinearRegression. """ + @property + def weights(self): + """ + Model weights. + """ + return self._call_java("weights") + + @property + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + class TreeRegressorParams(object): """ @@ -161,7 +180,6 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi 1.0 """ - _java_class = "org.apache.spark.ml.regression.DecisionTreeRegressor" # a placeholder to make it appear in the generated doc impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + @@ -173,10 +191,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") """ super(DecisionTreeRegressor, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.DecisionTreeRegressor", self.uid) #: param for Criterion used for information gain calculation (case-insensitive). self.impurity = \ Param(self, "impurity", @@ -195,9 +215,8 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre impurity="variance"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="variance") + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") Sets params for the DecisionTreeRegressor. """ kwargs = self.setParams._input_kwargs @@ -210,7 +229,7 @@ def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. """ - self.paramMap[self.impurity] = value + self._paramMap[self.impurity] = value return self def getImpurity(self): @@ -238,7 +257,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2) + >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42) >>> model = rf.fit(df) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction @@ -248,7 +267,6 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi 0.5 """ - _java_class = "org.apache.spark.ml.regression.RandomForestRegressor" # a placeholder to make it appear in the generated doc impurity = Param(Params._dummy(), "impurity", "Criterion used for information gain calculation (case-insensitive). " + @@ -266,14 +284,17 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", - numTrees=20, featureSubsetStrategy="auto", seed=42): + numTrees=20, featureSubsetStrategy="auto", seed=None): """ - __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", - numTrees=20, featureSubsetStrategy="auto", seed=42) + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", numTrees=20, \ + featureSubsetStrategy="auto", seed=None) """ super(RandomForestRegressor, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.RandomForestRegressor", self.uid) #: param for Criterion used for information gain calculation (case-insensitive). self.impurity = \ Param(self, "impurity", @@ -292,7 +313,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred "The number of features to consider for splits at each tree node. Supported " + "options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies)) self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, impurity="variance", numTrees=20, featureSubsetStrategy="auto") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -300,12 +321,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, impurity="variance", numTrees=20, featureSubsetStrategy="auto"): """ - setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ impurity="variance", numTrees=20, featureSubsetStrategy="auto") Sets params for linear regression. """ @@ -319,7 +340,7 @@ def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. """ - self.paramMap[self.impurity] = value + self._paramMap[self.impurity] = value return self def getImpurity(self): @@ -332,7 +353,7 @@ def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. """ - self.paramMap[self.subsamplingRate] = value + self._paramMap[self.subsamplingRate] = value return self def getSubsamplingRate(self): @@ -345,7 +366,7 @@ def setNumTrees(self, value): """ Sets the value of :py:attr:`numTrees`. """ - self.paramMap[self.numTrees] = value + self._paramMap[self.numTrees] = value return self def getNumTrees(self): @@ -358,7 +379,7 @@ def setFeatureSubsetStrategy(self, value): """ Sets the value of :py:attr:`featureSubsetStrategy`. """ - self.paramMap[self.featureSubsetStrategy] = value + self._paramMap[self.featureSubsetStrategy] = value return self def getFeatureSubsetStrategy(self): @@ -396,7 +417,6 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, 1.0 """ - _java_class = "org.apache.spark.ml.regression.GBTRegressor" # a placeholder to make it appear in the generated doc lossType = Param(Params._dummy(), "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + @@ -414,12 +434,13 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1): """ - __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="squared", - maxIter=20, stepSize=0.1) + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + lossType="squared", maxIter=20, stepSize=0.1) """ super(GBTRegressor, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid) #: param for Loss function which GBT tries to minimize (case-insensitive). self.lossType = Param(self, "lossType", "Loss function which GBT tries to minimize (case-insensitive). " + @@ -445,9 +466,9 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1): """ - setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", - maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ lossType="squared", maxIter=20, stepSize=0.1) Sets params for Gradient Boosted Tree Regression. """ @@ -461,7 +482,7 @@ def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. """ - self.paramMap[self.lossType] = value + self._paramMap[self.lossType] = value return self def getLossType(self): @@ -474,7 +495,7 @@ def setSubsamplingRate(self, value): """ Sets the value of :py:attr:`subsamplingRate`. """ - self.paramMap[self.subsamplingRate] = value + self._paramMap[self.subsamplingRate] = value return self def getSubsamplingRate(self): @@ -487,7 +508,7 @@ def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. """ - self.paramMap[self.stepSize] = value + self._paramMap[self.stepSize] = value return self def getStepSize(self): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index ba6478dcd58a9..6adbf166f34a8 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -31,10 +31,13 @@ import unittest from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase -from pyspark.sql import DataFrame -from pyspark.ml.param import Param -from pyspark.ml.param.shared import HasMaxIter, HasInputCol -from pyspark.ml.pipeline import Estimator, Model, Pipeline, Transformer +from pyspark.sql import DataFrame, SQLContext +from pyspark.ml.param import Param, Params +from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed +from pyspark.ml.util import keyword_only +from pyspark.ml import Estimator, Model, Pipeline, Transformer +from pyspark.ml.feature import * +from pyspark.mllib.linalg import DenseVector class MockDataset(DataFrame): @@ -43,44 +46,43 @@ def __init__(self): self.index = 0 -class MockTransformer(Transformer): +class HasFake(Params): + + def __init__(self): + super(HasFake, self).__init__() + self.fake = Param(self, "fake", "fake param") + + def getFake(self): + return self.getOrDefault(self.fake) + + +class MockTransformer(Transformer, HasFake): def __init__(self): super(MockTransformer, self).__init__() - self.fake = Param(self, "fake", "fake") self.dataset_index = None - self.fake_param_value = None - def transform(self, dataset, params={}): + def _transform(self, dataset): self.dataset_index = dataset.index - if self.fake in params: - self.fake_param_value = params[self.fake] dataset.index += 1 return dataset -class MockEstimator(Estimator): +class MockEstimator(Estimator, HasFake): def __init__(self): super(MockEstimator, self).__init__() - self.fake = Param(self, "fake", "fake") self.dataset_index = None - self.fake_param_value = None - self.model = None - def fit(self, dataset, params={}): + def _fit(self, dataset): self.dataset_index = dataset.index - if self.fake in params: - self.fake_param_value = params[self.fake] model = MockModel() - self.model = model + self._copyValues(model) return model -class MockModel(MockTransformer, Model): - - def __init__(self): - super(MockModel, self).__init__() +class MockModel(MockTransformer, Model, HasFake): + pass class PipelineTests(PySparkTestCase): @@ -91,19 +93,17 @@ def test_pipeline(self): transformer1 = MockTransformer() estimator2 = MockEstimator() transformer3 = MockTransformer() - pipeline = Pipeline() \ - .setStages([estimator0, transformer1, estimator2, transformer3]) + pipeline = Pipeline(stages=[estimator0, transformer1, estimator2, transformer3]) pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) - self.assertEqual(0, estimator0.dataset_index) - self.assertEqual(0, estimator0.fake_param_value) - model0 = estimator0.model + model0, transformer1, model2, transformer3 = pipeline_model.stages self.assertEqual(0, model0.dataset_index) + self.assertEqual(0, model0.getFake()) self.assertEqual(1, transformer1.dataset_index) - self.assertEqual(1, transformer1.fake_param_value) - self.assertEqual(2, estimator2.dataset_index) - model2 = estimator2.model - self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should " - "not be called during fit.") + self.assertEqual(1, transformer1.getFake()) + self.assertEqual(2, dataset.index) + self.assertIsNone(model2.dataset_index, "The last model shouldn't be called in fit.") + self.assertIsNone(transformer3.dataset_index, + "The last transformer shouldn't be called in fit.") dataset = pipeline_model.transform(dataset) self.assertEqual(2, model0.dataset_index) self.assertEqual(3, transformer1.dataset_index) @@ -112,14 +112,46 @@ def test_pipeline(self): self.assertEqual(6, dataset.index) -class TestParams(HasMaxIter, HasInputCol): +class TestParams(HasMaxIter, HasInputCol, HasSeed): """ - A subclass of Params mixed with HasMaxIter and HasInputCol. + A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. """ - - def __init__(self): + @keyword_only + def __init__(self, seed=None): super(TestParams, self).__init__() self._setDefault(maxIter=10) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, seed=None): + """ + setParams(self, seed=None) + Sets params for this test. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + +class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): + """ + A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. + """ + @keyword_only + def __init__(self, seed=None): + super(OtherTestParams, self).__init__() + self._setDefault(maxIter=10) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, seed=None): + """ + setParams(self, seed=None) + Sets params for this test. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) class ParamTests(PySparkTestCase): @@ -129,16 +161,18 @@ def test_param(self): maxIter = testParams.maxIter self.assertEqual(maxIter.name, "maxIter") self.assertEqual(maxIter.doc, "max number of iterations (>= 0)") - self.assertTrue(maxIter.parent is testParams) + self.assertTrue(maxIter.parent == testParams.uid) def test_params(self): testParams = TestParams() maxIter = testParams.maxIter inputCol = testParams.inputCol + seed = testParams.seed params = testParams.params - self.assertEqual(params, [inputCol, maxIter]) + self.assertEqual(params, [inputCol, maxIter, seed]) + self.assertTrue(testParams.hasParam(maxIter)) self.assertTrue(testParams.hasDefault(maxIter)) self.assertFalse(testParams.isSet(maxIter)) self.assertTrue(testParams.isDefined(maxIter)) @@ -147,16 +181,76 @@ def test_params(self): self.assertTrue(testParams.isSet(maxIter)) self.assertEquals(testParams.getMaxIter(), 100) + self.assertTrue(testParams.hasParam(inputCol)) self.assertFalse(testParams.hasDefault(inputCol)) self.assertFalse(testParams.isSet(inputCol)) self.assertFalse(testParams.isDefined(inputCol)) with self.assertRaises(KeyError): testParams.getInputCol() + # Since the default is normally random, set it to a known number for debug str + testParams._setDefault(seed=41) + testParams.setSeed(43) + self.assertEquals( testParams.explainParams(), "\n".join(["inputCol: input column name (undefined)", - "maxIter: max number of iterations (>= 0) (default: 10, current: 100)"])) + "maxIter: max number of iterations (>= 0) (default: 10, current: 100)", + "seed: random seed (default: 41, current: 43)"])) + + def test_hasseed(self): + noSeedSpecd = TestParams() + withSeedSpecd = TestParams(seed=42) + other = OtherTestParams() + # Check that we no longer use 42 as the magic number + self.assertNotEqual(noSeedSpecd.getSeed(), 42) + origSeed = noSeedSpecd.getSeed() + # Check that we only compute the seed once + self.assertEqual(noSeedSpecd.getSeed(), origSeed) + # Check that a specified seed is honored + self.assertEqual(withSeedSpecd.getSeed(), 42) + # Check that a different class has a different seed + self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed()) + + +class FeatureTests(PySparkTestCase): + + def test_binarizer(self): + b0 = Binarizer() + self.assertListEqual(b0.params, [b0.inputCol, b0.outputCol, b0.threshold]) + self.assertTrue(all([~b0.isSet(p) for p in b0.params])) + self.assertTrue(b0.hasDefault(b0.threshold)) + self.assertEqual(b0.getThreshold(), 0.0) + b0.setParams(inputCol="input", outputCol="output").setThreshold(1.0) + self.assertTrue(all([b0.isSet(p) for p in b0.params])) + self.assertEqual(b0.getThreshold(), 1.0) + self.assertEqual(b0.getInputCol(), "input") + self.assertEqual(b0.getOutputCol(), "output") + + b0c = b0.copy({b0.threshold: 2.0}) + self.assertEqual(b0c.uid, b0.uid) + self.assertListEqual(b0c.params, b0.params) + self.assertEqual(b0c.getThreshold(), 2.0) + + b1 = Binarizer(threshold=2.0, inputCol="input", outputCol="output") + self.assertNotEqual(b1.uid, b0.uid) + self.assertEqual(b1.getThreshold(), 2.0) + self.assertEqual(b1.getInputCol(), "input") + self.assertEqual(b1.getOutputCol(), "output") + + def test_idf(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (DenseVector([1.0, 2.0]),), + (DenseVector([0.0, 1.0]),), + (DenseVector([3.0, 0.2]),)], ["tf"]) + idf0 = IDF(inputCol="tf") + self.assertListEqual(idf0.params, [idf0.inputCol, idf0.minDocFreq, idf0.outputCol]) + idf0m = idf0.fit(dataset, {idf0.outputCol: "idf"}) + self.assertEqual(idf0m.uid, idf0.uid, + "Model should inherit the UID from its parent estimator.") + output = idf0m.transform(dataset) + self.assertIsNotNone(output.head().idf) if __name__ == "__main__": diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 86f4dc7368be0..0bf988fd72f14 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -91,20 +91,19 @@ class CrossValidator(Estimator): >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator >>> from pyspark.mllib.linalg import Vectors >>> dataset = sqlContext.createDataFrame( - ... [(Vectors.dense([0.0, 1.0]), 0.0), - ... (Vectors.dense([1.0, 2.0]), 1.0), - ... (Vectors.dense([0.55, 3.0]), 0.0), - ... (Vectors.dense([0.45, 4.0]), 1.0), - ... (Vectors.dense([0.51, 5.0]), 1.0)] * 10, + ... [(Vectors.dense([0.0]), 0.0), + ... (Vectors.dense([0.4]), 1.0), + ... (Vectors.dense([0.5]), 0.0), + ... (Vectors.dense([0.6]), 1.0), + ... (Vectors.dense([1.0]), 1.0)] * 10, ... ["features", "label"]) >>> lr = LogisticRegression() - >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build() + >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) - >>> # SPARK-7432: The following test is flaky. - >>> # cvModel = cv.fit(dataset) - >>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset) - >>> # cvModel.transform(dataset).collect() == expected.collect() + >>> cvModel = cv.fit(dataset) + >>> evaluator.evaluate(cvModel.transform(dataset)) + 0.8333... """ # a placeholder to make it appear in the generated doc @@ -155,7 +154,7 @@ def setEstimator(self, value): """ Sets the value of :py:attr:`estimator`. """ - self.paramMap[self.estimator] = value + self._paramMap[self.estimator] = value return self def getEstimator(self): @@ -168,7 +167,7 @@ def setEstimatorParamMaps(self, value): """ Sets the value of :py:attr:`estimatorParamMaps`. """ - self.paramMap[self.estimatorParamMaps] = value + self._paramMap[self.estimatorParamMaps] = value return self def getEstimatorParamMaps(self): @@ -181,7 +180,7 @@ def setEvaluator(self, value): """ Sets the value of :py:attr:`evaluator`. """ - self.paramMap[self.evaluator] = value + self._paramMap[self.evaluator] = value return self def getEvaluator(self): @@ -194,7 +193,7 @@ def setNumFolds(self, value): """ Sets the value of :py:attr:`numFolds`. """ - self.paramMap[self.numFolds] = value + self._paramMap[self.numFolds] = value return self def getNumFolds(self): @@ -203,13 +202,12 @@ def getNumFolds(self): """ return self.getOrDefault(self.numFolds) - def fit(self, dataset, params={}): - paramMap = self.extractParamMap(params) - est = paramMap[self.estimator] - epm = paramMap[self.estimatorParamMaps] + def _fit(self, dataset): + est = self.getOrDefault(self.estimator) + epm = self.getOrDefault(self.estimatorParamMaps) numModels = len(epm) - eva = paramMap[self.evaluator] - nFolds = paramMap[self.numFolds] + eva = self.getOrDefault(self.evaluator) + nFolds = self.getOrDefault(self.numFolds) h = 1.0 / nFolds randCol = self.uid + "_rand" df = dataset.select("*", rand(0).alias(randCol)) @@ -229,6 +227,15 @@ def fit(self, dataset, params={}): bestModel = est.fit(dataset, epm[bestIndex]) return CrossValidatorModel(bestModel) + def copy(self, extra={}): + newCV = Params.copy(self, extra) + if self.isSet(self.estimator): + newCV.setEstimator(self.getEstimator().copy(extra)) + # estimatorParamMaps remain the same + if self.isSet(self.evaluator): + newCV.setEvaluator(self.getEvaluator().copy(extra)) + return newCV + class CrossValidatorModel(Model): """ @@ -240,8 +247,19 @@ def __init__(self, bestModel): #: best model from cross validation self.bestModel = bestModel - def transform(self, dataset, params={}): - return self.bestModel.transform(dataset, params) + def _transform(self, dataset): + return self.bestModel.transform(dataset) + + def copy(self, extra={}): + """ + Creates a copy of this instance with a randomly generated uid + and some extra params. This copies the underlying bestModel, + creates a deep copy of the embedded paramMap, and + copies the embedded and extra parameters over. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + return CrossValidatorModel(self.bestModel.copy(extra)) if __name__ == "__main__": diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index d3cb100a9efa5..cee9d67b05325 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -39,9 +39,16 @@ class Identifiable(object): """ def __init__(self): - #: A unique id for the object. The default implementation - #: concatenates the class name, "_", and 8 random hex chars. - self.uid = type(self).__name__ + "_" + uuid.uuid4().hex[:8] + #: A unique id for the object. + self.uid = self._randomUID() def __repr__(self): return self.uid + + @classmethod + def _randomUID(cls): + """ + Generate a unique id for the object. The default implementation + concatenates the class name, "_", and 12 random hex chars. + """ + return cls.__name__ + "_" + uuid.uuid4().hex[12:] diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index f5ac2a398642a..7b0893e2cdadc 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -20,8 +20,8 @@ from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params -from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model -from pyspark.mllib.common import inherit_doc +from pyspark.ml.pipeline import Estimator, Transformer, Model +from pyspark.mllib.common import inherit_doc, _java2py, _py2java def _jvm(): @@ -45,46 +45,61 @@ class JavaWrapper(Params): __metaclass__ = ABCMeta - #: Fully-qualified class name of the wrapped Java component. - _java_class = None + #: The wrapped Java companion object. Subclasses should initialize + #: it properly. The param values in the Java object should be + #: synced with the Python wrapper in fit/transform/evaluate/copy. + _java_obj = None - def _java_obj(self): + @staticmethod + def _new_java_obj(java_class, *args): """ - Returns or creates a Java object. + Construct a new Java object. """ + sc = SparkContext._active_spark_context java_obj = _jvm() - for name in self._java_class.split("."): + for name in java_class.split("."): java_obj = getattr(java_obj, name) - return java_obj() + java_args = [_py2java(sc, arg) for arg in args] + return java_obj(*java_args) - def _transfer_params_to_java(self, params, java_obj): + def _make_java_param_pair(self, param, value): """ - Transforms the embedded params and additional params to the - input Java object. - :param params: additional params (overwriting embedded values) - :param java_obj: Java object to receive the params + Makes a Java parm pair. + """ + sc = SparkContext._active_spark_context + param = self._resolveParam(param) + java_param = self._java_obj.getParam(param.name) + java_value = _py2java(sc, value) + return java_param.w(java_value) + + def _transfer_params_to_java(self): """ - paramMap = self.extractParamMap(params) + Transforms the embedded params to the companion Java object. + """ + paramMap = self.extractParamMap() for param in self.params: if param in paramMap: - value = paramMap[param] - java_param = java_obj.getParam(param.name) - java_obj.set(java_param.w(value)) + pair = self._make_java_param_pair(param, paramMap[param]) + self._java_obj.set(pair) + + def _transfer_params_from_java(self): + """ + Transforms the embedded params from the companion Java object. + """ + sc = SparkContext._active_spark_context + for param in self.params: + if self._java_obj.hasParam(param.name): + java_param = self._java_obj.getParam(param.name) + value = _java2py(sc, self._java_obj.getOrDefault(java_param)) + self._paramMap[param] = value - def _empty_java_param_map(self): + @staticmethod + def _empty_java_param_map(): """ Returns an empty Java ParamMap reference. """ return _jvm().org.apache.spark.ml.param.ParamMap() - def _create_java_param_map(self, params, java_obj): - paramMap = self._empty_java_param_map() - for param, value in params.items(): - if param.parent is self: - java_param = java_obj.getParam(param.name) - paramMap.put(java_param.w(value)) - return paramMap - @inherit_doc class JavaEstimator(Estimator, JavaWrapper): @@ -99,9 +114,9 @@ def _create_model(self, java_model): """ Creates a model from the input Java model reference. """ - return JavaModel(java_model) + raise NotImplementedError() - def _fit_java(self, dataset, params={}): + def _fit_java(self, dataset): """ Fits a Java model to the input dataset. :param dataset: input dataset, which is an instance of @@ -109,12 +124,11 @@ def _fit_java(self, dataset, params={}): :param params: additional params (overwriting embedded values) :return: fitted Java model """ - java_obj = self._java_obj() - self._transfer_params_to_java(params, java_obj) - return java_obj.fit(dataset._jdf, self._empty_java_param_map()) + self._transfer_params_to_java() + return self._java_obj.fit(dataset._jdf) - def fit(self, dataset, params={}): - java_model = self._fit_java(dataset, params) + def _fit(self, dataset): + java_model = self._fit_java(dataset) return self._create_model(java_model) @@ -127,39 +141,47 @@ class JavaTransformer(Transformer, JavaWrapper): __metaclass__ = ABCMeta - def transform(self, dataset, params={}): - java_obj = self._java_obj() - self._transfer_params_to_java(params, java_obj) - return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx) + def _transform(self, dataset): + self._transfer_params_to_java() + return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) @inherit_doc class JavaModel(Model, JavaTransformer): """ Base class for :py:class:`Model`s that wrap Java/Scala - implementations. + implementations. Subclasses should inherit this class before + param mix-ins, because this sets the UID from the Java model. """ __metaclass__ = ABCMeta def __init__(self, java_model): - super(JavaTransformer, self).__init__() - self._java_model = java_model - - def _java_obj(self): - return self._java_model - - -@inherit_doc -class JavaEvaluator(Evaluator, JavaWrapper): - """ - Base class for :py:class:`Evaluator`s that wrap Java/Scala - implementations. - """ - - __metaclass__ = ABCMeta + """ + Initialize this instance with a Java model object. + Subclasses should call this constructor, initialize params, + and then call _transformer_params_from_java. + """ + super(JavaModel, self).__init__() + self._java_obj = java_model + self.uid = java_model.uid() - def evaluate(self, dataset, params={}): - java_obj = self._java_obj() - self._transfer_params_to_java(params, java_obj) - return java_obj.evaluate(dataset._jdf, self._empty_java_param_map()) + def copy(self, extra={}): + """ + Creates a copy of this instance with the same uid and some + extra params. This implementation first calls Params.copy and + then make a copy of the companion Java model with extra params. + So both the Python wrapper and the Java model get copied. + :param extra: Extra parameters to copy to the new instance + :return: Copy of this instance + """ + that = super(JavaModel, self).copy(extra) + that._java_obj = self._java_obj.copy(self._empty_java_param_map()) + that._transfer_params_to_java() + return that + + def _call_java(self, name, *args): + m = getattr(self._java_obj, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 07507b2ad0d05..acba3a717d21a 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -23,16 +23,10 @@ # MLlib currently needs NumPy 1.4+, so complain if lower import numpy -if numpy.version.version < '1.4': + +ver = [int(x) for x in numpy.version.version.split('.')[:2]] +if ver < [1, 4]: raise Exception("MLlib requires NumPy 1.4+") __all__ = ['classification', 'clustering', 'feature', 'fpm', 'linalg', 'random', 'recommendation', 'regression', 'stat', 'tree', 'util'] - -import sys -from . import rand as random -modname = __name__ + '.random' -random.__name__ = modname -random.RandomRDDs.__module__ = modname -sys.modules[modname] = random -del modname, sys diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 04e67158514f5..b55583f82223f 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -142,6 +142,7 @@ class GaussianMixtureModel(object): """A clustering model derived from the Gaussian Mixture Model method. + >>> from pyspark.mllib.linalg import Vectors, DenseMatrix >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2)) @@ -154,11 +155,12 @@ class GaussianMixtureModel(object): True >>> labels[4]==labels[5] True - >>> clusterdata_2 = sc.parallelize(array([-5.1971, -2.5359, -3.8220, - ... -5.2211, -5.0602, 4.7118, - ... 6.8989, 3.4592, 4.6322, - ... 5.7048, 4.6567, 5.5026, - ... 4.5605, 5.2043, 6.2734]).reshape(5, 3)) + >>> data = array([-5.1971, -2.5359, -3.8220, + ... -5.2211, -5.0602, 4.7118, + ... 6.8989, 3.4592, 4.6322, + ... 5.7048, 4.6567, 5.5026, + ... 4.5605, 5.2043, 6.2734]) + >>> clusterdata_2 = sc.parallelize(data.reshape(5,3)) >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001, ... maxIterations=150, seed=10) >>> labels = model.predict(clusterdata_2).collect() @@ -166,12 +168,38 @@ class GaussianMixtureModel(object): True >>> labels[3]==labels[4] True + >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1)) + >>> im = GaussianMixtureModel([0.5, 0.5], + ... [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])), + ... MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))]) + >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im) """ def __init__(self, weights, gaussians): - self.weights = weights - self.gaussians = gaussians - self.k = len(self.weights) + self._weights = weights + self._gaussians = gaussians + self._k = len(self._weights) + + @property + def weights(self): + """ + Weights for each Gaussian distribution in the mixture, where weights[i] is + the weight for Gaussian i, and weights.sum == 1. + """ + return self._weights + + @property + def gaussians(self): + """ + Array of MultivariateGaussian where gaussians[i] represents + the Multivariate Gaussian (Normal) Distribution for Gaussian i. + """ + return self._gaussians + + @property + def k(self): + """Number of gaussians in mixture.""" + return self._k def predict(self, x): """ @@ -184,6 +212,9 @@ def predict(self, x): if isinstance(x, RDD): cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z))) return cluster_labels + else: + raise TypeError("x should be represented by an RDD, " + "but got %s." % type(x)) def predictSoft(self, x): """ @@ -193,10 +224,13 @@ def predictSoft(self, x): :return: membership_matrix. RDD of array of double values. """ if isinstance(x, RDD): - means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) + means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians]) membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), - _convert_to_vector(self.weights), means, sigmas) + _convert_to_vector(self._weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) + else: + raise TypeError("x should be represented by an RDD, " + "but got %s." % type(x)) class GaussianMixture(object): @@ -208,13 +242,24 @@ class GaussianMixture(object): :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3 :param maxIterations: Number of iterations. Default to 100 :param seed: Random Seed + :param initialModel: GaussianMixtureModel for initializing learning """ @classmethod - def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None): + def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initialModel=None): """Train a Gaussian Mixture clustering model.""" - weight, mu, sigma = callMLlibFunc("trainGaussianMixture", - rdd.map(_convert_to_vector), k, - convergenceTol, maxIterations, seed) + initialModelWeights = None + initialModelMu = None + initialModelSigma = None + if initialModel is not None: + if initialModel.k != k: + raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s" + % (initialModel.k, k)) + initialModelWeights = initialModel.weights + initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] + initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] + weight, mu, sigma = callMLlibFunc("trainGaussianMixture", rdd.map(_convert_to_vector), k, + convergenceTol, maxIterations, seed, initialModelWeights, + initialModelMu, initialModelSigma) mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] return GaussianMixtureModel(weight, mvg_obj) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index ba6058978880a..855e85f57155e 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -27,7 +27,7 @@ from pyspark import RDD, SparkContext from pyspark.serializers import PickleSerializer, AutoBatchedSerializer - +from pyspark.sql import DataFrame, SQLContext # Hack for support float('inf') in Py4j _old_smart_decode = py4j.protocol.smart_decode @@ -99,6 +99,9 @@ def _java2py(sc, r, encoding="bytes"): jrdd = sc._jvm.SerDe.javaToPython(r) return RDD(jrdd, sc) + if clsName == 'DataFrame': + return DataFrame(r, SQLContext(sc)) + if clsName in _picklable_classes: r = sc._jvm.SerDe.dumps(r) elif isinstance(r, (JavaArray, JavaList)): diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 4c777f2180dc9..c5cf3a4e7ff22 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -27,6 +27,8 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ Evaluator for binary classification. + :param scoreAndLabels: an RDD of (score, label) pairs + >>> scoreAndLabels = sc.parallelize([ ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2) >>> metrics = BinaryClassificationMetrics(scoreAndLabels) @@ -38,9 +40,6 @@ class BinaryClassificationMetrics(JavaModelWrapper): """ def __init__(self, scoreAndLabels): - """ - :param scoreAndLabels: an RDD of (score, label) pairs - """ sc = scoreAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([ @@ -76,6 +75,9 @@ class RegressionMetrics(JavaModelWrapper): """ Evaluator for regression. + :param predictionAndObservations: an RDD of (prediction, + observation) pairs. + >>> predictionAndObservations = sc.parallelize([ ... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)]) >>> metrics = RegressionMetrics(predictionAndObservations) @@ -92,9 +94,6 @@ class RegressionMetrics(JavaModelWrapper): """ def __init__(self, predictionAndObservations): - """ - :param predictionAndObservations: an RDD of (prediction, observation) pairs. - """ sc = predictionAndObservations.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([ @@ -148,6 +147,8 @@ class MulticlassMetrics(JavaModelWrapper): """ Evaluator for multiclass classification. + :param predictionAndLabels an RDD of (prediction, label) pairs. + >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) >>> metrics = MulticlassMetrics(predictionAndLabels) @@ -176,9 +177,6 @@ class MulticlassMetrics(JavaModelWrapper): """ def __init__(self, predictionAndLabels): - """ - :param predictionAndLabels an RDD of (prediction, label) pairs. - """ sc = predictionAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ @@ -277,6 +275,9 @@ class RankingMetrics(JavaModelWrapper): """ Evaluator for ranking algorithms. + :param predictionAndLabels: an RDD of (predicted ranking, + ground truth set) pairs. + >>> predictionAndLabels = sc.parallelize([ ... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]), ... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]), @@ -298,9 +299,6 @@ class RankingMetrics(JavaModelWrapper): """ def __init__(self, predictionAndLabels): - """ - :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs. - """ sc = predictionAndLabels.ctx sql_ctx = SQLContext(sc) df = sql_ctx.createDataFrame(predictionAndLabels, @@ -334,16 +332,136 @@ def ndcgAt(self, k): """ Compute the average NDCG value of all the queries, truncated at ranking position k. The discounted cumulative gain at position k is computed as: - sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), + sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current implementation, the relevance value is binary. - - If a query has an empty ground truth set, zero will be used as ndcg together with + If a query has an empty ground truth set, zero will be used as NDCG together with a log warning. """ return self.call("ndcgAt", int(k)) +class MultilabelMetrics(JavaModelWrapper): + """ + Evaluator for multilabel classification. + + :param predictionAndLabels: an RDD of (predictions, labels) pairs, + both are non-null Arrays, each with + unique elements. + + >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]), + ... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]), + ... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]) + >>> metrics = MultilabelMetrics(predictionAndLabels) + >>> metrics.precision(0.0) + 1.0 + >>> metrics.recall(1.0) + 0.66... + >>> metrics.f1Measure(2.0) + 0.5 + >>> metrics.precision() + 0.66... + >>> metrics.recall() + 0.64... + >>> metrics.f1Measure() + 0.63... + >>> metrics.microPrecision + 0.72... + >>> metrics.microRecall + 0.66... + >>> metrics.microF1Measure + 0.69... + >>> metrics.hammingLoss + 0.33... + >>> metrics.subsetAccuracy + 0.28... + >>> metrics.accuracy + 0.54... + """ + + def __init__(self, predictionAndLabels): + sc = predictionAndLabels.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(predictionAndLabels, + schema=sql_ctx._inferSchema(predictionAndLabels)) + java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics + java_model = java_class(df._jdf) + super(MultilabelMetrics, self).__init__(java_model) + + def precision(self, label=None): + """ + Returns precision or precision for a given label (category) if specified. + """ + if label is None: + return self.call("precision") + else: + return self.call("precision", float(label)) + + def recall(self, label=None): + """ + Returns recall or recall for a given label (category) if specified. + """ + if label is None: + return self.call("recall") + else: + return self.call("recall", float(label)) + + def f1Measure(self, label=None): + """ + Returns f1Measure or f1Measure for a given label (category) if specified. + """ + if label is None: + return self.call("f1Measure") + else: + return self.call("f1Measure", float(label)) + + @property + def microPrecision(self): + """ + Returns micro-averaged label-based precision. + (equals to micro-averaged document-based precision) + """ + return self.call("microPrecision") + + @property + def microRecall(self): + """ + Returns micro-averaged label-based recall. + (equals to micro-averaged document-based recall) + """ + return self.call("microRecall") + + @property + def microF1Measure(self): + """ + Returns micro-averaged label-based f1-measure. + (equals to micro-averaged document-based f1-measure) + """ + return self.call("microF1Measure") + + @property + def hammingLoss(self): + """ + Returns Hamming-loss. + """ + return self.call("hammingLoss") + + @property + def subsetAccuracy(self): + """ + Returns subset accuracy. + (for equal sets of labels) + """ + return self.call("subsetAccuracy") + + @property + def accuracy(self): + """ + Returns accuracy. + """ + return self.call("accuracy") + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index aac305db6c19a..da90554f41437 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -68,6 +68,8 @@ class Normalizer(VectorTransformer): For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization. + :param p: Normalization in L^p^ space, p = 2 by default. + >>> v = Vectors.dense(range(3)) >>> nor = Normalizer(1) >>> nor.transform(v) @@ -82,9 +84,6 @@ class Normalizer(VectorTransformer): DenseVector([0.0, 0.5, 1.0]) """ def __init__(self, p=2.0): - """ - :param p: Normalization in L^p^ space, p = 2 by default. - """ assert p >= 1.0, "p should be greater than 1.0" self.p = float(p) @@ -94,7 +93,7 @@ def transform(self, vector): :param vector: vector or RDD of vector to be normalized. :return: normalized vector. If the norm of the input is zero, it - will return the input vector. + will return the input vector. """ sc = SparkContext._active_spark_context assert sc is not None, "SparkContext should be initialized first" @@ -164,6 +163,13 @@ class StandardScaler(object): variance using column summary statistics on the samples in the training set. + :param withMean: False by default. Centers the data with mean + before scaling. It will build a dense output, so this + does not work on sparse input and will raise an + exception. + :param withStd: True by default. Scales the data to unit + standard deviation. + >>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])] >>> dataset = sc.parallelize(vs) >>> standardizer = StandardScaler(True, True) @@ -174,14 +180,6 @@ class StandardScaler(object): DenseVector([0.7071, -0.7071, 0.7071]) """ def __init__(self, withMean=False, withStd=True): - """ - :param withMean: False by default. Centers the data with mean - before scaling. It will build a dense output, so this - does not work on sparse input and will raise an - exception. - :param withStd: True by default. Scales the data to unit - standard deviation. - """ if not (withMean or withStd): warnings.warn("Both withMean and withStd are false. The model does nothing.") self.withMean = withMean @@ -193,7 +191,7 @@ def fit(self, dataset): for later scaling. :param data: The data used to compute the mean and variance - to build the transformation model. + to build the transformation model. :return: a StandardScalarModel """ dataset = dataset.map(_convert_to_vector) @@ -223,6 +221,8 @@ class ChiSqSelector(object): Creates a ChiSquared feature selector. + :param numTopFeatures: number of features that selector will select. + >>> data = [ ... LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})), ... LabeledPoint(1.0, SparseVector(3, {1: 9.0, 2: 6.0})), @@ -236,9 +236,6 @@ class ChiSqSelector(object): DenseVector([5.0]) """ def __init__(self, numTopFeatures): - """ - :param numTopFeatures: number of features that selector will select. - """ self.numTopFeatures = int(numTopFeatures) def fit(self, data): @@ -246,9 +243,9 @@ def fit(self, data): Returns a ChiSquared feature selector. :param data: an `RDD[LabeledPoint]` containing the labeled dataset - with categorical features. Real-valued features will be - treated as categorical for each distinct value. - Apply feature discretizer before using this function. + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + Apply feature discretizer before using this function. """ jmodel = callMLlibFunc("fitChiSqSelector", self.numTopFeatures, data) return ChiSqSelectorModel(jmodel) @@ -263,15 +260,14 @@ class HashingTF(object): Note: the terms must be hashable (can not be dict/set/list...). + :param numFeatures: number of features (default: 2^20) + >>> htf = HashingTF(100) >>> doc = "a a b b c d".split(" ") >>> htf.transform(doc) SparseVector(100, {...}) """ def __init__(self, numFeatures=1 << 20): - """ - :param numFeatures: number of features (default: 2^20) - """ self.numFeatures = numFeatures def indexOf(self, term): @@ -311,7 +307,7 @@ def transform(self, x): Call transform directly on the RDD instead. :param x: an RDD of term frequency vectors or a term frequency - vector + vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ if isinstance(x, RDD): @@ -342,6 +338,9 @@ class IDF(object): `minDocFreq`). For terms that are not in at least `minDocFreq` documents, the IDF is found as 0, resulting in TF-IDFs of 0. + :param minDocFreq: minimum of documents in which a term + should appear for filtering + >>> n = 4 >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)), ... Vectors.dense([0.0, 1.0, 2.0, 3.0]), @@ -362,10 +361,6 @@ class IDF(object): SparseVector(4, {1: 0.0, 3: 0.5754}) """ def __init__(self, minDocFreq=0): - """ - :param minDocFreq: minimum of documents in which a term - should appear for filtering - """ self.minDocFreq = minDocFreq def fit(self, dataset): diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index d8df02bdbaba9..bdc4a132b1b18 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -61,12 +61,12 @@ class FPGrowth(object): def train(cls, data, minSupport=0.3, numPartitions=-1): """ Computes an FP-Growth model that contains frequent itemsets. - :param data: The input data set, each element - contains a transaction. - :param minSupport: The minimal support level - (default: `0.3`). - :param numPartitions: The number of partitions used by parallel - FP-growth (default: same as input data). + + :param data: The input data set, each element contains a + transaction. + :param minSupport: The minimal support level (default: `0.3`). + :param numPartitions: The number of partitions used by + parallel FP-growth (default: same as input data). """ model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions)) return FPGrowthModel(model) diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/random.py similarity index 100% rename from python/pyspark/mllib/rand.py rename to python/pyspark/mllib/random.py diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 545c5ad20cb96..98a8ff8606366 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -813,13 +813,21 @@ def op(x, y): def fold(self, zeroValue, op): """ Aggregate the elements of each partition, and then the results for all - the partitions, using a given associative function and a neutral "zero - value." + the partitions, using a given associative and commutative function and + a neutral "zero value." The function C{op(t1, t2)} is allowed to modify C{t1} and return it as its result value to avoid object allocation; however, it should not modify C{t2}. + This behaves somewhat differently from fold operations implemented + for non-distributed collections in functional languages like Scala. + This fold operation may be applied to partitions individually, and then + fold those results into the final result, rather than apply the fold + to each element sequentially in some defined ordering. For functions + that are not commutative, the result may differ from that of a fold + applied to a non-distributed collection. + >>> from operator import add >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 @@ -2260,7 +2268,7 @@ def toLocalIterator(self): def _prepare_for_python_RDD(sc, command, obj=None): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() - pickled_command = ser.dumps((command, sys.version_info[:2])) + pickled_command = ser.dumps(command) if len(pickled_command) > (1 << 20): # 1M # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) @@ -2344,7 +2352,7 @@ def _jrdd(self): python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), bytearray(pickled_cmd), env, includes, self.preservesPartitioning, - self.ctx.pythonExec, + self.ctx.pythonExec, self.ctx.pythonVer, bvars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 1d0b16cade8bb..81c420ce16541 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -362,7 +362,7 @@ def _spill(self): self.spills += 1 gc.collect() # release the memory as much as possible - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 def items(self): """ Return all merged items as iterator """ @@ -515,7 +515,7 @@ def load(f): gc.collect() batch //= 2 limit = self._next_limit() - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 DiskBytesSpilled += os.path.getsize(path) os.unlink(path) # data will be deleted after close @@ -630,7 +630,7 @@ def _spill(self): self.values = [] gc.collect() DiskBytesSpilled += self._file.tell() - pos - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 class ExternalListOfList(ExternalList): @@ -794,7 +794,7 @@ def _spill(self): self.spills += 1 gc.collect() # release the memory as much as possible - MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 + MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 def _merged_items(self, index): size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index))) diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 7192c89b3dc7f..ad9c891ba1c04 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -18,47 +18,58 @@ """ Important classes of Spark SQL and DataFrames: - - L{SQLContext} + - :class:`pyspark.sql.SQLContext` Main entry point for :class:`DataFrame` and SQL functionality. - - L{DataFrame} + - :class:`pyspark.sql.DataFrame` A distributed collection of data grouped into named columns. - - L{Column} + - :class:`pyspark.sql.Column` A column expression in a :class:`DataFrame`. - - L{Row} + - :class:`pyspark.sql.Row` A row of data in a :class:`DataFrame`. - - L{HiveContext} + - :class:`pyspark.sql.HiveContext` Main entry point for accessing data stored in Apache Hive. - - L{GroupedData} + - :class:`pyspark.sql.GroupedData` Aggregation methods, returned by :func:`DataFrame.groupBy`. - - L{DataFrameNaFunctions} + - :class:`pyspark.sql.DataFrameNaFunctions` Methods for handling missing data (null values). - - L{DataFrameStatFunctions} + - :class:`pyspark.sql.DataFrameStatFunctions` Methods for statistics functionality. - - L{functions} + - :class:`pyspark.sql.functions` List of built-in functions available for :class:`DataFrame`. - - L{types} + - :class:`pyspark.sql.types` List of data types available. + - :class:`pyspark.sql.Window` + For working with window functions. """ from __future__ import absolute_import -# fix the module name conflict for Python 3+ -import sys -from . import _types as types -modname = __name__ + '.types' -types.__name__ = modname -# update the __module__ for all objects, make them picklable -for v in types.__dict__.values(): - if hasattr(v, "__module__") and v.__module__.endswith('._types'): - v.__module__ = modname -sys.modules[modname] = types -del modname, sys + +def since(version): + """ + A decorator that annotates a function to append the version of Spark the function was added. + """ + import re + indent_p = re.compile(r'\n( +)') + + def deco(f): + indents = indent_p.findall(f.__doc__) + indent = ' ' * (min(len(m) for m in indents) if indents else 0) + f.__doc__ = f.__doc__.rstrip() + "\n\n%s.. versionadded:: %s" % (indent, version) + return f + return deco + from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext -from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions -from pyspark.sql.dataframe import DataFrameStatFunctions +from pyspark.sql.column import Column +from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions +from pyspark.sql.group import GroupedData +from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter +from pyspark.sql.window import Window, WindowSpec + __all__ = [ 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', - 'DataFrameNaFunctions', 'DataFrameStatFunctions' + 'DataFrameNaFunctions', 'DataFrameStatFunctions', 'Window', 'WindowSpec', + 'DataFrameReader', 'DataFrameWriter' ] diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py new file mode 100644 index 0000000000000..1ecec5b126505 --- /dev/null +++ b/python/pyspark/sql/column.py @@ -0,0 +1,425 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +if sys.version >= '3': + basestring = str + long = int + +from pyspark.context import SparkContext +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql import since +from pyspark.sql.types import * + +__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions", + "DataFrameStatFunctions"] + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.functions.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.functions.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol + + +def _to_seq(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM Seq of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + return sc._jvm.PythonUtils.toSeq(cols) + + +def _unary_op(name, doc="unary operator"): + """ Create a method for given unary operator """ + def _(self): + jc = getattr(self._jc, name)() + return Column(jc) + _.__doc__ = doc + return _ + + +def _func_op(name, doc=''): + def _(self): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +def _bin_op(name, doc="binary operator"): + """ Create a method for given binary operator + """ + def _(self, other): + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, name)(jc) + return Column(njc) + _.__doc__ = doc + return _ + + +def _reverse_op(name, doc="binary operator"): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + jother = _create_column_from_literal(other) + jc = getattr(jother, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +class Column(object): + + """ + A column in a DataFrame. + + :class:`Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + + df.colName + df["colName"] + + # 2. Create from an expression + df.colName + 1 + 1 / df.colName + + .. note:: Experimental + + .. versionadded:: 1.3 + """ + + def __init__(self, jc): + self._jc = jc + + # arithmetic operators + __neg__ = _func_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __truediv__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rtruediv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") + + # logistic operators + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _func_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("apply") + + # bitwise operators + bitwiseOR = _bin_op("bitwiseOR") + bitwiseAND = _bin_op("bitwiseAND") + bitwiseXOR = _bin_op("bitwiseXOR") + + @since(1.3) + def getItem(self, key): + """ + An expression that gets an item at position ``ordinal`` out of a list, + or gets an item by key out of a dict. + + >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() + +----+------+ + |l[0]|d[key]| + +----+------+ + | 1| value| + +----+------+ + >>> df.select(df.l[0], df.d["key"]).show() + +----+------+ + |l[0]|d[key]| + +----+------+ + | 1| value| + +----+------+ + """ + return self[key] + + @since(1.3) + def getField(self, name): + """ + An expression that gets a field by name in a StructField. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df.select(df.r.getField("b")).show() + +----+ + |r[b]| + +----+ + | b| + +----+ + >>> df.select(df.r.a).show() + +----+ + |r[a]| + +----+ + | 1| + +----+ + """ + return self[name] + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + return self.getField(item) + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + + @ignore_unicode_prefix + @since(1.3) + def substr(self, startPos, length): + """ + Return a :class:`Column` which is a substring of the column. + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.select(df.name.substr(1, 3).alias("col")).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + jc = self._jc.substr(startPos, length) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, length._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc) + + __getslice__ = substr + + @ignore_unicode_prefix + @since(1.3) + def inSet(self, *cols): + """ + A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.inSet("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.inSet([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] + sc = SparkContext._active_spark_context + jc = getattr(self._jc, "in")(_to_seq(sc, cols)) + return Column(jc) + + # order + asc = _unary_op("asc", "Returns a sort expression based on the" + " ascending order of the given column name.") + desc = _unary_op("desc", "Returns a sort expression based on the" + " descending order of the given column name.") + + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + + @since(1.3) + def alias(self, *alias): + """ + Returns this column aliased with a new name or names (in the case of expressions that + return more than one column, such as explode). + + >>> df.select(df.age.alias("age2")).collect() + [Row(age2=2), Row(age2=5)] + """ + + if len(alias) == 1: + return Column(getattr(self._jc, "as")(alias[0])) + else: + sc = SparkContext._active_spark_context + return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) + + @ignore_unicode_prefix + @since(1.3) + def cast(self, dataType): + """ Convert the column into type ``dataType``. + + >>> df.select(df.age.cast("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + else: + raise TypeError("unexpected type: %s" % type(dataType)) + return Column(jc) + + astype = cast + + @since(1.3) + def between(self, lowerBound, upperBound): + """ + A boolean expression that is evaluated to true if the value of this + expression is between the given columns. + + >>> df.select(df.name, df.age.between(2, 4)).show() + +-----+--------------------------+ + | name|((age >= 2) && (age <= 4))| + +-----+--------------------------+ + |Alice| true| + | Bob| false| + +-----+--------------------------+ + """ + return (self >= lowerBound) & (self <= upperBound) + + @since(1.4) + def when(self, condition, value): + """ + Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param condition: a boolean :class:`Column` expression. + :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show() + +-----+--------------------------------------------------------+ + | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0| + +-----+--------------------------------------------------------+ + |Alice| -1| + | Bob| 1| + +-----+--------------------------------------------------------+ + """ + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + v = value._jc if isinstance(value, Column) else value + jc = self._jc.when(condition._jc, v) + return Column(jc) + + @since(1.4) + def otherwise(self, value): + """ + Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param value: a literal value, or a :class:`Column` expression. + + >>> from pyspark.sql import functions as F + >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show() + +-----+---------------------------------+ + | name|CASE WHEN (age > 3) THEN 1 ELSE 0| + +-----+---------------------------------+ + |Alice| 0| + | Bob| 1| + +-----+---------------------------------+ + """ + v = value._jc if isinstance(value, Column) else value + jc = self._jc.otherwise(v) + return Column(jc) + + @since(1.4) + def over(self, window): + """ + Define a windowing column. + + :param window: a :class:`WindowSpec` + :return: a Column + + >>> from pyspark.sql import Window + >>> window = Window.partitionBy("name").orderBy("age").rowsBetween(-1, 1) + >>> from pyspark.sql.functions import rank, min + >>> # df.select(rank().over(window), min('age').over(window)) + + .. note:: Window functions is only supported with HiveContext in 1.4 + """ + from pyspark.sql.window import WindowSpec + if not isinstance(window, WindowSpec): + raise TypeError("window should be WindowSpec") + jc = self._jc.over(window._jspec) + return Column(jc) + + def __repr__(self): + return 'Column<%s>' % self._jc.toString().encode('utf8') + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + import pyspark.sql.column + globs = pyspark.sql.column.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.column, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index f6f107ca32d2f..9fdf43c3e6eb5 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -28,9 +28,11 @@ from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import AutoBatchedSerializer, PickleSerializer +from pyspark.sql import since from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter from pyspark.sql.dataframe import DataFrame +from pyspark.sql.readwriter import DataFrameReader try: import pandas @@ -105,11 +107,13 @@ def _ssql_ctx(self): self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext + @since(1.3) def setConf(self, key, value): """Sets the given Spark SQL configuration property. """ self._ssql_ctx.setConf(key, value) + @since(1.3) def getConf(self, key, defaultValue): """Returns the value of Spark SQL configuration property for the given key. @@ -118,11 +122,37 @@ def getConf(self, key, defaultValue): return self._ssql_ctx.getConf(key, defaultValue) @property + @since("1.3.1") def udf(self): - """Returns a :class:`UDFRegistration` for UDF registration.""" + """Returns a :class:`UDFRegistration` for UDF registration. + + :return: :class:`UDFRegistration` + """ return UDFRegistration(self) + @since(1.4) + def range(self, start, end, step=1, numPartitions=None): + """ + Create a :class:`DataFrame` with single LongType column named `id`, + containing elements in a range from `start` to `end` (exclusive) with + step value `step`. + + :param start: the start value + :param end: the end value (exclusive) + :param step: the incremental step (default: 1) + :param numPartitions: the number of partitions of the DataFrame + :return: :class:`DataFrame` + + >>> sqlContext.range(1, 7, 2).collect() + [Row(id=1), Row(id=3), Row(id=5)] + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions)) + return DataFrame(jdf, self) + @ignore_unicode_prefix + @since(1.2) def registerFunction(self, name, f, returnType=StringType()): """Registers a lambda function as a UDF so it can be used in SQL statements. @@ -157,6 +187,7 @@ def registerFunction(self, name, f, returnType=StringType()): env, includes, self._sc.pythonExec, + self._sc.pythonVer, bvars, self._sc._javaAccumulator, returnType.json()) @@ -167,8 +198,8 @@ def _inferSchema(self, rdd, samplingRatio=None): raise ValueError("The first row in RDD is empty, " "can not infer schema") if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.sql.Row instead") + warnings.warn("Using RDD of dict to inferSchema is deprecated. " + "Use pyspark.sql.Row instead") if samplingRatio is None: schema = _infer_schema(first) @@ -188,9 +219,10 @@ def _inferSchema(self, rdd, samplingRatio=None): @ignore_unicode_prefix def inferSchema(self, rdd, samplingRatio=None): - """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. """ - warnings.warn("inferSchema is deprecated, please use createDataFrame instead") + .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. + """ + warnings.warn("inferSchema is deprecated, please use createDataFrame instead.") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") @@ -199,7 +231,8 @@ def inferSchema(self, rdd, samplingRatio=None): @ignore_unicode_prefix def applySchema(self, rdd, schema): - """::note: Deprecated in 1.3, use :func:`createDataFrame` instead. + """ + .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. """ warnings.warn("applySchema is deprecated, please use createDataFrame instead") @@ -211,6 +244,7 @@ def applySchema(self, rdd, schema): return self.createDataFrame(rdd, schema) + @since(1.3) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None): """ @@ -231,6 +265,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): :class:`list`, or :class:`pandas.DataFrame`. :param schema: a :class:`StructType` or list of column names. default None. :param samplingRatio: the sample ratio of rows used for inferring + :return: :class:`DataFrame` >>> l = [('Alice', 1)] >>> sqlContext.createDataFrame(l).collect() @@ -315,6 +350,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return DataFrame(df, self) + @since(1.3) def registerDataFrameAsTable(self, df, tableName): """Registers the given :class:`DataFrame` as a temporary table in the catalog. @@ -330,14 +366,12 @@ def registerDataFrameAsTable(self, df, tableName): def parquetFile(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.parquetFile(parquetFile) - >>> sorted(df.collect()) == sorted(df2.collect()) - True + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead. + + >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ + warnings.warn("parquetFile is deprecated. Use read.parquet() instead.") gateway = self._sc._gateway jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) for i in range(0, len(paths)): @@ -348,35 +382,12 @@ def parquetFile(self, *paths): def jsonFile(self, path, schema=None, samplingRatio=1.0): """Loads a text file storing one JSON object per line as a :class:`DataFrame`. - If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. - - >>> import tempfile, shutil - >>> jsonFile = tempfile.mkdtemp() - >>> shutil.rmtree(jsonFile) - >>> with open(jsonFile, 'w') as f: - ... f.writelines(jsonStrings) - >>> df1 = sqlContext.jsonFile(jsonFile) - >>> df1.printSchema() - root - |-- field1: long (nullable = true) - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field4: long (nullable = true) + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead. - >>> from pyspark.sql.types import * - >>> schema = StructType([ - ... StructField("field2", StringType()), - ... StructField("field3", - ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) - >>> df2 = sqlContext.jsonFile(jsonFile, schema) - >>> df2.printSchema() - root - |-- field2: string (nullable = true) - |-- field3: struct (nullable = true) - | |-- field5: array (nullable = true) - | | |-- element: integer (containsNull = true) + >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes + [('age', 'bigint'), ('name', 'string')] """ + warnings.warn("jsonFile is deprecated. Use read.json() instead.") if schema is None: df = self._ssql_ctx.jsonFile(path, samplingRatio) else: @@ -385,6 +396,7 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): return DataFrame(df, self) @ignore_unicode_prefix + @since(1.0) def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. @@ -430,28 +442,13 @@ def func(iterator): def load(self, path=None, source=None, schema=None, **options): """Returns the dataset in a data source as a :class:`DataFrame`. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Optionally, a schema can be provided as the schema of the returned DataFrame. + .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead. """ - if path is not None: - options["path"] = path - if source is None: - source = self.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - if schema is None: - df = self._ssql_ctx.load(source, options) - else: - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.load(source, scala_datatype, options) - return DataFrame(df, self) + warnings.warn("load is deprecated. Use read.load() instead.") + return self.read.load(path, source, schema, **options) - def createExternalTable(self, tableName, path=None, source=None, - schema=None, **options): + @since(1.3) + def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): """Creates an external table based on the dataset in a data source. It returns the DataFrame associated with the external table. @@ -462,6 +459,8 @@ def createExternalTable(self, tableName, path=None, source=None, Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. + + :return: :class:`DataFrame` """ if path is not None: options["path"] = path @@ -479,9 +478,12 @@ def createExternalTable(self, tableName, path=None, source=None, return DataFrame(df, self) @ignore_unicode_prefix + @since(1.0) def sql(self, sqlQuery): """Returns a :class:`DataFrame` representing the result of the given query. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1") >>> df2.collect() @@ -489,9 +491,12 @@ def sql(self, sqlQuery): """ return DataFrame(self._ssql_ctx.sql(sqlQuery), self) + @since(1.0) def table(self, tableName): """Returns the specified table as a :class:`DataFrame`. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.table("table1") >>> sorted(df.collect()) == sorted(df2.collect()) @@ -500,6 +505,7 @@ def table(self, tableName): return DataFrame(self._ssql_ctx.table(tableName), self) @ignore_unicode_prefix + @since(1.3) def tables(self, dbName=None): """Returns a :class:`DataFrame` containing names of tables in the given database. @@ -508,6 +514,9 @@ def tables(self, dbName=None): The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` (a column with :class:`BooleanType` indicating if a table is a temporary one or not). + :param dbName: string, name of the database to use. + :return: :class:`DataFrame` + >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> df2 = sqlContext.tables() >>> df2.filter("tableName = 'table1'").first() @@ -518,10 +527,12 @@ def tables(self, dbName=None): else: return DataFrame(self._ssql_ctx.tables(dbName), self) + @since(1.3) def tableNames(self, dbName=None): """Returns a list of names of tables in the database ``dbName``. - If ``dbName`` is not specified, the current database will be used. + :param dbName: string, name of the database to use. Default to the current database. + :return: list of table names, in string >>> sqlContext.registerDataFrameAsTable(df, "table1") >>> "table1" in sqlContext.tableNames() @@ -534,18 +545,32 @@ def tableNames(self, dbName=None): else: return [name for name in self._ssql_ctx.tableNames(dbName)] + @since(1.0) def cacheTable(self, tableName): """Caches the specified table in-memory.""" self._ssql_ctx.cacheTable(tableName) + @since(1.0) def uncacheTable(self, tableName): """Removes the specified table from the in-memory cache.""" self._ssql_ctx.uncacheTable(tableName) + @since(1.3) def clearCache(self): """Removes all cached tables from the in-memory cache. """ self._ssql_ctx.clearCache() + @property + @since(1.4) + def read(self): + """ + Returns a :class:`DataFrameReader` that can be used to read data + in as a :class:`DataFrame`. + + :return: :class:`DataFrameReader` + """ + return DataFrameReader(self) + class HiveContext(SQLContext): """A variant of Spark SQL that integrates with data stored in Hive. @@ -600,10 +625,14 @@ def register(self, name, f, returnType=StringType()): def _test(): + import os import doctest from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.context + + os.chdir(os.environ["SPARK_HOME"]) + globs = pyspark.sql.context.__dict__.copy() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 82cb1c2fdbf94..7673153abe0e2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -25,17 +25,17 @@ else: from itertools import imap as map -from pyspark.context import SparkContext from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync -from pyspark.sql.types import * +from pyspark.sql import since from pyspark.sql.types import _create_cls, _parse_datatype_json_string +from pyspark.sql.column import Column, _to_seq, _to_java_column +from pyspark.sql.readwriter import DataFrameWriter +from pyspark.sql.types import * - -__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions", - "DataFrameStatFunctions"] +__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"] class DataFrame(object): @@ -44,7 +44,7 @@ class DataFrame(object): A :class:`DataFrame` is equivalent to a relational table in Spark SQL, and can be created using various functions in :class:`SQLContext`:: - people = sqlContext.parquetFile("...") + people = sqlContext.read.parquet("...") Once created, it can be manipulated using the various domain-specific-language (DSL) functions defined in: :class:`DataFrame`, :class:`Column`. @@ -56,11 +56,15 @@ class DataFrame(object): A more concrete example:: # To create DataFrame using SQLContext - people = sqlContext.parquetFile("...") - department = sqlContext.parquetFile("...") + people = sqlContext.read.parquet("...") + department = sqlContext.read.parquet("...") people.filter(people.age > 30).join(department, people.deptId == department.id)) \ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) + + .. note:: Experimental + + .. versionadded:: 1.3 """ def __init__(self, jdf, sql_ctx): @@ -72,6 +76,7 @@ def __init__(self, jdf, sql_ctx): self._lazy_rdd = None @property + @since(1.3) def rdd(self): """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. """ @@ -89,18 +94,21 @@ def applySchema(it): return self._lazy_rdd @property + @since("1.3.1") def na(self): """Returns a :class:`DataFrameNaFunctions` for handling missing values. """ return DataFrameNaFunctions(self) @property + @since(1.4) def stat(self): """Returns a :class:`DataFrameStatFunctions` for statistic functions. """ return DataFrameStatFunctions(self) @ignore_unicode_prefix + @since(1.3) def toJSON(self, use_unicode=True): """Converts a :class:`DataFrame` into a :class:`RDD` of string. @@ -115,19 +123,12 @@ def toJSON(self, use_unicode=True): def saveAsParquetFile(self, path): """Saves the contents as a Parquet file, preserving the schema. - Files that are written out using this method can be read back in as - a :class:`DataFrame` using :func:`SQLContext.parquetFile`. - - >>> import tempfile, shutil - >>> parquetFile = tempfile.mkdtemp() - >>> shutil.rmtree(parquetFile) - >>> df.saveAsParquetFile(parquetFile) - >>> df2 = sqlContext.parquetFile(parquetFile) - >>> sorted(df2.collect()) == sorted(df.collect()) - True + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead. """ + warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.") self._jdf.saveAsParquetFile(path) + @since(1.3) def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. @@ -142,81 +143,49 @@ def registerTempTable(self, name): self._jdf.registerTempTable(name) def registerAsTable(self, name): - """DEPRECATED: use :func:`registerTempTable` instead""" - warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + """ + .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead. + """ + warnings.warn("Use registerTempTable instead of registerAsTable.") self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): """Inserts the contents of this :class:`DataFrame` into the specified table. - Optionally overwriting any existing data. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead. """ - self._jdf.insertInto(tableName, overwrite) - - def _java_save_mode(self, mode): - """Returns the Java save mode based on the Python save mode represented by a string. - """ - jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode - jmode = jSaveMode.ErrorIfExists - mode = mode.lower() - if mode == "append": - jmode = jSaveMode.Append - elif mode == "overwrite": - jmode = jSaveMode.Overwrite - elif mode == "ignore": - jmode = jSaveMode.Ignore - elif mode == "error": - pass - else: - raise ValueError( - "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.") - return jmode + warnings.warn("insertInto is deprecated. Use write.insertInto() instead.") + self.write.insertInto(tableName, overwrite) def saveAsTable(self, tableName, source=None, mode="error", **options): """Saves the contents of this :class:`DataFrame` to a data source as a table. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. - - Additionally, mode is used to specify the behavior of the saveAsTable operation when - table already exists in the data source. There are four modes: - - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead. """ - if source is None: - source = self.sql_ctx.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - jmode = self._java_save_mode(mode) - self._jdf.saveAsTable(tableName, source, jmode, options) + warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.") + self.write.saveAsTable(tableName, source, mode, **options) + @since(1.3) def save(self, path=None, source=None, mode="error", **options): """Saves the contents of the :class:`DataFrame` to a data source. - The data source is specified by the ``source`` and a set of ``options``. - If ``source`` is not specified, the default data source configured by - ``spark.sql.sources.default`` will be used. + .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead. + """ + warnings.warn("insertInto is deprecated. Use write.save() instead.") + return self.write.save(path, source, mode, **options) - Additionally, mode is used to specify the behavior of the save operation when - data already exists in the data source. There are four modes: + @property + @since(1.4) + def write(self): + """ + Interface for saving the content of the :class:`DataFrame` out into external storage. - * `append`: Append contents of this :class:`DataFrame` to existing data. - * `overwrite`: Overwrite existing data. - * `error`: Throw an exception if data already exists. - * `ignore`: Silently ignore this operation if data already exists. + :return: :class:`DataFrameWriter` """ - if path is not None: - options["path"] = path - if source is None: - source = self.sql_ctx.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - jmode = self._java_save_mode(mode) - self._jdf.save(source, jmode, options) + return DataFrameWriter(self) @property + @since(1.3) def schema(self): """Returns the schema of this :class:`DataFrame` as a :class:`types.StructType`. @@ -227,6 +196,7 @@ def schema(self): self._schema = _parse_datatype_json_string(self._jdf.schema().json()) return self._schema + @since(1.3) def printSchema(self): """Prints out the schema in the tree format. @@ -238,6 +208,7 @@ def printSchema(self): """ print(self._jdf.schema().treeString()) + @since(1.3) def explain(self, extended=False): """Prints the (logical and physical) plans to the console for debugging purpose. @@ -263,12 +234,14 @@ def explain(self, extended=False): else: print(self._jdf.queryExecution().executedPlan().toString()) + @since(1.3) def isLocal(self): """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally (without any Spark executors). """ return self._jdf.isLocal() + @since(1.3) def show(self, n=20): """Prints the first ``n`` rows to the console. @@ -287,6 +260,7 @@ def show(self, n=20): def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) + @since(1.3) def count(self): """Returns the number of rows in this :class:`DataFrame`. @@ -296,6 +270,7 @@ def count(self): return int(self._jdf.count()) @ignore_unicode_prefix + @since(1.3) def collect(self): """Returns all the records as a list of :class:`Row`. @@ -309,6 +284,7 @@ def collect(self): return [cls(r) for r in rs] @ignore_unicode_prefix + @since(1.3) def limit(self, num): """Limits the result count to the number specified. @@ -321,6 +297,7 @@ def limit(self, num): return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix + @since(1.3) def take(self, num): """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. @@ -330,6 +307,7 @@ def take(self, num): return self.limit(num).collect() @ignore_unicode_prefix + @since(1.3) def map(self, f): """ Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`. @@ -341,6 +319,7 @@ def map(self, f): return self.rdd.map(f) @ignore_unicode_prefix + @since(1.3) def flatMap(self, f): """ Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`, and then flattening the results. @@ -352,6 +331,7 @@ def flatMap(self, f): """ return self.rdd.flatMap(f) + @since(1.3) def mapPartitions(self, f, preservesPartitioning=False): """Returns a new :class:`RDD` by applying the ``f`` function to each partition. @@ -364,6 +344,7 @@ def mapPartitions(self, f, preservesPartitioning=False): """ return self.rdd.mapPartitions(f, preservesPartitioning) + @since(1.3) def foreach(self, f): """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. @@ -375,6 +356,7 @@ def foreach(self, f): """ return self.rdd.foreach(f) + @since(1.3) def foreachPartition(self, f): """Applies the ``f`` function to each partition of this :class:`DataFrame`. @@ -387,6 +369,7 @@ def foreachPartition(self, f): """ return self.rdd.foreachPartition(f) + @since(1.3) def cache(self): """ Persists with the default storage level (C{MEMORY_ONLY_SER}). """ @@ -394,6 +377,7 @@ def cache(self): self._jdf.cache() return self + @since(1.3) def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): """Sets the storage level to persist its values across operations after the first time it is computed. This can only be used to assign @@ -405,6 +389,7 @@ def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): self._jdf.persist(javaStorageLevel) return self + @since(1.3) def unpersist(self, blocking=True): """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from memory and disk. @@ -413,10 +398,22 @@ def unpersist(self, blocking=True): self._jdf.unpersist(blocking) return self - # def coalesce(self, numPartitions, shuffle=False): - # rdd = self._jdf.coalesce(numPartitions, shuffle, None) - # return DataFrame(rdd, self.sql_ctx) + @since(1.4) + def coalesce(self, numPartitions): + """ + Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions. + + Similar to coalesce defined on an :class:`RDD`, this operation results in a + narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, + there will not be a shuffle, instead each of the 100 new partitions will + claim 10 of the current partitions. + + >>> df.coalesce(1).rdd.getNumPartitions() + 1 + """ + return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx) + @since(1.3) def repartition(self, numPartitions): """Returns a new :class:`DataFrame` that has exactly ``numPartitions`` partitions. @@ -425,6 +422,7 @@ def repartition(self, numPartitions): """ return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + @since(1.3) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. @@ -433,6 +431,7 @@ def distinct(self): """ return DataFrame(self._jdf.distinct(), self.sql_ctx) + @since(1.3) def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. @@ -444,6 +443,7 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. @@ -466,6 +466,7 @@ def randomSplit(self, weights, seed=None): return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] @property + @since(1.3) def dtypes(self): """Returns all column names and their data types as a list. @@ -476,6 +477,7 @@ def dtypes(self): @property @ignore_unicode_prefix + @since(1.3) def columns(self): """Returns all column names as a list. @@ -485,6 +487,7 @@ def columns(self): return [f.name for f in self.schema.fields] @ignore_unicode_prefix + @since(1.3) def alias(self, alias): """Returns a new :class:`DataFrame` with an alias set. @@ -499,6 +502,7 @@ def alias(self, alias): return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) @ignore_unicode_prefix + @since(1.3) def join(self, other, joinExprs=None, joinType=None): """Joins with another :class:`DataFrame`, using the given join expression. @@ -532,6 +536,7 @@ def join(self, other, joinExprs=None, joinType=None): return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix + @since(1.3) def sort(self, *cols, **kwargs): """Returns a new :class:`DataFrame` sorted by the specified column(s). @@ -591,12 +596,16 @@ def _jcols(self, *cols): cols = cols[0] return self._jseq(cols, _to_java_column) + @since("1.3.1") def describe(self, *cols): """Computes statistics for numeric columns. This include count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical columns. + .. note:: This function is meant for exploratory data analysis, as we make no \ + guarantee about the backward compatibility of the schema of the resulting DataFrame. + >>> df.describe().show() +-------+---+ |summary|age| @@ -612,10 +621,13 @@ def describe(self, *cols): return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix + @since(1.3) def head(self, n=None): - """ - Returns the first ``n`` rows as a list of :class:`Row`, - or the first :class:`Row` if ``n`` is ``None.`` + """Returns the first ``n`` rows. + + :param n: int, default 1. Number of rows to return. + :return: If n is greater than 1, return a list of :class:`Row`. + If n is 1, return a single Row. >>> df.head() Row(age=2, name=u'Alice') @@ -628,6 +640,7 @@ def head(self, n=None): return self.take(n) @ignore_unicode_prefix + @since(1.3) def first(self): """Returns the first row as a :class:`Row`. @@ -637,6 +650,7 @@ def first(self): return self.head() @ignore_unicode_prefix + @since(1.3) def __getitem__(self, item): """Returns the column as a :class:`Column`. @@ -664,6 +678,7 @@ def __getitem__(self, item): else: raise TypeError("unexpected item type: %s" % type(item)) + @since(1.3) def __getattr__(self, name): """Returns the :class:`Column` denoted by ``name``. @@ -677,6 +692,7 @@ def __getattr__(self, name): return Column(jc) @ignore_unicode_prefix + @since(1.3) def select(self, *cols): """Projects a set of expressions and returns a new :class:`DataFrame`. @@ -694,6 +710,7 @@ def select(self, *cols): jdf = self._jdf.select(self._jcols(*cols)) return DataFrame(jdf, self.sql_ctx) + @since(1.3) def selectExpr(self, *expr): """Projects a set of SQL expressions and returns a new :class:`DataFrame`. @@ -708,6 +725,7 @@ def selectExpr(self, *expr): return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix + @since(1.3) def filter(self, condition): """Filters rows using the given condition. @@ -737,6 +755,7 @@ def filter(self, condition): where = filter @ignore_unicode_prefix + @since(1.3) def groupBy(self, *cols): """Groups the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedData` @@ -756,9 +775,55 @@ def groupBy(self, *cols): >>> df.groupBy(['name', df.age]).count().collect() [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] """ - jdf = self._jdf.groupBy(self._jcols(*cols)) - return GroupedData(jdf, self.sql_ctx) - + jgd = self._jdf.groupBy(self._jcols(*cols)) + from pyspark.sql.group import GroupedData + return GroupedData(jgd, self.sql_ctx) + + @since(1.4) + def rollup(self, *cols): + """ + Create a multi-dimensional rollup for the current :class:`DataFrame` using + the specified columns, so we can run aggregation on them. + + >>> df.rollup('name', df.age).count().show() + +-----+----+-----+ + | name| age|count| + +-----+----+-----+ + |Alice|null| 1| + | Bob| 5| 1| + | Bob|null| 1| + | null|null| 2| + |Alice| 2| 1| + +-----+----+-----+ + """ + jgd = self._jdf.rollup(self._jcols(*cols)) + from pyspark.sql.group import GroupedData + return GroupedData(jgd, self.sql_ctx) + + @since(1.4) + def cube(self, *cols): + """ + Create a multi-dimensional cube for the current :class:`DataFrame` using + the specified columns, so we can run aggregation on them. + + >>> df.cube('name', df.age).count().show() + +-----+----+-----+ + | name| age|count| + +-----+----+-----+ + | null| 2| 1| + |Alice|null| 1| + | Bob| 5| 1| + | Bob|null| 1| + | null| 5| 1| + | null|null| 2| + |Alice| 2| 1| + +-----+----+-----+ + """ + jgd = self._jdf.cube(self._jcols(*cols)) + from pyspark.sql.group import GroupedData + return GroupedData(jgd, self.sql_ctx) + + @since(1.3) def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups (shorthand for ``df.groupBy.agg()``). @@ -771,6 +836,7 @@ def agg(self, *exprs): """ return self.groupBy().agg(*exprs) + @since(1.3) def unionAll(self, other): """ Return a new :class:`DataFrame` containing union of rows in this frame and another frame. @@ -779,6 +845,7 @@ def unionAll(self, other): """ return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) + @since(1.3) def intersect(self, other): """ Return a new :class:`DataFrame` containing rows only in both this frame and another frame. @@ -787,6 +854,7 @@ def intersect(self, other): """ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + @since(1.3) def subtract(self, other): """ Return a new :class:`DataFrame` containing rows in this frame but not in another frame. @@ -795,6 +863,7 @@ def subtract(self, other): """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + @since(1.4) def dropDuplicates(self, subset=None): """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. @@ -825,6 +894,7 @@ def dropDuplicates(self, subset=None): jdf = self._jdf.dropDuplicates(self._jseq(subset)) return DataFrame(jdf, self.sql_ctx) + @since("1.3.1") def dropna(self, how='any', thresh=None, subset=None): """Returns a new :class:`DataFrame` omitting rows with null values. @@ -867,6 +937,7 @@ def dropna(self, how='any', thresh=None, subset=None): return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx) + @since("1.3.1") def fillna(self, value, subset=None): """Replace null values, alias for ``na.fill()``. @@ -928,6 +999,7 @@ def fillna(self, value, subset=None): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) + @since(1.4) def replace(self, to_replace, value, subset=None): """Returns a new :class:`DataFrame` replacing a value with another value. @@ -944,6 +1016,7 @@ def replace(self, to_replace, value, subset=None): Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, then the non-string column is simply ignored. + >>> df4.replace(10, 20).show() +----+------+-----+ | age|height| name| @@ -1002,6 +1075,7 @@ def replace(self, to_replace, value, subset=None): return DataFrame( self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) + @since(1.4) def corr(self, col1, col2, method=None): """ Calculates the correlation of two columns of a DataFrame as a double value. Currently only @@ -1023,6 +1097,7 @@ def corr(self, col1, col2, method=None): "coefficient is supported.") return self._jdf.stat().corr(col1, col2, method) + @since(1.4) def cov(self, col1, col2): """ Calculate the sample covariance for the given columns, specified by their names, as a @@ -1037,6 +1112,7 @@ def cov(self, col1, col2): raise ValueError("col2 should be a string.") return self._jdf.stat().cov(col1, col2) + @since(1.4) def crosstab(self, col1, col2): """ Computes a pair-wise frequency table of the given columns. Also known as a contingency @@ -1058,6 +1134,7 @@ def crosstab(self, col1, col2): raise ValueError("col2 should be a string.") return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) + @since(1.4) def freqItems(self, cols, support=None): """ Finding frequent items for columns, possibly with false positives. Using the @@ -1065,6 +1142,9 @@ def freqItems(self, cols, support=None): "http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou". :func:`DataFrame.freqItems` and :func:`DataFrameStatFunctions.freqItems` are aliases. + .. note:: This function is meant for exploratory data analysis, as we make no \ + guarantee about the backward compatibility of the schema of the resulting DataFrame. + :param cols: Names of the columns to calculate frequent items for as a list or tuple of strings. :param support: The frequency with which to consider an item 'frequent'. Default is 1%. @@ -1079,6 +1159,7 @@ def freqItems(self, cols, support=None): return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sql_ctx) @ignore_unicode_prefix + @since(1.3) def withColumn(self, colName, col): """Returns a new :class:`DataFrame` by adding a column. @@ -1091,6 +1172,7 @@ def withColumn(self, colName, col): return self.select('*', col.alias(colName)) @ignore_unicode_prefix + @since(1.3) def withColumnRenamed(self, existing, new): """Returns a new :class:`DataFrame` by renaming an existing column. @@ -1105,6 +1187,7 @@ def withColumnRenamed(self, existing, new): for c in self.columns] return self.select(*cols) + @since(1.4) @ignore_unicode_prefix def drop(self, colName): """Returns a new :class:`DataFrame` that drops the specified column. @@ -1117,6 +1200,7 @@ def drop(self, colName): jdf = self._jdf.drop(colName) return DataFrame(jdf, self.sql_ctx) + @since(1.3) def toPandas(self): """Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``. @@ -1141,169 +1225,6 @@ class SchemaRDD(DataFrame): """ -def dfapi(f): - def _api(self): - name = f.__name__ - jdf = getattr(self._jdf, name)() - return DataFrame(jdf, self.sql_ctx) - _api.__name__ = f.__name__ - _api.__doc__ = f.__doc__ - return _api - - -def df_varargs_api(f): - def _api(self, *args): - name = f.__name__ - jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) - return DataFrame(jdf, self.sql_ctx) - _api.__name__ = f.__name__ - _api.__doc__ = f.__doc__ - return _api - - -class GroupedData(object): - """ - A set of methods for aggregations on a :class:`DataFrame`, - created by :func:`DataFrame.groupBy`. - """ - - def __init__(self, jdf, sql_ctx): - self._jdf = jdf - self.sql_ctx = sql_ctx - - @ignore_unicode_prefix - def agg(self, *exprs): - """Compute aggregates and returns the result as a :class:`DataFrame`. - - The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. - - If ``exprs`` is a single :class:`dict` mapping from string to string, then the key - is the column to perform aggregation on, and the value is the aggregate function. - - Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. - - :param exprs: a dict mapping from column name (string) to aggregate functions (string), - or a list of :class:`Column`. - - >>> gdf = df.groupBy(df.name) - >>> gdf.agg({"*": "count"}).collect() - [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] - - >>> from pyspark.sql import functions as F - >>> gdf.agg(F.min(df.age)).collect() - [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)] - """ - assert exprs, "exprs should not be empty" - if len(exprs) == 1 and isinstance(exprs[0], dict): - jdf = self._jdf.agg(exprs[0]) - else: - # Columns - assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jdf = self._jdf.agg(exprs[0]._jc, - _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) - return DataFrame(jdf, self.sql_ctx) - - @dfapi - def count(self): - """Counts the number of records for each group. - - >>> df.groupBy(df.age).count().collect() - [Row(age=2, count=1), Row(age=5, count=1)] - """ - - @df_varargs_api - def mean(self, *cols): - """Computes average values for each numeric columns for each group. - - :func:`mean` is an alias for :func:`avg`. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().mean('age').collect() - [Row(AVG(age)=3.5)] - >>> df3.groupBy().mean('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] - """ - - @df_varargs_api - def avg(self, *cols): - """Computes average values for each numeric columns for each group. - - :func:`mean` is an alias for :func:`avg`. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().avg('age').collect() - [Row(AVG(age)=3.5)] - >>> df3.groupBy().avg('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] - """ - - @df_varargs_api - def max(self, *cols): - """Computes the max value for each numeric columns for each group. - - >>> df.groupBy().max('age').collect() - [Row(MAX(age)=5)] - >>> df3.groupBy().max('age', 'height').collect() - [Row(MAX(age)=5, MAX(height)=85)] - """ - - @df_varargs_api - def min(self, *cols): - """Computes the min value for each numeric column for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().min('age').collect() - [Row(MIN(age)=2)] - >>> df3.groupBy().min('age', 'height').collect() - [Row(MIN(age)=2, MIN(height)=80)] - """ - - @df_varargs_api - def sum(self, *cols): - """Compute the sum for each numeric columns for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().sum('age').collect() - [Row(SUM(age)=7)] - >>> df3.groupBy().sum('age', 'height').collect() - [Row(SUM(age)=7, SUM(height)=165)] - """ - - -def _create_column_from_literal(literal): - sc = SparkContext._active_spark_context - return sc._jvm.functions.lit(literal) - - -def _create_column_from_name(name): - sc = SparkContext._active_spark_context - return sc._jvm.functions.col(name) - - -def _to_java_column(col): - if isinstance(col, Column): - jcol = col._jc - else: - jcol = _create_column_from_name(col) - return jcol - - -def _to_seq(sc, cols, converter=None): - """ - Convert a list of Column (or names) into a JVM Seq of Column. - - An optional `converter` could be used to convert items in `cols` - into JVM Column objects. - """ - if converter: - cols = [converter(c) for c in cols] - return sc._jvm.PythonUtils.toSeq(cols) - - def _to_scala_map(sc, jm): """ Convert a dict into a JVM Map. @@ -1311,276 +1232,6 @@ def _to_scala_map(sc, jm): return sc._jvm.PythonUtils.toScalaMap(jm) -def _unary_op(name, doc="unary operator"): - """ Create a method for given unary operator """ - def _(self): - jc = getattr(self._jc, name)() - return Column(jc) - _.__doc__ = doc - return _ - - -def _func_op(name, doc=''): - def _(self): - sc = SparkContext._active_spark_context - jc = getattr(sc._jvm.functions, name)(self._jc) - return Column(jc) - _.__doc__ = doc - return _ - - -def _bin_op(name, doc="binary operator"): - """ Create a method for given binary operator - """ - def _(self, other): - jc = other._jc if isinstance(other, Column) else other - njc = getattr(self._jc, name)(jc) - return Column(njc) - _.__doc__ = doc - return _ - - -def _reverse_op(name, doc="binary operator"): - """ Create a method for binary operator (this object is on right side) - """ - def _(self, other): - jother = _create_column_from_literal(other) - jc = getattr(jother, name)(self._jc) - return Column(jc) - _.__doc__ = doc - return _ - - -class Column(object): - - """ - A column in a DataFrame. - - :class:`Column` instances can be created by:: - - # 1. Select a column out of a DataFrame - - df.colName - df["colName"] - - # 2. Create from an expression - df.colName + 1 - 1 / df.colName - """ - - def __init__(self, jc): - self._jc = jc - - # arithmetic operators - __neg__ = _func_op("negate") - __add__ = _bin_op("plus") - __sub__ = _bin_op("minus") - __mul__ = _bin_op("multiply") - __div__ = _bin_op("divide") - __truediv__ = _bin_op("divide") - __mod__ = _bin_op("mod") - __radd__ = _bin_op("plus") - __rsub__ = _reverse_op("minus") - __rmul__ = _bin_op("multiply") - __rdiv__ = _reverse_op("divide") - __rtruediv__ = _reverse_op("divide") - __rmod__ = _reverse_op("mod") - - # logistic operators - __eq__ = _bin_op("equalTo") - __ne__ = _bin_op("notEqual") - __lt__ = _bin_op("lt") - __le__ = _bin_op("leq") - __ge__ = _bin_op("geq") - __gt__ = _bin_op("gt") - - # `and`, `or`, `not` cannot be overloaded in Python, - # so use bitwise operators as boolean operators - __and__ = _bin_op('and') - __or__ = _bin_op('or') - __invert__ = _func_op('not') - __rand__ = _bin_op("and") - __ror__ = _bin_op("or") - - # container operators - __contains__ = _bin_op("contains") - __getitem__ = _bin_op("apply") - - # bitwise operators - bitwiseOR = _bin_op("bitwiseOR") - bitwiseAND = _bin_op("bitwiseAND") - bitwiseXOR = _bin_op("bitwiseXOR") - - def getItem(self, key): - """An expression that gets an item at position `ordinal` out of a list, - or gets an item by key out of a dict. - - >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) - >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() - +----+------+ - |l[0]|d[key]| - +----+------+ - | 1| value| - +----+------+ - >>> df.select(df.l[0], df.d["key"]).show() - +----+------+ - |l[0]|d[key]| - +----+------+ - | 1| value| - +----+------+ - """ - return self[key] - - def getField(self, name): - """An expression that gets a field by name in a StructField. - - >>> from pyspark.sql import Row - >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() - >>> df.select(df.r.getField("b")).show() - +----+ - |r[b]| - +----+ - | b| - +----+ - >>> df.select(df.r.a).show() - +----+ - |r[a]| - +----+ - | 1| - +----+ - """ - return self[name] - - def __getattr__(self, item): - if item.startswith("__"): - raise AttributeError(item) - return self.getField(item) - - # string methods - rlike = _bin_op("rlike") - like = _bin_op("like") - startswith = _bin_op("startsWith") - endswith = _bin_op("endsWith") - - @ignore_unicode_prefix - def substr(self, startPos, length): - """ - Return a :class:`Column` which is a substring of the column - - :param startPos: start position (int or Column) - :param length: length of the substring (int or Column) - - >>> df.select(df.name.substr(1, 3).alias("col")).collect() - [Row(col=u'Ali'), Row(col=u'Bob')] - """ - if type(startPos) != type(length): - raise TypeError("Can not mix the type") - if isinstance(startPos, (int, long)): - jc = self._jc.substr(startPos, length) - elif isinstance(startPos, Column): - jc = self._jc.substr(startPos._jc, length._jc) - else: - raise TypeError("Unexpected type: %s" % type(startPos)) - return Column(jc) - - __getslice__ = substr - - @ignore_unicode_prefix - def inSet(self, *cols): - """ A boolean expression that is evaluated to true if the value of this - expression is contained by the evaluated values of the arguments. - - >>> df[df.name.inSet("Bob", "Mike")].collect() - [Row(age=5, name=u'Bob')] - >>> df[df.age.inSet([1, 2, 3])].collect() - [Row(age=2, name=u'Alice')] - """ - if len(cols) == 1 and isinstance(cols[0], (list, set)): - cols = cols[0] - cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] - sc = SparkContext._active_spark_context - jc = getattr(self._jc, "in")(_to_seq(sc, cols)) - return Column(jc) - - # order - asc = _unary_op("asc", "Returns a sort expression based on the" - " ascending order of the given column name.") - desc = _unary_op("desc", "Returns a sort expression based on the" - " descending order of the given column name.") - - isNull = _unary_op("isNull", "True if the current expression is null.") - isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") - - def alias(self, alias): - """Return a alias for this column - - >>> df.select(df.age.alias("age2")).collect() - [Row(age2=2), Row(age2=5)] - """ - return Column(getattr(self._jc, "as")(alias)) - - @ignore_unicode_prefix - def cast(self, dataType): - """ Convert the column into type `dataType` - - >>> df.select(df.age.cast("string").alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] - >>> df.select(df.age.cast(StringType()).alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] - """ - if isinstance(dataType, basestring): - jc = self._jc.cast(dataType) - elif isinstance(dataType, DataType): - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(dataType.json()) - jc = self._jc.cast(jdt) - else: - raise TypeError("unexpected type: %s" % type(dataType)) - return Column(jc) - - @ignore_unicode_prefix - def between(self, lowerBound, upperBound): - """ A boolean expression that is evaluated to true if the value of this - expression is between the given columns. - """ - return (self >= lowerBound) & (self <= upperBound) - - @ignore_unicode_prefix - def when(self, condition, value): - """Evaluates a list of conditions and returns one of multiple possible result expressions. - If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. - - See :func:`pyspark.sql.functions.when` for example usage. - - :param condition: a boolean :class:`Column` expression. - :param value: a literal value, or a :class:`Column` expression. - - """ - sc = SparkContext._active_spark_context - if not isinstance(condition, Column): - raise TypeError("condition should be a Column") - v = value._jc if isinstance(value, Column) else value - jc = sc._jvm.functions.when(condition._jc, v) - return Column(jc) - - @ignore_unicode_prefix - def otherwise(self, value): - """Evaluates a list of conditions and returns one of multiple possible result expressions. - If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. - - See :func:`pyspark.sql.functions.when` for example usage. - - :param value: a literal value, or a :class:`Column` expression. - """ - v = value._jc if isinstance(value, Column) else value - jc = self._jc.otherwise(value) - return Column(jc) - - def __repr__(self): - return 'Column<%s>' % self._jc.toString().encode('utf8') - - class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. """ @@ -1640,9 +1291,6 @@ def _test(): .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() - globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), - Row(name='Bob', age=5, height=85)]).toDF() - globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d91265ee0bec8..bbf465aca8d4d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -26,21 +26,27 @@ from pyspark import SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.sql import since from pyspark.sql.types import StringType -from pyspark.sql.dataframe import Column, _to_java_column, _to_seq +from pyspark.sql.column import Column, _to_java_column, _to_seq __all__ = [ + 'array', 'approxCountDistinct', 'coalesce', 'countDistinct', + 'explode', 'monotonicallyIncreasingId', 'rand', 'randn', 'sparkPartitionId', + 'struct', 'udf', 'when'] +__all__ += ['lag', 'lead', 'ntile'] + def _create_function(name, doc=""): """ Create a function for aggregator by name""" @@ -66,6 +72,17 @@ def _(col1, col2): return _ +def _create_window_function(name, doc=''): + """ Create a window function by name """ + def _(): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)() + return Column(jc) + _.__name__ = name + _.__doc__ = 'Window function: ' + doc + return _ + + _functions = { 'lit': 'Creates a :class:`Column` of literal value.', 'col': 'Returns a :class:`Column` based on the given column name.', @@ -78,6 +95,18 @@ def _(col1, col2): 'sqrt': 'Computes the square root of the specified float value.', 'abs': 'Computes the absolute value.', + 'max': 'Aggregate function: returns the maximum value of the expression in a group.', + 'min': 'Aggregate function: returns the minimum value of the expression in a group.', + 'first': 'Aggregate function: returns the first value in a group.', + 'last': 'Aggregate function: returns the last value in a group.', + 'count': 'Aggregate function: returns the number of items in a group.', + 'sum': 'Aggregate function: returns the sum of all values in the expression.', + 'avg': 'Aggregate function: returns the average of the values in a group.', + 'mean': 'Aggregate function: returns the average of the values in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', +} + +_functions_1_4 = { # unary math functions 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + '0.0 through pi.', @@ -102,21 +131,11 @@ def _(col1, col2): 'tan': 'Computes the tangent of the given value.', 'tanh': 'Computes the hyperbolic tangent of the given value.', 'toDegrees': 'Converts an angle measured in radians to an approximately equivalent angle ' + - 'measured in degrees.', + 'measured in degrees.', 'toRadians': 'Converts an angle measured in degrees to an approximately equivalent angle ' + - 'measured in radians.', + 'measured in radians.', 'bitwiseNOT': 'Computes bitwise not.', - - 'max': 'Aggregate function: returns the maximum value of the expression in a group.', - 'min': 'Aggregate function: returns the minimum value of the expression in a group.', - 'first': 'Aggregate function: returns the first value in a group.', - 'last': 'Aggregate function: returns the last value in a group.', - 'count': 'Aggregate function: returns the number of items in a group.', - 'sum': 'Aggregate function: returns the sum of all values in the expression.', - 'avg': 'Aggregate function: returns the average of the values in a group.', - 'mean': 'Aggregate function: returns the average of the values in a group.', - 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', } # math functions that take two arguments as input @@ -127,16 +146,57 @@ def _(col1, col2): 'pow': 'Returns the value of the first argument raised to the power of the second argument.' } +_window_functions = { + 'rowNumber': + """returns a sequential number starting at 1 within a window partition. + + This is equivalent to the ROW_NUMBER function in SQL.""", + 'denseRank': + """returns the rank of rows within a window partition, without any gaps. + + The difference between rank and denseRank is that denseRank leaves no gaps in ranking + sequence when there are ties. That is, if you were ranking a competition using denseRank + and had three people tie for second place, you would say that all three were in second + place and that the next person came in third. + + This is equivalent to the DENSE_RANK function in SQL.""", + 'rank': + """returns the rank of rows within a window partition. + + The difference between rank and denseRank is that denseRank leaves no gaps in ranking + sequence when there are ties. That is, if you were ranking a competition using denseRank + and had three people tie for second place, you would say that all three were in second + place and that the next person came in third. + + This is equivalent to the RANK function in SQL.""", + 'cumeDist': + """returns the cumulative distribution of values within a window partition, + i.e. the fraction of rows that are below the current row. + + This is equivalent to the CUME_DIST function in SQL.""", + 'percentRank': + """returns the relative rank (i.e. percentile) of rows within a window partition. + + This is equivalent to the PERCENT_RANK function in SQL.""", +} + for _name, _doc in _functions.items(): - globals()[_name] = _create_function(_name, _doc) + globals()[_name] = since(1.3)(_create_function(_name, _doc)) +for _name, _doc in _functions_1_4.items(): + globals()[_name] = since(1.4)(_create_function(_name, _doc)) for _name, _doc in _binary_mathfunctions.items(): - globals()[_name] = _create_binary_mathfunction(_name, _doc) + globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc)) +for _name, _doc in _window_functions.items(): + globals()[_name] = since(1.4)(_create_window_function(_name, _doc)) del _name, _doc __all__ += _functions.keys() +__all__ += _functions_1_4.keys() __all__ += _binary_mathfunctions.keys() +__all__ += _window_functions.keys() __all__.sort() +@since(1.4) def array(*cols): """Creates a new array column. @@ -155,6 +215,7 @@ def array(*cols): return Column(jc) +@since(1.3) def approxCountDistinct(col, rsd=None): """Returns a new :class:`Column` for approximate distinct count of ``col``. @@ -169,6 +230,7 @@ def approxCountDistinct(col, rsd=None): return Column(jc) +@since(1.4) def coalesce(*cols): """Returns the first column that is not null. @@ -205,6 +267,7 @@ def coalesce(*cols): return Column(jc) +@since(1.3) def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. @@ -219,6 +282,28 @@ def countDistinct(col, *cols): return Column(jc) +@since(1.4) +def explode(col): + """Returns a new row for each element in the given array or map. + + >>> from pyspark.sql import Row + >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() + [Row(anInt=1), Row(anInt=2), Row(anInt=3)] + + >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() + +---+-----+ + |key|value| + +---+-----+ + | a| b| + +---+-----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode(_to_java_column(col)) + return Column(jc) + + +@since(1.4) def monotonicallyIncreasingId(): """A column that generates monotonically increasing 64-bit integers. @@ -227,7 +312,7 @@ def monotonicallyIncreasingId(): within each partition in the lower 33 bits. The assumption is that the data frame has less than 1 billion partitions, and each partition has less than 8 billion records. - As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + As an example, consider a :class:`DataFrame` with two partitions, each with 3 records. This expression would return the following IDs: 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. @@ -239,6 +324,7 @@ def monotonicallyIncreasingId(): return Column(sc._jvm.functions.monotonicallyIncreasingId()) +@since(1.4) def rand(seed=None): """Generates a random column with i.i.d. samples from U[0.0, 1.0]. """ @@ -250,6 +336,7 @@ def rand(seed=None): return Column(jc) +@since(1.4) def randn(seed=None): """Generates a column with i.i.d. samples from the standard normal distribution. """ @@ -261,6 +348,7 @@ def randn(seed=None): return Column(jc) +@since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. @@ -274,6 +362,7 @@ def sparkPartitionId(): @ignore_unicode_prefix +@since(1.4) def struct(*cols): """Creates a new struct column. @@ -292,6 +381,7 @@ def struct(*cols): return Column(jc) +@since(1.4) def when(condition, value): """Evaluates a list of conditions and returns one of multiple possible result expressions. If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. @@ -313,9 +403,60 @@ def when(condition, value): return Column(jc) +@since(1.4) +def lag(col, count=1, default=None): + """ + Window function: returns the value that is `offset` rows before the current row, and + `defaultValue` if there is less than `offset` rows before the current row. For example, + an `offset` of one will return the previous row at any given point in the window partition. + + This is equivalent to the LAG function in SQL. + + :param col: name of column or expression + :param count: number of row to extend + :param default: default value + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.lag(_to_java_column(col), count, default)) + + +@since(1.4) +def lead(col, count=1, default=None): + """ + Window function: returns the value that is `offset` rows after the current row, and + `defaultValue` if there is less than `offset` rows after the current row. For example, + an `offset` of one will return the next row at any given point in the window partition. + + This is equivalent to the LEAD function in SQL. + + :param col: name of column or expression + :param count: number of row to extend + :param default: default value + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.lead(_to_java_column(col), count, default)) + + +@since(1.4) +def ntile(n): + """ + Window function: returns a group id from 1 to `n` (inclusive) in a round-robin fashion in + a window partition. Fow example, if `n` is 3, the first row will get 1, the second row will + get 2, the third row will get 3, and the fourth row will get 1... + + This is equivalent to the NTILE function in SQL. + + :param n: an integer + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.ntile(int(n))) + + class UserDefinedFunction(object): """ User defined function in Python + + .. versionadded:: 1.3 """ def __init__(self, func, returnType): self.func = func @@ -333,8 +474,8 @@ def _create_judf(self): ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ - judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, - includes, sc.pythonExec, broadcast_vars, + judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes, + sc.pythonExec, sc.pythonVer, broadcast_vars, sc._javaAccumulator, jdt) return judf @@ -349,6 +490,7 @@ def __call__(self, *cols): return Column(jc) +@since(1.3) def udf(f, returnType=StringType()): """Creates a :class:`Column` expression representing a user defined function (UDF). diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py new file mode 100644 index 0000000000000..5a37a673ee80c --- /dev/null +++ b/python/pyspark/sql/group.py @@ -0,0 +1,195 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql import since +from pyspark.sql.column import Column, _to_seq +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.types import * + +__all__ = ["GroupedData"] + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +def df_varargs_api(f): + def _api(self, *args): + name = f.__name__ + jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedData(object): + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by :func:`DataFrame.groupBy`. + + .. note:: Experimental + + .. versionadded:: 1.3 + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + @ignore_unicode_prefix + @since(1.3) + def agg(self, *exprs): + """Compute aggregates and returns the result as a :class:`DataFrame`. + + The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + + If ``exprs`` is a single :class:`dict` mapping from string to string, then the key + is the column to perform aggregation on, and the value is the aggregate function. + + Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + + :param exprs: a dict mapping from column name (string) to aggregate functions (string), + or a list of :class:`Column`. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"*": "count"}).collect() + [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] + + >>> from pyspark.sql import functions as F + >>> gdf.agg(F.min(df.age)).collect() + [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)] + """ + assert exprs, "exprs should not be empty" + if len(exprs) == 1 and isinstance(exprs[0], dict): + jdf = self._jdf.agg(exprs[0]) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" + jdf = self._jdf.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + @since(1.3) + def count(self): + """Counts the number of records for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ + + @df_varargs_api + @since(1.3) + def mean(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().mean('age').collect() + [Row(AVG(age)=3.5)] + >>> df3.groupBy().mean('age', 'height').collect() + [Row(AVG(age)=3.5, AVG(height)=82.5)] + """ + + @df_varargs_api + @since(1.3) + def avg(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().avg('age').collect() + [Row(AVG(age)=3.5)] + >>> df3.groupBy().avg('age', 'height').collect() + [Row(AVG(age)=3.5, AVG(height)=82.5)] + """ + + @df_varargs_api + @since(1.3) + def max(self, *cols): + """Computes the max value for each numeric columns for each group. + + >>> df.groupBy().max('age').collect() + [Row(MAX(age)=5)] + >>> df3.groupBy().max('age', 'height').collect() + [Row(MAX(age)=5, MAX(height)=85)] + """ + + @df_varargs_api + @since(1.3) + def min(self, *cols): + """Computes the min value for each numeric column for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().min('age').collect() + [Row(MIN(age)=2)] + >>> df3.groupBy().min('age', 'height').collect() + [Row(MIN(age)=2, MIN(height)=80)] + """ + + @df_varargs_api + @since(1.3) + def sum(self, *cols): + """Compute the sum for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().sum('age').collect() + [Row(SUM(age)=7)] + >>> df3.groupBy().sum('age', 'height').collect() + [Row(SUM(age)=7, SUM(height)=165)] + """ + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.group + globs = pyspark.sql.group.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), + Row(name='Bob', age=5, height=85)]).toDF() + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.group, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py new file mode 100644 index 0000000000000..f036644acc961 --- /dev/null +++ b/python/pyspark/sql/readwriter.py @@ -0,0 +1,407 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import JavaClass + +from pyspark.sql import since +from pyspark.sql.column import _to_seq +from pyspark.sql.types import * + +__all__ = ["DataFrameReader", "DataFrameWriter"] + + +class DataFrameReader(object): + """ + Interface used to load a :class:`DataFrame` from external storage systems + (e.g. file systems, key-value stores, etc). Use :func:`SQLContext.read` + to access this. + + ::Note: Experimental + + .. versionadded:: 1.4 + """ + + def __init__(self, sqlContext): + self._jreader = sqlContext._ssql_ctx.read() + self._sqlContext = sqlContext + + def _df(self, jdf): + from pyspark.sql.dataframe import DataFrame + return DataFrame(jdf, self._sqlContext) + + @since(1.4) + def format(self, source): + """Specifies the input data source format. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json') + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + + """ + self._jreader = self._jreader.format(source) + return self + + @since(1.4) + def schema(self, schema): + """Specifies the input schema. + + Some data sources (e.g. JSON) can infer the input schema automatically from data. + By specifying the schema here, the underlying data source can skip the schema + inference step, and thus speed up data loading. + + :param schema: a StructType object + """ + if not isinstance(schema, StructType): + raise TypeError("schema should be StructType") + jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) + self._jreader = self._jreader.schema(jschema) + return self + + @since(1.4) + def options(self, **options): + """Adds input options for the underlying data source. + """ + for k in options: + self._jreader = self._jreader.option(k, options[k]) + return self + + @since(1.4) + def load(self, path=None, format=None, schema=None, **options): + """Loads data from a data source and returns it as a :class`DataFrame`. + + :param path: optional string for file-system backed data sources. + :param format: optional string for format of the data source. Default to 'parquet'. + :param schema: optional :class:`StructType` for the input schema. + :param options: all other string options + + >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned') + >>> df.dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + """ + if format is not None: + self.format(format) + if schema is not None: + self.schema(schema) + self.options(**options) + if path is not None: + return self._df(self._jreader.load(path)) + else: + return self._df(self._jreader.load()) + + @since(1.4) + def json(self, path, schema=None): + """ + Loads a JSON file (one object per line) and returns the result as + a :class`DataFrame`. + + If the ``schema`` parameter is not specified, this function goes + through the input once to determine the input schema. + + :param path: string, path to the JSON dataset. + :param schema: an optional :class:`StructType` for the input schema. + + >>> df = sqlContext.read.json('python/test_support/sql/people.json') + >>> df.dtypes + [('age', 'bigint'), ('name', 'string')] + + """ + if schema is not None: + self.schema(schema) + return self._df(self._jreader.json(path)) + + @since(1.4) + def table(self, tableName): + """Returns the specified table as a :class:`DataFrame`. + + :param tableName: string, name of the table. + + >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.registerTempTable('tmpTable') + >>> sqlContext.read.table('tmpTable').dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + """ + return self._df(self._jreader.table(tableName)) + + @since(1.4) + def parquet(self, *path): + """Loads a Parquet file, returning the result as a :class:`DataFrame`. + + >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df.dtypes + [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] + """ + return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path))) + + @since(1.4) + def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None, + predicates=None, properties={}): + """ + Construct a :class:`DataFrame` representing the database table accessible + via JDBC URL `url` named `table` and connection `properties`. + + The `column` parameter could be used to partition the table, then it will + be retrieved in parallel based on the parameters passed to this function. + + The `predicates` parameter gives a list expressions suitable for inclusion + in WHERE clauses; each one defines one partition of the :class:`DataFrame`. + + ::Note: Don't create too many partitions in parallel on a large cluster; + otherwise Spark might crash your external database systems. + + :param url: a JDBC URL + :param table: name of table + :param column: the column used to partition + :param lowerBound: the lower bound of partition column + :param upperBound: the upper bound of the partition column + :param numPartitions: the number of partitions + :param predicates: a list of expressions + :param properties: JDBC database connection arguments, a list of arbitrary string + tag/value. Normally at least a "user" and "password" property + should be included. + :return: a DataFrame + """ + jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() + for k in properties: + jprop.setProperty(k, properties[k]) + if column is not None: + if numPartitions is None: + numPartitions = self._sqlContext._sc.defaultParallelism + return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), + int(numPartitions), jprop)) + if predicates is not None: + arr = self._sqlContext._sc._jvm.PythonUtils.toArray(predicates) + return self._df(self._jreader.jdbc(url, table, arr, jprop)) + return self._df(self._jreader.jdbc(url, table, jprop)) + + +class DataFrameWriter(object): + """ + Interface used to write a [[DataFrame]] to external storage systems + (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write` + to access this. + + ::Note: Experimental + + .. versionadded:: 1.4 + """ + def __init__(self, df): + self._df = df + self._sqlContext = df.sql_ctx + self._jwrite = df._jdf.write() + + @since(1.4) + def mode(self, saveMode): + """Specifies the behavior when data or table already exists. + + Options include: + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + + >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self._jwrite = self._jwrite.mode(saveMode) + return self + + @since(1.4) + def format(self, source): + """Specifies the underlying output data source. + + :param source: string, name of the data source, e.g. 'json', 'parquet'. + + >>> df.write.format('json').save(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self._jwrite = self._jwrite.format(source) + return self + + @since(1.4) + def options(self, **options): + """Adds output options for the underlying data source. + """ + for k in options: + self._jwrite = self._jwrite.option(k, options[k]) + return self + + @since(1.4) + def partitionBy(self, *cols): + """Partitions the output by the given columns on the file system. + + If specified, the output is laid out on the file system similar + to Hive's partitioning scheme. + + :param cols: name of columns + + >>> df.write.partitionBy('year', 'month').parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + if len(cols) == 1 and isinstance(cols[0], (list, tuple)): + cols = cols[0] + self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) + return self + + @since(1.4) + def save(self, path=None, format=None, mode="error", **options): + """Saves the contents of the :class:`DataFrame` to a data source. + + The data source is specified by the ``format`` and a set of ``options``. + If ``format`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + :param path: the path in a Hadoop supported file system + :param format: the format used to save + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param options: all other string options + + >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode).options(**options) + if format is not None: + self.format(format) + if path is None: + self._jwrite.save() + else: + self._jwrite.save(path) + + @since(1.4) + def insertInto(self, tableName, overwrite=False): + """Inserts the content of the :class:`DataFrame` to the specified table. + + It requires that the schema of the class:`DataFrame` is the same as the + schema of the table. + + Optionally overwriting any existing data. + """ + self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName) + + @since(1.4) + def saveAsTable(self, name, format=None, mode="error", **options): + """Saves the content of the :class:`DataFrame` as the specified table. + + In the case the table already exists, behavior of this function depends on the + save mode, specified by the `mode` function (default to throwing an exception). + When `mode` is `Overwrite`, the schema of the [[DataFrame]] does not need to be + the same as that of the existing table. + + * `append`: Append contents of this :class:`DataFrame` to existing data. + * `overwrite`: Overwrite existing data. + * `error`: Throw an exception if data already exists. + * `ignore`: Silently ignore this operation if data already exists. + + :param name: the table name + :param format: the format used to save + :param mode: one of `append`, `overwrite`, `error`, `ignore` (default: error) + :param options: all other string options + """ + self.mode(mode).options(**options) + if format is not None: + self.format(format) + self._jwrite.saveAsTable(name) + + @since(1.4) + def json(self, path, mode="error"): + """Saves the content of the :class:`DataFrame` in JSON format at the specified path. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + + >>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self._jwrite.mode(mode).json(path) + + @since(1.4) + def parquet(self, path, mode="error"): + """Saves the content of the :class:`DataFrame` in Parquet format at the specified path. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + + >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self._jwrite.mode(mode).parquet(path) + + @since(1.4) + def jdbc(self, url, table, mode="error", properties={}): + """Saves the content of the :class:`DataFrame` to a external database table via JDBC. + + .. note:: Don't create too many partitions in parallel on a large cluster;\ + otherwise Spark might crash your external database systems. + + :param url: a JDBC URL of the form ``jdbc:subprotocol:subname`` + :param table: Name of the table in the external database. + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param properties: JDBC database connection arguments, a list of + arbitrary string tag/value. Normally at least a + "user" and "password" property should be included. + """ + jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() + for k in properties: + jprop.setProperty(k, properties[k]) + self._jwrite.mode(mode).jdbc(url, table, jprop) + + +def _test(): + import doctest + import os + import tempfile + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.readwriter + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.sql.readwriter.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + + globs['tempfile'] = tempfile + globs['os'] = os + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.readwriter, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1922d03af61da..6e498f0af0af5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -44,6 +44,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase from pyspark.sql.functions import UserDefinedFunction +from pyspark.sql.window import Window class ExamplePointUDT(UserDefinedType): @@ -99,6 +100,15 @@ def test_data_type_eq(self): lt2 = pickle.loads(pickle.dumps(LongType())) self.assertEquals(lt, lt2) + # regression test for SPARK-7978 + def test_decimal_type(self): + t1 = DecimalType() + t2 = DecimalType(10, 2) + self.assertTrue(t2 is not t1) + self.assertNotEqual(t1, t2) + t3 = DecimalType(8) + self.assertNotEqual(t2, t3) + class SQLTests(ReusedPySparkTestCase): @@ -117,6 +127,26 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_range(self): + self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) + self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) + self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2) + + def test_explode(self): + from pyspark.sql.functions import explode + d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + result = data.select(explode(data.intlist).alias("a")).select("a").collect() + self.assertEqual(result[0][0], 1) + self.assertEqual(result[1][0], 2) + self.assertEqual(result[2][0], 3) + + result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect() + self.assertEqual(result[0][0], "a") + self.assertEqual(result[0][1], "b") + def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) @@ -465,29 +495,29 @@ def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) - df.save(tmpPath, "org.apache.spark.sql.json", "error") - actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") - self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + df.write.json(tmpPath) + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema) - self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + actual = self.sqlCtx.read.json(tmpPath, schema) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - df.save(tmpPath, "org.apache.spark.sql.json", "overwrite") - actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json") - self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + df.write.json(tmpPath, "overwrite") + actual = self.sqlCtx.read.json(tmpPath) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath, - noUse="this options will not be used in save.") - actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath, - noUse="this options will not be used in load.") - self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + df.write.save(format="json", mode="overwrite", path=tmpPath, + noUse="this options will not be used in save.") + actual = self.sqlCtx.read.load(format="json", path=tmpPath, + noUse="this options will not be used in load.") + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") actual = self.sqlCtx.load(path=tmpPath) - self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) shutil.rmtree(tmpPath) @@ -723,11 +753,11 @@ def setUpClass(cls): try: cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf() except py4j.protocol.Py4JError: - cls.sqlCtx = None - return + cls.tearDownClass() + raise unittest.SkipTest("Hive is not available") except TypeError: - cls.sqlCtx = None - return + cls.tearDownClass() + raise unittest.SkipTest("Hive is not available") os.unlink(cls.tempdir.name) _scala_HiveContext =\ cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc()) @@ -741,57 +771,68 @@ def tearDownClass(cls): shutil.rmtree(cls.tempdir.name, ignore_errors=True) def test_save_and_load_table(self): - if self.sqlCtx is None: - return # no hive available, skipped - df = self.df tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) - df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath) - actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, - "org.apache.spark.sql.json") - self.assertTrue( - sorted(df.collect()) == - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) - self.assertTrue( - sorted(df.collect()) == - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) - self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath) + actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json") + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("DROP TABLE externalJsonTable") - df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath) + df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.createExternalTable("externalJsonTable", - source="org.apache.spark.sql.json", + actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json", schema=schema, path=tmpPath, noUse="this options will not be used") - self.assertTrue( - sorted(df.collect()) == - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) - self.assertTrue( - sorted(df.select("value").collect()) == - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) - self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.select("value").collect()), + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) self.sqlCtx.sql("DROP TABLE savedJsonTable") self.sqlCtx.sql("DROP TABLE externalJsonTable") defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") + df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath) - self.assertTrue( - sorted(df.collect()) == - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) - self.assertTrue( - sorted(df.collect()) == - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) - self.assertTrue(sorted(df.collect()) == sorted(actual.collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + self.assertEqual(sorted(df.collect()), + sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("DROP TABLE savedJsonTable") self.sqlCtx.sql("DROP TABLE externalJsonTable") self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) shutil.rmtree(tmpPath) + def test_window_functions(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + w = Window.partitionBy("value").orderBy("key") + from pyspark.sql import functions as F + sel = df.select(df.value, df.key, + F.max("key").over(w.rowsBetween(0, 1)), + F.min("key").over(w.rowsBetween(0, 1)), + F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), + F.rowNumber().over(w), + F.rank().over(w), + F.denseRank().over(w), + F.ntile(2).over(w)) + rs = sorted(sel.collect()) + expected = [ + ("1", 1, 1, 1, 1, 1, 1, 1, 1), + ("2", 1, 1, 1, 3, 1, 1, 1, 1), + ("2", 1, 2, 1, 3, 2, 1, 1, 1), + ("2", 2, 2, 2, 3, 3, 3, 2, 2) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/types.py similarity index 95% rename from python/pyspark/sql/_types.py rename to python/pyspark/sql/types.py index b96851a174d49..b6ec6137c9180 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/types.py @@ -73,56 +73,84 @@ def json(self): # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle -class PrimitiveTypeSingleton(type): - """Metaclass for PrimitiveType""" +class DataTypeSingleton(type): + """Metaclass for DataType""" _instances = {} def __call__(cls): if cls not in cls._instances: - cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + cls._instances[cls] = super(DataTypeSingleton, cls).__call__() return cls._instances[cls] -class PrimitiveType(DataType): - """Spark SQL PrimitiveType""" +class NullType(DataType): + """Null type. + + The data type representing None, used for the types that cannot be inferred. + """ - __metaclass__ = PrimitiveTypeSingleton + __metaclass__ = DataTypeSingleton -class NullType(PrimitiveType): - """Null type. +class AtomicType(DataType): + """An internal type used to represent everything that is not + null, UDTs, arrays, structs, and maps.""" - The data type representing None, used for the types that cannot be inferred. + +class NumericType(AtomicType): + """Numeric data types. + """ + + +class IntegralType(NumericType): + """Integral data types. + """ + + __metaclass__ = DataTypeSingleton + + +class FractionalType(NumericType): + """Fractional data types. """ -class StringType(PrimitiveType): +class StringType(AtomicType): """String data type. """ + __metaclass__ = DataTypeSingleton -class BinaryType(PrimitiveType): + +class BinaryType(AtomicType): """Binary (byte array) data type. """ + __metaclass__ = DataTypeSingleton + -class BooleanType(PrimitiveType): +class BooleanType(AtomicType): """Boolean data type. """ + __metaclass__ = DataTypeSingleton + -class DateType(PrimitiveType): +class DateType(AtomicType): """Date (datetime.date) data type. """ + __metaclass__ = DataTypeSingleton -class TimestampType(PrimitiveType): + +class TimestampType(AtomicType): """Timestamp (datetime.datetime) data type. """ + __metaclass__ = DataTypeSingleton + -class DecimalType(DataType): +class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. """ @@ -150,31 +178,35 @@ def __repr__(self): return "DecimalType()" -class DoubleType(PrimitiveType): +class DoubleType(FractionalType): """Double data type, representing double precision floats. """ + __metaclass__ = DataTypeSingleton + -class FloatType(PrimitiveType): +class FloatType(FractionalType): """Float data type, representing single precision floats. """ + __metaclass__ = DataTypeSingleton -class ByteType(PrimitiveType): + +class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte. """ def simpleString(self): return 'tinyint' -class IntegerType(PrimitiveType): +class IntegerType(IntegralType): """Int data type, i.e. a signed 32-bit integer. """ def simpleString(self): return 'int' -class LongType(PrimitiveType): +class LongType(IntegralType): """Long data type, i.e. a signed 64-bit integer. If the values are beyond the range of [-9223372036854775808, 9223372036854775807], @@ -184,7 +216,7 @@ def simpleString(self): return 'bigint' -class ShortType(PrimitiveType): +class ShortType(IntegralType): """Short data type, i.e. a signed 16-bit integer. """ def simpleString(self): @@ -426,11 +458,9 @@ def __eq__(self, other): return type(self) == type(other) -_all_primitive_types = dict((v.typeName(), v) - for v in list(globals().values()) - if (type(v) is type or type(v) is PrimitiveTypeSingleton) - and v.__base__ == PrimitiveType) - +_atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType, + ByteType, ShortType, IntegerType, LongType, DateType, TimestampType] +_all_atomic_types = dict((t.typeName(), t) for t in _atomic_types) _all_complex_types = dict((v.typeName(), v) for v in [ArrayType, MapType, StructType]) @@ -444,7 +474,7 @@ def _parse_datatype_json_string(json_string): ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype - >>> for cls in _all_primitive_types.values(): + >>> for cls in _all_atomic_types.values(): ... check_datatype(cls()) >>> # Simple ArrayType. @@ -494,8 +524,8 @@ def _parse_datatype_json_string(json_string): def _parse_datatype_json_value(json_value): if not isinstance(json_value, dict): - if json_value in _all_primitive_types.keys(): - return _all_primitive_types[json_value]() + if json_value in _all_atomic_types.keys(): + return _all_atomic_types[json_value]() elif json_value == 'decimal': return DecimalType() elif _FIXED_DECIMAL.match(json_value): @@ -930,7 +960,7 @@ def _infer_schema_type(obj, dataType): DecimalType: (decimal.Decimal,), StringType: (str, unicode), BinaryType: (bytearray,), - DateType: (datetime.date,), + DateType: (datetime.date, datetime.datetime), TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), @@ -1125,7 +1155,7 @@ def Dict(d): return lambda datum: dataType.deserialize(datum) elif not isinstance(dataType, StructType): - # no wrapper for primitive types + # no wrapper for atomic types return lambda x: x class Row(tuple): diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py new file mode 100644 index 0000000000000..0a0e006bdf83a --- /dev/null +++ b/python/pyspark/sql/window.py @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext +from pyspark.sql import since +from pyspark.sql.column import _to_seq, _to_java_column + +__all__ = ["Window", "WindowSpec"] + + +def _to_java_cols(cols): + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + return _to_seq(sc, cols, _to_java_column) + + +class Window(object): + + """ + Utility functions for defining window in DataFrames. + + For example: + + >>> # PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + >>> window = Window.partitionBy("country").orderBy("date").rowsBetween(-sys.maxsize, 0) + + >>> # PARTITION BY country ORDER BY date RANGE BETWEEN 3 PRECEDING AND 3 FOLLOWING + >>> window = Window.orderBy("date").partitionBy("country").rangeBetween(-3, 3) + + .. note:: Experimental + + .. versionadded:: 1.4 + """ + @staticmethod + @since(1.4) + def partitionBy(*cols): + """ + Creates a :class:`WindowSpec` with the partitioning defined. + """ + sc = SparkContext._active_spark_context + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols)) + return WindowSpec(jspec) + + @staticmethod + @since(1.4) + def orderBy(*cols): + """ + Creates a :class:`WindowSpec` with the partitioning defined. + """ + sc = SparkContext._active_spark_context + jspec = sc._jvm.org.apache.spark.sql.expressions.Window.partitionBy(_to_java_cols(cols)) + return WindowSpec(jspec) + + +class WindowSpec(object): + """ + A window specification that defines the partitioning, ordering, + and frame boundaries. + + Use the static methods in :class:`Window` to create a :class:`WindowSpec`. + + .. note:: Experimental + + .. versionadded:: 1.4 + """ + + _JAVA_MAX_LONG = (1 << 63) - 1 + _JAVA_MIN_LONG = - (1 << 63) + + def __init__(self, jspec): + self._jspec = jspec + + @since(1.4) + def partitionBy(self, *cols): + """ + Defines the partitioning columns in a :class:`WindowSpec`. + + :param cols: names of columns or expressions + """ + return WindowSpec(self._jspec.partitionBy(_to_java_cols(cols))) + + @since(1.4) + def orderBy(self, *cols): + """ + Defines the ordering columns in a :class:`WindowSpec`. + + :param cols: names of columns or expressions + """ + return WindowSpec(self._jspec.orderBy(_to_java_cols(cols))) + + @since(1.4) + def rowsBetween(self, start, end): + """ + Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative positions from the current row. + For example, "0" means "current row", while "-1" means the row before + the current row, and "5" means the fifth row after the current row. + + :param start: boundary start, inclusive. + The frame is unbounded if this is ``-sys.maxsize`` (or lower). + :param end: boundary end, inclusive. + The frame is unbounded if this is ``sys.maxsize`` (or higher). + """ + if start <= -sys.maxsize: + start = self._JAVA_MIN_LONG + if end >= sys.maxsize: + end = self._JAVA_MAX_LONG + return WindowSpec(self._jspec.rowsBetween(start, end)) + + @since(1.4) + def rangeBetween(self, start, end): + """ + Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + + Both `start` and `end` are relative from the current row. For example, + "0" means "current row", while "-1" means one off before the current row, + and "5" means the five off after the current row. + + :param start: boundary start, inclusive. + The frame is unbounded if this is ``-sys.maxsize`` (or lower). + :param end: boundary end, inclusive. + The frame is unbounded if this is ``sys.maxsize`` (or higher). + """ + if start <= -sys.maxsize: + start = self._JAVA_MIN_LONG + if end >= sys.maxsize: + end = self._JAVA_MAX_LONG + return WindowSpec(self._jspec.rangeBetween(start, end)) + + +def _test(): + import doctest + SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index e278b29003f69..10a859a532e28 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -132,11 +132,12 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders={}, .. note:: Experimental Create a RDD from Kafka using offset ranges for each topic and partition. + :param sc: SparkContext object :param kafkaParams: Additional params for Kafka :param offsetRanges: list of offsetRange to specify topic:partition:[start, end) to consume :param leaders: Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty - map, in which case leaders will be looked up on the driver. + map, in which case leaders will be looked up on the driver. :param keyDecoder: A function used to decode key (default is utf8_decoder) :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A RDD object diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 33ea8c9293d74..46cb18b2e8ef9 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -41,8 +41,8 @@ class PySparkStreamingTestCase(unittest.TestCase): - timeout = 4 # seconds - duration = .2 + timeout = 10 # seconds + duration = .5 @classmethod def setUpClass(cls): @@ -379,13 +379,13 @@ def func(dstream): class WindowFunctionTests(PySparkStreamingTestCase): - timeout = 5 + timeout = 15 def test_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.window(.6, .2).count() + return dstream.window(1.5, .5).count() expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -394,7 +394,7 @@ def test_count_by_window(self): input = [range(1), range(2), range(3), range(4), range(5)] def func(dstream): - return dstream.countByWindow(.6, .2) + return dstream.countByWindow(1.5, .5) expected = [[1], [3], [6], [9], [12], [9], [5]] self._test_func(input, func, expected) @@ -403,7 +403,7 @@ def test_count_by_window_large(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByWindow(1, .2) + return dstream.countByWindow(2.5, .5) expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] self._test_func(input, func, expected) @@ -412,7 +412,7 @@ def test_count_by_value_and_window(self): input = [range(1), range(2), range(3), range(4), range(5), range(6)] def func(dstream): - return dstream.countByValueAndWindow(1, .2) + return dstream.countByValueAndWindow(2.5, .5) expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] self._test_func(input, func, expected) @@ -421,7 +421,7 @@ def test_group_by_key_and_window(self): input = [[('a', i)] for i in range(5)] def func(dstream): - return dstream.groupByKeyAndWindow(.6, .2).mapValues(list) + return dstream.groupByKeyAndWindow(1.5, .5).mapValues(list) expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 09de4d159fdcf..f9fb37f7fc139 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -444,6 +444,11 @@ def func(x): class RDDTests(ReusedPySparkTestCase): + def test_range(self): + self.assertEqual(self.sc.range(1, 1).count(), 0) + self.assertEqual(self.sc.range(1, 0, -1).count(), 1) + self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2) + def test_id(self): rdd = self.sc.parallelize(range(10)) id = rdd.id() @@ -1543,13 +1548,13 @@ def count(): def test_with_different_versions_of_python(self): rdd = self.sc.parallelize(range(10)) rdd.count() - version = sys.version_info - sys.version_info = (2, 0, 0) + version = self.sc.pythonVer + self.sc.pythonVer = "2.0" try: with QuietTest(self.sc): self.assertRaises(Py4JJavaError, lambda: rdd.count()) finally: - sys.version_info = version + self.sc.pythonVer = version class SparkSubmitTests(unittest.TestCase): @@ -1804,6 +1809,10 @@ def run(): sc.stop() + def test_startTime(self): + with SparkContext() as sc: + self.assertGreater(sc.startTime, 0) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fbdaf3a5814cd..93df9002be377 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -57,6 +57,12 @@ def main(infile, outfile): if split_index == -1: # for unit tests exit(-1) + version = utf8_deserializer.loads(infile) + if version != "%d.%d" % sys.version_info[:2]: + raise Exception(("Python in worker has different version %s than that in " + + "driver %s, PySpark cannot run with different minor versions") % + ("%d.%d" % sys.version_info[:2], version)) + # initialize global state shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 @@ -92,11 +98,7 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, profiler, deserializer, serializer), version = command - if version != sys.version_info[:2]: - raise Exception(("Python in worker has different version %s than that in " + - "driver %s, PySpark cannot run with different minor versions") % - (sys.version_info[:2], version)) + func, profiler, deserializer, serializer = command init_time = time.time() def process(): diff --git a/python/run-tests b/python/run-tests index f2757a3967e81..17dda3eadac0c 100755 --- a/python/run-tests +++ b/python/run-tests @@ -57,52 +57,56 @@ function run_test() { function run_core_tests() { echo "Run core tests ..." - run_test "pyspark/rdd.py" - run_test "pyspark/context.py" - run_test "pyspark/conf.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - run_test "pyspark/profiler.py" - run_test "pyspark/shuffle.py" - run_test "pyspark/tests.py" + run_test "pyspark.rdd" + run_test "pyspark.context" + run_test "pyspark.conf" + run_test "pyspark.broadcast" + run_test "pyspark.accumulators" + run_test "pyspark.serializers" + run_test "pyspark.profiler" + run_test "pyspark.shuffle" + run_test "pyspark.tests" } function run_sql_tests() { echo "Run sql tests ..." - run_test "pyspark/sql/_types.py" - run_test "pyspark/sql/context.py" - run_test "pyspark/sql/dataframe.py" - run_test "pyspark/sql/functions.py" - run_test "pyspark/sql/tests.py" + run_test "pyspark.sql.types" + run_test "pyspark.sql.context" + run_test "pyspark.sql.column" + run_test "pyspark.sql.dataframe" + run_test "pyspark.sql.group" + run_test "pyspark.sql.functions" + run_test "pyspark.sql.readwriter" + run_test "pyspark.sql.window" + run_test "pyspark.sql.tests" } function run_mllib_tests() { echo "Run mllib tests ..." - run_test "pyspark/mllib/classification.py" - run_test "pyspark/mllib/clustering.py" - run_test "pyspark/mllib/evaluation.py" - run_test "pyspark/mllib/feature.py" - run_test "pyspark/mllib/fpm.py" - run_test "pyspark/mllib/linalg.py" - run_test "pyspark/mllib/rand.py" - run_test "pyspark/mllib/recommendation.py" - run_test "pyspark/mllib/regression.py" - run_test "pyspark/mllib/stat/_statistics.py" - run_test "pyspark/mllib/tree.py" - run_test "pyspark/mllib/util.py" - run_test "pyspark/mllib/tests.py" + run_test "pyspark.mllib.classification" + run_test "pyspark.mllib.clustering" + run_test "pyspark.mllib.evaluation" + run_test "pyspark.mllib.feature" + run_test "pyspark.mllib.fpm" + run_test "pyspark.mllib.linalg" + run_test "pyspark.mllib.random" + run_test "pyspark.mllib.recommendation" + run_test "pyspark.mllib.regression" + run_test "pyspark.mllib.stat._statistics" + run_test "pyspark.mllib.tree" + run_test "pyspark.mllib.util" + run_test "pyspark.mllib.tests" } function run_ml_tests() { echo "Run ml tests ..." - run_test "pyspark/ml/feature.py" - run_test "pyspark/ml/classification.py" - run_test "pyspark/ml/recommendation.py" - run_test "pyspark/ml/regression.py" - run_test "pyspark/ml/tuning.py" - run_test "pyspark/ml/tests.py" - run_test "pyspark/ml/evaluation.py" + run_test "pyspark.ml.feature" + run_test "pyspark.ml.classification" + run_test "pyspark.ml.recommendation" + run_test "pyspark.ml.regression" + run_test "pyspark.ml.tuning" + run_test "pyspark.ml.tests" + run_test "pyspark.ml.evaluation" } function run_streaming_tests() { @@ -122,8 +126,8 @@ function run_streaming_tests() { done export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell" - run_test "pyspark/streaming/util.py" - run_test "pyspark/streaming/tests.py" + run_test "pyspark.streaming.util" + run_test "pyspark.streaming.tests" } echo "Running PySpark tests. Output is in python/$LOG_FILE." diff --git a/python/test_support/sql/parquet_partitioned/_SUCCESS b/python/test_support/sql/parquet_partitioned/_SUCCESS new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python/test_support/sql/parquet_partitioned/_common_metadata b/python/test_support/sql/parquet_partitioned/_common_metadata new file mode 100644 index 0000000000000..7ef2320651dee Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_common_metadata differ diff --git a/python/test_support/sql/parquet_partitioned/_metadata b/python/test_support/sql/parquet_partitioned/_metadata new file mode 100644 index 0000000000000..78a1ca7d38279 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/_metadata differ diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc new file mode 100644 index 0000000000000..e93f42ed6f350 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/.part-r-00008.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet new file mode 100644 index 0000000000000..461c382937ecd Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2014/month=9/day=1/part-r-00008.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc new file mode 100644 index 0000000000000..b63c4d6d1e1dc Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00002.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc new file mode 100644 index 0000000000000..5bc0ebd713563 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/.part-r-00004.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet new file mode 100644 index 0000000000000..62a63915beac2 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00002.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet new file mode 100644 index 0000000000000..67665a7b55da6 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=25/part-r-00004.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc new file mode 100644 index 0000000000000..ae94a15d08c81 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/.part-r-00005.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet new file mode 100644 index 0000000000000..6cb8538aa8904 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=10/day=26/part-r-00005.gz.parquet differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc new file mode 100644 index 0000000000000..58d9bb5fc5883 Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/.part-r-00007.gz.parquet.crc differ diff --git a/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet new file mode 100644 index 0000000000000..9b00805481e7b Binary files /dev/null and b/python/test_support/sql/parquet_partitioned/year=2015/month=9/day=1/part-r-00007.gz.parquet differ diff --git a/python/test_support/sql/people.json b/python/test_support/sql/people.json new file mode 100644 index 0000000000000..50a859cbd7ee8 --- /dev/null +++ b/python/test_support/sql/people.json @@ -0,0 +1,3 @@ +{"name":"Michael"} +{"name":"Andy", "age":30} +{"name":"Justin", "age":19} diff --git a/repl/pom.xml b/repl/pom.xml index 03053b4c3b287..6e5cb7f77e1df 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -48,6 +48,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-bagel_${scala.binary.version} diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 488f3a9f33256..2b235525250c2 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -206,7 +206,8 @@ class SparkILoop( // e.g. file:/C:/my/path.jar -> C:/my/path.jar SparkILoop.getAddedJars.map { jar => new URI(jar).getPath.stripPrefix("/") } } else { - SparkILoop.getAddedJars + // We need new URI(jar).getPath here for the case that `jar` includes encoded white space (%20). + SparkILoop.getAddedJars.map { jar => new URI(jar).getPath } } // work around for Scala bug val totalClassPath = addedJars.foldLeft( @@ -1109,7 +1110,7 @@ object SparkILoop extends Logging { if (settings.classpath.isDefault) settings.classpath.value = sys.props("java.class.path") - getAddedJars.foreach(settings.classpath.append(_)) + getAddedJars.map(jar => new URI(jar).getPath).foreach(settings.classpath.append(_)) repl process settings } diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 934daaeaafca1..50fd43a418bca 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -22,13 +22,12 @@ import java.net.URLClassLoader import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.util.Utils -class ReplSuite extends FunSuite { +class ReplSuite extends SparkFunSuite { def runInterpreter(master: String, input: String): String = { val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 14f5e9ed4f25e..9ecc7c229e38a 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -24,14 +24,13 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.tools.nsc.interpreter.SparkILoop -import org.scalatest.FunSuite import org.apache.commons.lang3.StringEscapeUtils -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.util.Utils -class ReplSuite extends FunSuite { +class ReplSuite extends SparkFunSuite { def runInterpreter(master: String, input: String): String = { val CONF_EXECUTOR_CLASSPATH = "spark.executor.extraClassPath" diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index c709cde740748..a58eda12b1120 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -25,7 +25,6 @@ import scala.language.implicitConversions import scala.language.postfixOps import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar @@ -35,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.util.Utils class ExecutorClassLoaderSuite - extends FunSuite + extends SparkFunSuite with BeforeAndAfterAll with MockitoSugar with Logging { diff --git a/sbin/start-master.sh b/sbin/start-master.sh index 17fff58f4f768..a7f5d5702fd80 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -22,6 +22,8 @@ sbin="`dirname "$0"`" sbin="`cd "$sbin"; pwd`" +ORIGINAL_ARGS="$@" + START_TACHYON=false while (( "$#" )); do @@ -53,7 +55,9 @@ if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then SPARK_MASTER_WEBUI_PORT=8080 fi -"$sbin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT +"$sbin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 \ + --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \ + $ORIGINAL_ARGS if [ "$START_TACHYON" == "true" ]; then "$sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 7168d5b2a8e26..d6f927b6fa803 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -14,25 +14,41 @@ ~ See the License for the specific language governing permissions and ~ limitations under the License. --> - - - - - - + - Scalastyle standard configuration - - - - - - - - - Scalastyle standard configuration + + + + + + + + + + - - - - - - - - - - true - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + + + + ^FunSuite[A-Za-z]*$ + Tests must extend org.apache.spark.SparkFunSuite instead. + + + + + + + + + ^println$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 800> + + + + + 30 + + + + + 10 + + + + + 50 + + + + + + + + + + + -1,0,1,2,3 + + diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 5c322d032d474..d9e1cdb84bb27 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -50,6 +50,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-unsafe_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 4190b7ffe1c8f..0d460b634d9b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -55,6 +55,9 @@ object Row { // TODO: Improve the performance of this if used in performance critical part. new GenericRow(rows.flatMap(_.toSeq).toArray) } + + /** Returns an empty row. */ + val empty = apply() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 2eb3e167baad5..ef7b3ad9432cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -103,7 +103,7 @@ class SqlLexical extends StdLexical { ( identChar ~ (identChar | digit).* ^^ { case first ~ rest => processIdent((first :: rest).mkString) } | rep1(digit) ~ ('.' ~> digit.*).? ^^ { - case i ~ None => NumericLit(i.mkString) + case i ~ None => NumericLit(i.mkString) case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) } | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index a13e2f36a1a1f..2e7b4c236d8f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -18,11 +18,15 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} +import java.math.{BigDecimal => JavaBigDecimal} +import java.sql.Date import java.util.{Map => JavaMap} +import javax.annotation.Nullable import scala.collection.mutable.HashMap import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ /** @@ -33,197 +37,338 @@ object CatalystTypeConverters { // Since the map values can be mutable, we explicitly import scala.collection.Map at here. import scala.collection.Map + private def isPrimitive(dataType: DataType): Boolean = { + dataType match { + case BooleanType => true + case ByteType => true + case ShortType => true + case IntegerType => true + case LongType => true + case FloatType => true + case DoubleType => true + case _ => false + } + } + + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { + val converter = dataType match { + case udt: UserDefinedType[_] => UDTConverter(udt) + case arrayType: ArrayType => ArrayConverter(arrayType.elementType) + case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType) + case structType: StructType => StructConverter(structType) + case StringType => StringConverter + case DateType => DateConverter + case dt: DecimalType => BigDecimalConverter + case BooleanType => BooleanConverter + case ByteType => ByteConverter + case ShortType => ShortConverter + case IntegerType => IntConverter + case LongType => LongConverter + case FloatType => FloatConverter + case DoubleType => DoubleConverter + case _ => IdentityConverter + } + converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] + } + /** - * Converts Scala objects to catalyst rows / types. This method is slow, and for batch - * conversion you should be using converter produced by createToCatalystConverter. - * Note: This is always called after schemaFor has been called. - * This ordering is important for UDT registration. + * Converts a Scala type to its Catalyst equivalent (and vice versa). + * + * @tparam ScalaInputType The type of Scala values that can be converted to Catalyst. + * @tparam ScalaOutputType The type of Scala values returned when converting Catalyst to Scala. + * @tparam CatalystType The internal Catalyst type used to represent values of this Scala type. */ - def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (obj, udt: UserDefinedType[_]) => - udt.serialize(obj) - - case (o: Option[_], _) => - o.map(convertToCatalyst(_, dataType)).orNull - - case (s: Seq[_], arrayType: ArrayType) => - s.map(convertToCatalyst(_, arrayType.elementType)) - - case (jit: JavaIterable[_], arrayType: ArrayType) => { - val iter = jit.iterator - var listOfItems: List[Any] = List() - while (iter.hasNext) { - val item = iter.next() - listOfItems :+= convertToCatalyst(item, arrayType.elementType) + private abstract class CatalystTypeConverter[ScalaInputType, ScalaOutputType, CatalystType] + extends Serializable { + + /** + * Converts a Scala type to its Catalyst equivalent while automatically handling nulls + * and Options. + */ + final def toCatalyst(@Nullable maybeScalaValue: Any): CatalystType = { + if (maybeScalaValue == null) { + null.asInstanceOf[CatalystType] + } else if (maybeScalaValue.isInstanceOf[Option[ScalaInputType]]) { + val opt = maybeScalaValue.asInstanceOf[Option[ScalaInputType]] + if (opt.isDefined) { + toCatalystImpl(opt.get) + } else { + null.asInstanceOf[CatalystType] + } + } else { + toCatalystImpl(maybeScalaValue.asInstanceOf[ScalaInputType]) } - listOfItems } - case (s: Array[_], arrayType: ArrayType) => - s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + */ + final def toScala(row: Row, column: Int): ScalaOutputType = { + if (row.isNullAt(column)) null.asInstanceOf[ScalaOutputType] else toScalaImpl(row, column) + } - case (m: Map[_, _], mapType: MapType) => - m.map { case (k, v) => - convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) - } + /** + * Convert a Catalyst value to its Scala equivalent. + */ + def toScala(@Nullable catalystValue: CatalystType): ScalaOutputType + + /** + * Converts a Scala value to its Catalyst equivalent. + * @param scalaValue the Scala value, guaranteed not to be null. + * @return the Catalyst value. + */ + protected def toCatalystImpl(scalaValue: ScalaInputType): CatalystType + + /** + * Given a Catalyst row, convert the value at column `column` to its Scala equivalent. + * This method will only be called on non-null columns. + */ + protected def toScalaImpl(row: Row, column: Int): ScalaOutputType + } + + private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = scalaValue + override def toScala(catalystValue: Any): Any = catalystValue + override def toScalaImpl(row: Row, column: Int): Any = row(column) + } + + private case class UDTConverter( + udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { + override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) + override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) + override def toScalaImpl(row: Row, column: Int): Any = toScala(row(column)) + } - case (jmap: JavaMap[_, _], mapType: MapType) => - val iter = jmap.entrySet.iterator - var listOfEntries: List[(Any, Any)] = List() - while (iter.hasNext) { - val entry = iter.next() - listOfEntries :+= (convertToCatalyst(entry.getKey, mapType.keyType), - convertToCatalyst(entry.getValue, mapType.valueType)) + /** Converter for arrays, sequences, and Java iterables. */ + private case class ArrayConverter( + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] { + + private[this] val elementConverter = getConverterForType(elementType) + + override def toCatalystImpl(scalaValue: Any): Seq[Any] = { + scalaValue match { + case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) + case s: Seq[_] => s.map(elementConverter.toCatalyst) + case i: JavaIterable[_] => + val iter = i.iterator + var convertedIterable: List[Any] = List() + while (iter.hasNext) { + val item = iter.next() + convertedIterable :+= elementConverter.toCatalyst(item) + } + convertedIterable } - listOfEntries.toMap - - case (p: Product, structType: StructType) => - val ar = new Array[Any](structType.size) - val iter = p.productIterator - var idx = 0 - while (idx < structType.size) { - ar(idx) = convertToCatalyst(iter.next(), structType.fields(idx).dataType) - idx += 1 + } + + override def toScala(catalystValue: Seq[Any]): Seq[Any] = { + if (catalystValue == null) { + null + } else { + catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala) } - new GenericRowWithSchema(ar, structType) + } + + override def toScalaImpl(row: Row, column: Int): Seq[Any] = + toScala(row(column).asInstanceOf[Seq[Any]]) + } - case (d: String, _) => - UTF8String(d) + private case class MapConverter( + keyType: DataType, + valueType: DataType) + extends CatalystTypeConverter[Any, Map[Any, Any], Map[Any, Any]] { - case (d: BigDecimal, _) => - Decimal(d) + private[this] val keyConverter = getConverterForType(keyType) + private[this] val valueConverter = getConverterForType(valueType) - case (d: java.math.BigDecimal, _) => - Decimal(d) + override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match { + case m: Map[_, _] => + m.map { case (k, v) => + keyConverter.toCatalyst(k) -> valueConverter.toCatalyst(v) + } - case (d: java.sql.Date, _) => - DateUtils.fromJavaDate(d) + case jmap: JavaMap[_, _] => + val iter = jmap.entrySet.iterator + val convertedMap: HashMap[Any, Any] = HashMap() + while (iter.hasNext) { + val entry = iter.next() + val key = keyConverter.toCatalyst(entry.getKey) + convertedMap(key) = valueConverter.toCatalyst(entry.getValue) + } + convertedMap + } - case (r: Row, structType: StructType) => - val converters = structType.fields.map { - f => (item: Any) => convertToCatalyst(item, f.dataType) + override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = { + if (catalystValue == null) { + null + } else { + catalystValue.map { case (k, v) => + keyConverter.toScala(k) -> valueConverter.toScala(v) + } } - convertRowWithConverters(r, structType, converters) + } - case (other, _) => - other + override def toScalaImpl(row: Row, column: Int): Map[Any, Any] = + toScala(row(column).asInstanceOf[Map[Any, Any]]) } - /** - * Creates a converter function that will convert Scala objects to the specified catalyst type. - * Typical use case would be converting a collection of rows that have the same schema. You will - * call this function once to get a converter, and apply it to every row. - */ - private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { - def extractOption(item: Any): Any = item match { - case opt: Option[_] => opt.orNull - case other => other - } + private case class StructConverter( + structType: StructType) extends CatalystTypeConverter[Any, Row, Row] { - dataType match { - // Check UDT first since UDTs can override other types - case udt: UserDefinedType[_] => - (item) => extractOption(item) match { - case null => null - case other => udt.serialize(other) - } + private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) } - case arrayType: ArrayType => - val elementConverter = createToCatalystConverter(arrayType.elementType) - (item: Any) => { - extractOption(item) match { - case a: Array[_] => a.toSeq.map(elementConverter) - case s: Seq[_] => s.map(elementConverter) - case i: JavaIterable[_] => { - val iter = i.iterator - var convertedIterable: List[Any] = List() - while (iter.hasNext) { - val item = iter.next() - convertedIterable :+= elementConverter(item) - } - convertedIterable - } - case null => null - } + override def toCatalystImpl(scalaValue: Any): Row = scalaValue match { + case row: Row => + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toCatalyst(row(idx)) + idx += 1 } - - case mapType: MapType => - val keyConverter = createToCatalystConverter(mapType.keyType) - val valueConverter = createToCatalystConverter(mapType.valueType) - (item: Any) => { - extractOption(item) match { - case m: Map[_, _] => - m.map { case (k, v) => - keyConverter(k) -> valueConverter(v) - } - - case jmap: JavaMap[_, _] => - val iter = jmap.entrySet.iterator - val convertedMap: HashMap[Any, Any] = HashMap() - while (iter.hasNext) { - val entry = iter.next() - convertedMap(keyConverter(entry.getKey)) = valueConverter(entry.getValue) - } - convertedMap - - case null => null - } + new GenericRowWithSchema(ar, structType) + + case p: Product => + val ar = new Array[Any](structType.size) + val iter = p.productIterator + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx).toCatalyst(iter.next()) + idx += 1 } + new GenericRowWithSchema(ar, structType) + } - case structType: StructType => - val converters = structType.fields.map(f => createToCatalystConverter(f.dataType)) - (item: Any) => { - extractOption(item) match { - case r: Row => - convertRowWithConverters(r, structType, converters) - - case p: Product => - val ar = new Array[Any](structType.size) - val iter = p.productIterator - var idx = 0 - while (idx < structType.size) { - ar(idx) = converters(idx)(iter.next()) - idx += 1 - } - new GenericRowWithSchema(ar, structType) - - case null => - null - } + override def toScala(row: Row): Row = { + if (row == null) { + null + } else { + val ar = new Array[Any](row.size) + var idx = 0 + while (idx < row.size) { + ar(idx) = converters(idx).toScala(row, idx) + idx += 1 } - - case dateType: DateType => (item: Any) => extractOption(item) match { - case d: java.sql.Date => DateUtils.fromJavaDate(d) - case other => other + new GenericRowWithSchema(ar, structType) } + } - case dataType: StringType => (item: Any) => extractOption(item) match { - case s: String => UTF8String(s) - case other => other - } + override def toScalaImpl(row: Row, column: Int): Row = toScala(row(column).asInstanceOf[Row]) + } + + private object StringConverter extends CatalystTypeConverter[Any, String, Any] { + override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { + case str: String => UTF8String(str) + case utf8: UTF8String => utf8 + } + override def toScala(catalystValue: Any): String = catalystValue match { + case null => null + case str: String => str + case utf8: UTF8String => utf8.toString() + } + override def toScalaImpl(row: Row, column: Int): String = row(column).toString + } + + private object DateConverter extends CatalystTypeConverter[Date, Date, Any] { + override def toCatalystImpl(scalaValue: Date): Int = DateUtils.fromJavaDate(scalaValue) + override def toScala(catalystValue: Any): Date = + if (catalystValue == null) null else DateUtils.toJavaDate(catalystValue.asInstanceOf[Int]) + override def toScalaImpl(row: Row, column: Int): Date = toScala(row.getInt(column)) + } - case _ => - (item: Any) => extractOption(item) match { - case d: BigDecimal => Decimal(d) - case d: java.math.BigDecimal => Decimal(d) - case other => other + private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { + case d: BigDecimal => Decimal(d) + case d: JavaBigDecimal => Decimal(d) + case d: Decimal => d + } + override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal + override def toScalaImpl(row: Row, column: Int): JavaBigDecimal = row.get(column) match { + case d: JavaBigDecimal => d + case d: Decimal => d.toJavaBigDecimal + } + } + + private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { + final override def toScala(catalystValue: Any): Any = catalystValue + final override def toCatalystImpl(scalaValue: T): Any = scalaValue + } + + private object BooleanConverter extends PrimitiveConverter[Boolean] { + override def toScalaImpl(row: Row, column: Int): Boolean = row.getBoolean(column) + } + + private object ByteConverter extends PrimitiveConverter[Byte] { + override def toScalaImpl(row: Row, column: Int): Byte = row.getByte(column) + } + + private object ShortConverter extends PrimitiveConverter[Short] { + override def toScalaImpl(row: Row, column: Int): Short = row.getShort(column) + } + + private object IntConverter extends PrimitiveConverter[Int] { + override def toScalaImpl(row: Row, column: Int): Int = row.getInt(column) + } + + private object LongConverter extends PrimitiveConverter[Long] { + override def toScalaImpl(row: Row, column: Int): Long = row.getLong(column) + } + + private object FloatConverter extends PrimitiveConverter[Float] { + override def toScalaImpl(row: Row, column: Int): Float = row.getFloat(column) + } + + private object DoubleConverter extends PrimitiveConverter[Double] { + override def toScalaImpl(row: Row, column: Int): Double = row.getDouble(column) + } + + /** + * Converts Scala objects to catalyst rows / types. This method is slow, and for batch + * conversion you should be using converter produced by createToCatalystConverter. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ + def convertToCatalyst(scalaValue: Any, dataType: DataType): Any = { + getConverterForType(dataType).toCatalyst(scalaValue) + } + + /** + * Creates a converter function that will convert Scala objects to the specified Catalyst type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToCatalystConverter(dataType: DataType): Any => Any = { + if (isPrimitive(dataType)) { + // Although the `else` branch here is capable of handling inbound conversion of primitives, + // we add some special-case handling for those types here. The motivation for this relates to + // Java method invocation costs: if we have rows that consist entirely of primitive columns, + // then returning the same conversion function for all of the columns means that the call site + // will be monomorphic instead of polymorphic. In microbenchmarks, this actually resulted in + // a measurable performance impact. Note that this optimization will be unnecessary if we + // use code generation to construct Scala Row -> Catalyst Row converters. + def convert(maybeScalaValue: Any): Any = { + if (maybeScalaValue.isInstanceOf[Option[Any]]) { + maybeScalaValue.asInstanceOf[Option[Any]].orNull + } else { + maybeScalaValue } + } + convert + } else { + getConverterForType(dataType).toCatalyst } } /** - * Converts Scala objects to catalyst rows / types. + * Converts Scala objects to Catalyst rows / types. * * Note: This should be called before do evaluation on Row * (It does not support UDT) * This is used to create an RDD or test results with correct types for Catalyst. */ def convertToCatalyst(a: Any): Any = a match { - case s: String => UTF8String(s) - case d: java.sql.Date => DateUtils.fromJavaDate(d) - case d: BigDecimal => Decimal(d) - case d: java.math.BigDecimal => Decimal(d) + case s: String => StringConverter.toCatalyst(s) + case d: Date => DateConverter.toCatalyst(d) + case d: BigDecimal => BigDecimalConverter.toCatalyst(d) + case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) case seq: Seq[Any] => seq.map(convertToCatalyst) case r: Row => Row(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray @@ -232,38 +377,13 @@ object CatalystTypeConverters { case other => other } - /** + /** * Converts Catalyst types used internally in rows to standard Scala types * This method is slow, and for batch conversion you should be using converter * produced by createToScalaConverter. */ - def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { - // Check UDT first since UDTs can override other types - case (d, udt: UserDefinedType[_]) => - udt.deserialize(d) - - case (s: Seq[_], arrayType: ArrayType) => - s.map(convertToScala(_, arrayType.elementType)) - - case (m: Map[_, _], mapType: MapType) => - m.map { case (k, v) => - convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) - } - - case (r: Row, s: StructType) => - convertRowToScala(r, s) - - case (d: Decimal, _: DecimalType) => - d.toJavaBigDecimal - - case (i: Int, DateType) => - DateUtils.toJavaDate(i) - - case (s: UTF8String, StringType) => - s.toString() - - case (other, _) => - other + def convertToScala(catalystValue: Any, dataType: DataType): Any = { + getConverterForType(dataType).toScala(catalystValue) } /** @@ -271,82 +391,7 @@ object CatalystTypeConverters { * Typical use case would be converting a collection of rows that have the same schema. You will * call this function once to get a converter, and apply it to every row. */ - private[sql] def createToScalaConverter(dataType: DataType): Any => Any = dataType match { - // Check UDT first since UDTs can override other types - case udt: UserDefinedType[_] => - (item: Any) => if (item == null) null else udt.deserialize(item) - - case arrayType: ArrayType => - val elementConverter = createToScalaConverter(arrayType.elementType) - (item: Any) => if (item == null) null else item.asInstanceOf[Seq[_]].map(elementConverter) - - case mapType: MapType => - val keyConverter = createToScalaConverter(mapType.keyType) - val valueConverter = createToScalaConverter(mapType.valueType) - (item: Any) => if (item == null) { - null - } else { - item.asInstanceOf[Map[_, _]].map { case (k, v) => - keyConverter(k) -> valueConverter(v) - } - } - - case s: StructType => - val converters = s.fields.map(f => createToScalaConverter(f.dataType)) - (item: Any) => { - if (item == null) { - null - } else { - convertRowWithConverters(item.asInstanceOf[Row], s, converters) - } - } - - case _: DecimalType => - (item: Any) => item match { - case d: Decimal => d.toJavaBigDecimal - case other => other - } - - case DateType => - (item: Any) => item match { - case i: Int => DateUtils.toJavaDate(i) - case other => other - } - - case StringType => - (item: Any) => item match { - case s: UTF8String => s.toString() - case other => other - } - - case other => - (item: Any) => item - } - - def convertRowToScala(r: Row, schema: StructType): Row = { - val ar = new Array[Any](r.size) - var idx = 0 - while (idx < r.size) { - ar(idx) = convertToScala(r(idx), schema.fields(idx).dataType) - idx += 1 - } - new GenericRowWithSchema(ar, schema) - } - - /** - * Converts a row by applying the provided set of converter functions. It is used for both - * toScala and toCatalyst conversions. - */ - private[sql] def convertRowWithConverters( - row: Row, - schema: StructType, - converters: Array[Any => Any]): Row = { - val ar = new Array[Any](row.size) - var idx = 0 - while (idx < row.size) { - ar(idx) = converters(idx)(row(idx)) - idx += 1 - } - new GenericRowWithSchema(ar, schema) + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { + getConverterForType(dataType).toScala } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 1ec874f79617c..9a3f9694e4c48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.catalyst import java.beans.Introspector import java.lang.{Iterable => JIterable} @@ -24,10 +24,8 @@ import java.util.{Iterator => JIterator, Map => JMap} import scala.language.existentials import com.google.common.reflect.TypeToken - import org.apache.spark.sql.types._ - /** * Type-inference utilities for POJOs and Java collections. */ @@ -40,12 +38,21 @@ private [sql] object JavaTypeInference { private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType + /** + * Infers the corresponding SQL data type of a JavaClean class. + * @param beanClass Java type + * @return (SQL data type, nullable) + */ + def inferDataType(beanClass: Class[_]): (DataType, Boolean) = { + inferDataType(TypeToken.of(beanClass)) + } + /** * Infers the corresponding SQL data type of a Java type. * @param typeToken Java type * @return (SQL data type, nullable) */ - private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala index 05a92b06f9fd9..554fb4eb25eb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala @@ -31,3 +31,39 @@ abstract class ParserDialect { // this is the main function that will be implemented by sql parser. def parse(sqlText: String): LogicalPlan } + +/** + * Currently we support the default dialect named "sql", associated with the class + * [[DefaultParserDialect]] + * + * And we can also provide custom SQL Dialect, for example in Spark SQL CLI: + * {{{ + *-- switch to "hiveql" dialect + * spark-sql>SET spark.sql.dialect=hiveql; + * spark-sql>SELECT * FROM src LIMIT 1; + * + *-- switch to "sql" dialect + * spark-sql>SET spark.sql.dialect=sql; + * spark-sql>SELECT * FROM src LIMIT 1; + * + *-- register the new SQL dialect + * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect; + * spark-sql> SELECT * FROM src LIMIT 1; + * + *-- register the non-exist SQL dialect + * spark-sql> SET spark.sql.dialect=NotExistedClass; + * spark-sql> SELECT * FROM src LIMIT 1; + * + *-- Exception will be thrown and switch to dialect + *-- "sql" (for SQLContext) or + *-- "hiveql" (for HiveContext) + * }}} + */ +private[spark] class DefaultParserDialect extends ParserDialect { + @transient + protected val sqlParser = new SqlParser + + override def parse(sqlText: String): LogicalPlan = { + sqlParser.parse(sqlText) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index fc36b9f1f20d2..e85312aee7d16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -140,7 +140,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { (HAVING ~> expression).? ~ sortType.? ~ (LIMIT ~> expression).? ^^ { - case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => + case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l => val base = r.getOrElse(OneRowRelation) val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g @@ -212,7 +212,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val ordering: Parser[Seq[SortOrder]] = ( rep1sep(expression ~ direction.? , ",") ^^ { - case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) + case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending))) } ) @@ -242,7 +242,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | termExpression ~ NOT.? ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ { case e ~ not ~ el ~ eu => val betweenExpr: Expression = And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu)) - not.fold(betweenExpr)(f=> Not(betweenExpr)) + not.fold(betweenExpr)(f => Not(betweenExpr)) } | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) } @@ -365,7 +365,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) - | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } + | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } | primary ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a4c61149dd975..bc17169f35a46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -73,7 +72,6 @@ class Analyzer( ResolveGroupingAnalytics :: ResolveSortReferences :: ResolveGenerate :: - ImplicitGenerate :: ResolveFunctions :: ExtractWindowExpressions :: GlobalAggregates :: @@ -143,25 +141,6 @@ class Analyzer( } object ResolveGroupingAnalytics extends Rule[LogicalPlan] { - /** - * Extract attribute set according to the grouping id - * @param bitmask bitmask to represent the selected of the attribute sequence - * @param exprs the attributes in sequence - * @return the attributes of non selected specified via bitmask (with the bit set to 1) - */ - private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) - : OpenHashSet[Expression] = { - val set = new OpenHashSet[Expression](2) - - var bit = exprs.length - 1 - while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) - bit -= 1 - } - - set - } - /* * GROUP BY a, b, c WITH ROLLUP * is equivalent to @@ -198,10 +177,15 @@ class Analyzer( g.bitmasks.foreach { bitmask => // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs) + val nonSelectedGroupExprs = ArrayBuffer.empty[Expression] + var bit = g.groupByExprs.length - 1 + while (bit >= 0) { + if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit) + bit -= 1 + } val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown { - case x: Expression if nonSelectedGroupExprSet.contains(x) => + case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null Literal.create(null, expr.dataType) @@ -322,6 +306,16 @@ class Analyzer( case oldVersion @ Aggregate(_, aggregateExpressions, _) if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + + case oldVersion: Generate + if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => + val newOutput = oldVersion.generatorOutput.map(_.newInstance()) + (oldVersion, oldVersion.copy(generatorOutput = newOutput)) + + case oldVersion @ Window(_, windowExpressions, _, child) + if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) + .nonEmpty => + (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass. sys.error( s""" @@ -500,7 +494,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) if aggregate.resolved && containsAggregate(havingCondition) => { - val evaluatedCondition = Alias(havingCondition, "havingCondition")() + val evaluatedCondition = Alias(havingCondition, "havingCondition")() val aggExprsWithHaving = evaluatedCondition +: originalAggExprs Project(aggregate.output, @@ -516,66 +510,103 @@ class Analyzer( } /** - * When a SELECT clause has only a single expression and that expression is a - * [[catalyst.expressions.Generator Generator]] we convert the - * [[catalyst.plans.logical.Project Project]] to a [[catalyst.plans.logical.Generate Generate]]. + * Rewrites table generating expressions that either need one or more of the following in order + * to be resolved: + * - concrete attribute references for their output. + * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). + * + * Names for the output [[Attribute]]s are extracted from [[Alias]] or [[MultiAlias]] expressions + * that wrap the [[Generator]]. If more than one [[Generator]] is found in a Project, an + * [[AnalysisException]] is throw. */ - object ImplicitGenerate extends Rule[LogicalPlan] { + object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Project(Seq(Alias(g: Generator, name)), child) => - Generate(g, join = false, outer = false, - qualifier = None, UnresolvedAttribute(name) :: Nil, child) - case Project(Seq(MultiAlias(g: Generator, names)), child) => - Generate(g, join = false, outer = false, - qualifier = None, names.map(UnresolvedAttribute(_)), child) + case p: Generate if !p.child.resolved || !p.generator.resolved => p + case g: Generate if !g.resolved => + g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) + + case p @ Project(projectList, child) => + // Holds the resolved generator, if one exists in the project list. + var resolvedGenerator: Generate = null + + val newProjectList = projectList.flatMap { + case AliasedGenerator(generator, names) if generator.childrenResolved => + if (resolvedGenerator != null) { + failAnalysis( + s"Only one generator allowed per select but ${resolvedGenerator.nodeName} and " + + s"and ${generator.nodeName} found.") + } + + resolvedGenerator = + Generate( + generator, + join = projectList.size > 1, // Only join if there are other expressions in SELECT. + outer = false, + qualifier = None, + generatorOutput = makeGeneratorOutput(generator, names), + child) + + resolvedGenerator.generatorOutput + case other => other :: Nil + } + + if (resolvedGenerator != null) { + Project(newProjectList, resolvedGenerator) + } else { + p + } } - } - /** - * Resolve the Generate, if the output names specified, we will take them, otherwise - * we will try to provide the default names, which follow the same rule with Hive. - */ - object ResolveGenerate extends Rule[LogicalPlan] { - // Construct the output attributes for the generator, - // The output attribute names can be either specified or - // auto generated. + /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ + private object AliasedGenerator { + def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { + case Alias(g: Generator, name) + if g.elementTypes.size > 1 && java.util.regex.Pattern.matches("_c[0-9]+", name) => { + // Assume the default name given by parser is "_c[0-9]+", + // TODO in long term, move the naming logic from Parser to Analyzer. + // In projection, Parser gave default name for TGF as does for normal UDF, + // but the TGF probably have multiple output columns/names. + // e.g. SELECT explode(map(key, value)) FROM src; + // Let's simply ignore the default given name for this case. + Some((g, Nil)) + } + case Alias(g: Generator, name) if g.elementTypes.size > 1 => + // If not given the default names, and the TGF with multiple output columns + failAnalysis( + s"""Expect multiple names given for ${g.getClass.getName}, + |but only single name '${name}' specified""".stripMargin) + case Alias(g: Generator, name) => Some((g, name :: Nil)) + case MultiAlias(g: Generator, names) => Some(g, names) + case _ => None + } + } + + /** + * Construct the output attributes for a [[Generator]], given a list of names. If the list of + * names is empty names are assigned by ordinal (i.e., _c0, _c1, ...) to match Hive's defaults. + */ private def makeGeneratorOutput( generator: Generator, - generatorOutput: Seq[Attribute]): Seq[Attribute] = { + names: Seq[String]): Seq[Attribute] = { val elementTypes = generator.elementTypes - if (generatorOutput.length == elementTypes.length) { - generatorOutput.zip(elementTypes).map { - case (a, (t, nullable)) if !a.resolved => - AttributeReference(a.name, t, nullable)() - case (a, _) => a + if (names.length == elementTypes.length) { + names.zip(elementTypes).map { + case (name, (t, nullable)) => + AttributeReference(name, t, nullable)() } - } else if (generatorOutput.length == 0) { + } else if (names.isEmpty) { elementTypes.zipWithIndex.map { // keep the default column names as Hive does _c0, _c1, _cN case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)() } } else { - throw new AnalysisException( - s""" - |The number of aliases supplied in the AS clause does not match - |the number of columns output by the UDTF expected - |${elementTypes.size} aliases but got ${generatorOutput.size} - """.stripMargin) + failAnalysis( + "The number of aliases supplied in the AS clause does not match the number of columns " + + s"output by the UDTF expected ${elementTypes.size} aliases but got " + + s"${names.mkString(",")} ") } } - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Generate if !p.child.resolved || !p.generator.resolved => p - case p: Generate if p.resolved == false => - // if the generator output names are not specified, we will use the default ones. - Generate( - p.generator, - join = p.join, - outer = p.outer, - p.qualifier, - makeGeneratorOutput(p.generator, p.generatorOutput), p.child) - } } /** @@ -602,10 +633,10 @@ class Analyzer( * it into the plan tree. */ object ExtractWindowExpressions extends Rule[LogicalPlan] { - def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = + private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean = projectList.exists(hasWindowFunction) - def hasWindowFunction(expr: NamedExpression): Boolean = { + private def hasWindowFunction(expr: NamedExpression): Boolean = { expr.find { case window: WindowExpression => true case _ => false @@ -613,14 +644,24 @@ class Analyzer( } /** - * From a Seq of [[NamedExpression]]s, extract window expressions and - * other regular expressions. + * From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and + * other regular expressions that do not contain any window expression. For example, for + * `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract + * `col1`, `col2 + col3`, `col4`, and `col5` out and replace their appearances in + * the window expression as attribute references. So, the first returned value will be + * `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be + * [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2]. + * + * @return (seq of expressions containing at lease one window expressions, + * seq of non-window expressions) */ - def extract( + private def extract( expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = { - // First, we simple partition the input expressions to two part, one having - // WindowExpressions and another one without WindowExpressions. - val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction) + // First, we partition the input expressions to two part. For the first part, + // every expression in it contain at least one WindowExpression. + // Expressions in the second part do not have any WindowExpression. + val (expressionsWithWindowFunctions, regularExpressions) = + expressions.partition(hasWindowFunction) // Then, we need to extract those regular expressions used in the WindowExpression. // For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5), @@ -629,8 +670,8 @@ class Analyzer( val extractedExprBuffer = new ArrayBuffer[NamedExpression]() def extractExpr(expr: Expression): Expression = expr match { case ne: NamedExpression => - // If a named expression is not in regularExpressions, add extract it and replace it - // with an AttributeReference. + // If a named expression is not in regularExpressions, add it to + // extractedExprBuffer and replace it with an AttributeReference. val missingExpr = AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer) if (missingExpr.nonEmpty) { @@ -647,8 +688,9 @@ class Analyzer( withName.toAttribute } - // Now, we extract expressions from windowExpressions by using extractExpr. - val newWindowExpressions = windowExpressions.map { + // Now, we extract regular expressions from expressionsWithWindowFunctions + // by using extractExpr. + val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). @@ -674,37 +716,80 @@ class Analyzer( }.asInstanceOf[NamedExpression] } - (newWindowExpressions, regularExpressions ++ extractedExprBuffer) - } + (newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer) + } // end of extract /** * Adds operators for Window Expressions. Every Window operator handles a single Window Spec. */ - def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = { - // First, we group window expressions based on their Window Spec. - val groupedWindowExpression = windowExpressions.groupBy { expr => - val windowSpec = expr.collectFirst { + private def addWindow( + expressionsWithWindowFunctions: Seq[NamedExpression], + child: LogicalPlan): LogicalPlan = { + // First, we need to extract all WindowExpressions from expressionsWithWindowFunctions + // and put those extracted WindowExpressions to extractedWindowExprBuffer. + // This step is needed because it is possible that an expression contains multiple + // WindowExpressions with different Window Specs. + // After extracting WindowExpressions, we need to construct a project list to generate + // expressionsWithWindowFunctions based on extractedWindowExprBuffer. + // For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract + // "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to + // "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)". + // Then, the projectList will be [_we0/_we1]. + val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]() + val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { + // We need to use transformDown because we want to trigger + // "case alias @ Alias(window: WindowExpression, _)" first. + _.transformDown { + case alias @ Alias(window: WindowExpression, _) => + // If a WindowExpression has an assigned alias, just use it. + extractedWindowExprBuffer += alias + alias.toAttribute + case window: WindowExpression => + // If there is no alias assigned to the WindowExpressions. We create an + // internal column. + val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")() + extractedWindowExprBuffer += withName + withName.toAttribute + }.asInstanceOf[NamedExpression] + } + + // Second, we group extractedWindowExprBuffer based on their Window Spec. + val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr => + val distinctWindowSpec = expr.collect { case window: WindowExpression => window.windowSpec + }.distinct + + // We do a final check and see if we only have a single Window Spec defined in an + // expressions. + if (distinctWindowSpec.length == 0 ) { + failAnalysis(s"$expr does not have any WindowExpression.") + } else if (distinctWindowSpec.length > 1) { + // newExpressionsWithWindowFunctions only have expressions with a single + // WindowExpression. If we reach here, we have a bug. + failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." + + s"Please file a bug report with this error message, stack trace, and the query.") + } else { + distinctWindowSpec.head } - windowSpec.getOrElse( - failAnalysis(s"$windowExpressions does not have any WindowExpression.")) }.toSeq - // For every Window Spec, we add a Window operator and set currentChild as the child of it. + // Third, for every Window Spec, we add a Window operator and set currentChild as the + // child of it. var currentChild = child var i = 0 - while (i < groupedWindowExpression.size) { - val (windowSpec, windowExpressions) = groupedWindowExpression(i) + while (i < groupedWindowExpressions.size) { + val (windowSpec, windowExpressions) = groupedWindowExpressions(i) // Set currentChild to the newly created Window operator. currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild) - // Move to next WindowExpression. + // Move to next Window Spec. i += 1 } - // We return the top operator. - currentChild - } + // Finally, we create a Project to output currentChild's output + // newExpressionsWithWindowFunctions. + Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild) + } // end of addWindow // We have to use transformDown at here to make sure the rule of // "Aggregate with Having clause" will be triggered. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 208021c421326..3e240fd55e250 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -140,7 +140,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { trait OverrideCatalog extends Catalog { // TODO: This doesn't work when the database changes... - val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]() + val overrides = new mutable.HashMap[(Option[String], String), LogicalPlan]() abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = { val tableIdent = processTableIdentifier(tableIdentifier) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f104e742c90fe..c0695ae369421 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -62,14 +62,21 @@ trait CheckAnalysis { val from = operator.inputSet.map(_.name).mkString(", ") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + case e: Expression if e.checkInputDataTypes().isFailure => + e.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(message) => + e.failAnalysis( + s"cannot resolve '${e.prettyString}' due to data type mismatch: $message") + } + case c: Cast if !c.resolved => failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case b: BinaryExpression if !b.resolved => + case WindowExpression(UnresolvedWindowFunction(name, _), _) => failAnalysis( - s"invalid expression ${b.prettyString} " + - s"between ${b.left.dataType.simpleString} and ${b.right.dataType.simpleString}") + s"Could not resolve window function '$name'. " + + "Note that, using window functions currently requires a HiveContext") case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty => // The window spec is not valid. @@ -86,12 +93,12 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK - case e: Attribute if !groupingExprs.contains(e) => + case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + s"nor is it an aggregate function. " + "Add to group by or wrap in first() if you don't care which value you get.") - case e if groupingExprs.contains(e) => // OK + case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK case e if e.references.isEmpty => // OK case e => e.children.foreach(checkValidAggregateExpression) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 16ca5bcd57a72..0849faa9bfa7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.Expression import scala.collection.mutable @@ -28,12 +29,12 @@ trait FunctionRegistry { def lookupFunction(name: String, children: Seq[Expression]): Expression - def caseSensitive: Boolean + def conf: CatalystConf } trait OverrideFunctionRegistry extends FunctionRegistry { - val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) + val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis) override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) @@ -44,8 +45,9 @@ trait OverrideFunctionRegistry extends FunctionRegistry { } } -class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistry { - val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive) +class SimpleFunctionRegistry(val conf: CatalystConf) extends FunctionRegistry { + + val functionBuilders = StringKeyHashMap[FunctionBuilder](conf.caseSensitiveAnalysis) override def registerFunction(name: String, builder: FunctionBuilder): Unit = { functionBuilders.put(name, builder) @@ -69,7 +71,7 @@ object EmptyFunctionRegistry extends FunctionRegistry { throw new UnsupportedOperationException } - override def caseSensitive: Boolean = throw new UnsupportedOperationException + override def conf: CatalystConf = throw new UnsupportedOperationException } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 168a4e30eab86..b064600e94fac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -41,7 +41,7 @@ object HiveTypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. */ - val findTightestCommonType: (DataType, DataType) => Option[DataType] = { + val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -57,6 +57,17 @@ object HiveTypeCoercion { case _ => None } + + /** + * Find the tightest common type of a set of types by continuously applying + * `findTightestCommonTypeOfTwo` on these types. + */ + private def findTightestCommonType(types: Seq[DataType]) = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case None => None + case Some(d) => findTightestCommonTypeOfTwo(d, c) + }) + } } /** @@ -76,8 +87,7 @@ trait HiveTypeCoercion { WidenTypes :: PromoteStrings :: DecimalPrecision :: - BooleanComparisons :: - BooleanCasts :: + BooleanEqualization :: StringToIntegralCasts :: FunctionArgumentConversion :: CaseWhenCoercion :: @@ -120,7 +130,7 @@ trait HiveTypeCoercion { * the appropriate numeric equivalent. */ object ConvertNaNs extends Rule[LogicalPlan] { - val stringNaN = Literal("NaN") + private val stringNaN = Literal("NaN") def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { @@ -181,7 +191,7 @@ trait HiveTypeCoercion { case (l, r) if l.dataType != r.dataType => logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") - findTightestCommonType(l.dataType, r.dataType).map { widestType => + findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() val newRight = @@ -218,7 +228,7 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e case b: BinaryExpression if b.left.dataType != b.right.dataType => - findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType => + findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType => val newLeft = if (b.left.dataType == widestType) b.left else Cast(b.left, widestType) val newRight = @@ -251,10 +261,10 @@ trait HiveTypeCoercion { p.makeCopy(Array(Cast(p.left, StringType), p.right)) case p: BinaryComparison if p.left.dataType == StringType && p.right.dataType == TimestampType => - p.makeCopy(Array(p.left, Cast(p.right, StringType))) + p.makeCopy(Array(Cast(p.left, TimestampType), p.right)) case p: BinaryComparison if p.left.dataType == TimestampType && p.right.dataType == StringType => - p.makeCopy(Array(Cast(p.left, StringType), p.right)) + p.makeCopy(Array(p.left, Cast(p.right, TimestampType))) case p: BinaryComparison if p.left.dataType == TimestampType && p.right.dataType == DateType => p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType))) @@ -274,7 +284,7 @@ trait HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b)) case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) => - i.makeCopy(Array(Cast(a, StringType), b)) + i.makeCopy(Array(a, b.map(Cast(_, TimestampType)))) case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) @@ -296,6 +306,9 @@ trait HiveTypeCoercion { */ object InConversion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + case i @ In(a, b) if b.exists(_.dataType != a.dataType) => i.makeCopy(Array(a, b.map(Cast(_, a.dataType)))) } @@ -347,17 +360,17 @@ trait HiveTypeCoercion { import scala.math.{max, min} // Conversion rules for integer types into fixed-precision decimals - val intTypeToFixed: Map[DataType, DecimalType] = Map( + private val intTypeToFixed: Map[DataType, DecimalType] = Map( ByteType -> DecimalType(3, 0), ShortType -> DecimalType(5, 0), IntegerType -> DecimalType(10, 0), LongType -> DecimalType(20, 0) ) - def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType // Conversion rules for float and double into fixed-precision decimals - val floatTypeToFixed: Map[DataType, DecimalType] = Map( + private val floatTypeToFixed: Map[DataType, DecimalType] = Map( FloatType -> DecimalType(7, 7), DoubleType -> DecimalType(15, 15) ) @@ -439,21 +452,18 @@ trait HiveTypeCoercion { DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) ) - case LessThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThan(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) - - case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1), - e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)) + // When we compare 2 decimal types with different precisions, cast them to the smallest + // common precision. + case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), + e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => + val resultType = DecimalType(max(p1, p2), max(s1, s2)) + b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) + case b @ BinaryComparison(e1 @ DecimalType.Fixed(_, _), e2) + if e2.dataType == DecimalType.Unlimited => + b.makeCopy(Array(Cast(e1, DecimalType.Unlimited), e2)) + case b @ BinaryComparison(e1, e2 @ DecimalType.Fixed(_, _)) + if e1.dataType == DecimalType.Unlimited => + b.makeCopy(Array(e1, Cast(e2, DecimalType.Unlimited))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles @@ -480,56 +490,66 @@ trait HiveTypeCoercion { } /** - * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. + * Changes numeric values to booleans so that expressions like true = 1 can be evaluated. */ - object BooleanComparisons extends Rule[LogicalPlan] { - val trueValues = Seq(1, 1L, 1.toByte, 1.toShort, new java.math.BigDecimal(1)).map(Literal(_)) - val falseValues = Seq(0, 0L, 0.toByte, 0.toShort, new java.math.BigDecimal(0)).map(Literal(_)) + object BooleanEqualization extends Rule[LogicalPlan] { + private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1)) + private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0)) + + private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { + CaseKeyWhen(numericExpr, Seq( + Literal(trueValues.head), booleanExpr, + Literal(falseValues.head), Not(booleanExpr), + Literal(false))) + } - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e + private def transform(booleanExpr: Expression, numericExpr: Expression) = { + If(Or(IsNull(booleanExpr), IsNull(numericExpr)), + Literal.create(null, BooleanType), + buildCaseKeyWhen(booleanExpr, numericExpr)) + } - // Hive treats (true = 1) as true and (false = 0) as true. - case EqualTo(l @ BooleanType(), r) if trueValues.contains(r) => l - case EqualTo(l, r @ BooleanType()) if trueValues.contains(l) => r - case EqualTo(l @ BooleanType(), r) if falseValues.contains(r) => Not(l) - case EqualTo(l, r @ BooleanType()) if falseValues.contains(l) => Not(r) - - // No need to change other EqualTo operators as that actually makes sense for boolean types. - case e: EqualTo => e - // No need to change the EqualNullSafe operators, too - case e: EqualNullSafe => e - // Otherwise turn them to Byte types so that there exists and ordering. - case p: BinaryComparison if p.left.dataType == BooleanType && - p.right.dataType == BooleanType => - p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType))) + private def transformNullSafe(booleanExpr: Expression, numericExpr: Expression) = { + CaseWhen(Seq( + And(IsNull(booleanExpr), IsNull(numericExpr)), Literal(true), + Or(IsNull(booleanExpr), IsNull(numericExpr)), Literal(false), + buildCaseKeyWhen(booleanExpr, numericExpr) + )) } - } - /** - * Casts to/from [[BooleanType]] are transformed into comparisons since - * the JVM does not consider Booleans to be numeric types. - */ - object BooleanCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // Skip if the type is boolean type already. Note that this extra cast should be removed - // by optimizer.SimplifyCasts. - case Cast(e, BooleanType) if e.dataType == BooleanType => e - // DateType should be null if be cast to boolean. - case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType) - // If the data type is not boolean and is being cast boolean, turn it into a comparison - // with the numeric value, i.e. x != 0. This will coerce the type into numeric type. - case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0))) - // Stringify boolean if casting to StringType. - // TODO Ensure true/false string letter casing is consistent with Hive in all cases. - case Cast(e, StringType) if e.dataType == BooleanType => - If(e, Literal("true"), Literal("false")) - // Turn true into 1, and false into 0 if casting boolean into other types. - case Cast(e, dataType) if e.dataType == BooleanType => - Cast(If(e, Literal(1), Literal(0)), dataType) + + // Hive treats (true = 1) as true and (false = 0) as true, + // all other cases are considered as false. + + // We may simplify the expression if one side is literal numeric values + case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => l + case EqualTo(l @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => Not(l) + case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) + if trueValues.contains(value) => r + case EqualTo(Literal(value, _: NumericType), r @ BooleanType()) + if falseValues.contains(value) => Not(r) + case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) + if trueValues.contains(value) => And(IsNotNull(l), l) + case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType)) + if falseValues.contains(value) => And(IsNotNull(l), Not(l)) + case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) + if trueValues.contains(value) => And(IsNotNull(r), r) + case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType()) + if falseValues.contains(value) => And(IsNotNull(r), Not(r)) + + case EqualTo(l @ BooleanType(), r @ NumericType()) => + transform(l , r) + case EqualTo(l @ NumericType(), r @ BooleanType()) => + transform(r, l) + case EqualNullSafe(l @ BooleanType(), r @ NumericType()) => + transformNullSafe(l, r) + case EqualNullSafe(l @ NumericType(), r @ BooleanType()) => + transformNullSafe(r, l) } } @@ -558,8 +578,7 @@ trait HiveTypeCoercion { case a @ CreateArray(children) if !a.resolved => val commonType = a.childTypes.reduce( - (a,b) => - findTightestCommonType(a,b).getOrElse(StringType)) + (a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType)) CreateArray( children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) @@ -588,14 +607,9 @@ trait HiveTypeCoercion { // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => - val dt: Option[DataType] = Some(NullType) val types = es.map(_.dataType) - val rt = types.foldLeft(dt)((r, c) => r match { - case None => None - case Some(d) => findTightestCommonType(d, c) - }) - rt match { - case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt))) + findTightestCommonType(types) match { + case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") } @@ -608,17 +622,13 @@ trait HiveTypeCoercion { */ object Division extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e + // Skip nodes who has not been resolved yet, + // as this is an extra rule which should be applied at last. + case e if !e.resolved => e // Decimal and Double remain the same - case d: Divide if d.resolved && d.dataType == DoubleType => d - case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d - - case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] => - Divide(l, Cast(r, DecimalType.Unlimited)) - case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] => - Divide(Cast(l, DecimalType.Unlimited), r) + case d: Divide if d.dataType == DoubleType => d + case d: Divide if d.dataType.isInstanceOf[DecimalType] => d case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } @@ -631,25 +641,33 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual => - logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") - val commonType = cw.valueTypes.reduce { (v1, v2) => - findTightestCommonType(v1, v2).getOrElse(sys.error( - s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) - } - val transformedBranches = cw.branches.sliding(2, 2).map { - case Seq(when, value) if value.dataType != commonType => - Seq(when, Cast(value, commonType)) - case Seq(elseVal) if elseVal.dataType != commonType => - Seq(Cast(elseVal, commonType)) - case s => s - }.reduce(_ ++ _) - cw match { - case _: CaseWhen => - CaseWhen(transformedBranches) - case CaseKeyWhen(key, _) => - CaseKeyWhen(key, transformedBranches) - } + case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => + logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") + val maybeCommonType = findTightestCommonType(c.valueTypes) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { + case Seq(when, value) if value.dataType != commonType => + Seq(when, Cast(value, commonType)) + case Seq(elseVal) if elseVal.dataType != commonType => + Seq(Cast(elseVal, commonType)) + case other => other + }.reduce(_ ++ _) + c match { + case _: CaseWhen => CaseWhen(castedBranches) + case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches) + } + }.getOrElse(c) + + case c: CaseKeyWhen if c.childrenResolved && !c.resolved => + val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType)) + maybeCommonType.map { commonType => + val castedBranches = c.branches.grouped(2).map { + case Seq(when, then) if when.dataType != commonType => + Seq(Cast(when, commonType), then) + case other => other + }.reduce(_ ++ _) + CaseKeyWhen(Cast(c.key, commonType), castedBranches) + }.getOrElse(c) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala new file mode 100644 index 0000000000000..79c3528a522d3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCheckResult.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +/** + * Represents the result of `Expression.checkInputDataTypes`. + * We will throw `AnalysisException` in `CheckAnalysis` if `isFailure` is true. + */ +trait TypeCheckResult { + def isFailure: Boolean = !isSuccess + def isSuccess: Boolean +} + +object TypeCheckResult { + + /** + * Represents the successful result of `Expression.checkInputDataTypes`. + */ + object TypeCheckSuccess extends TypeCheckResult { + def isSuccess: Boolean = true + } + + /** + * Represents the failing result of `Expression.checkInputDataTypes`, + * with a error message to show the reason of failure. + */ + case class TypeCheckFailure(message: String) extends TypeCheckResult { + def isSuccess: Boolean = false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 2999c2ef3efe1..bbb150c1e83c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -67,7 +67,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name" @@ -85,7 +85,7 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name(${children.mkString(",")})" @@ -107,7 +107,7 @@ trait Star extends NamedExpression with trees.LeafNode[Expression] { override lazy val resolved = false // Star gets expanded at runtime so we never evaluate a Star. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] @@ -166,7 +166,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) override lazy val resolved = false - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child AS $names" @@ -200,7 +200,7 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child[$extraction]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 4c0d70203c6f5..51821757967d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ @@ -61,7 +60,7 @@ package object dsl { trait ImplicitOperators { def expr: Expression - def unary_- : Expression= UnaryMinus(expr) + def unary_- : Expression = UnaryMinus(expr) def unary_! : Predicate = Not(expr) def unary_~ : Expression = BitwiseNot(expr) @@ -141,7 +140,7 @@ package object dsl { // Note that if we make ExpressionConversions an object rather than a trait, we can // then make this a value class to avoid the small penalty of runtime instantiation. def $(args: Any*): analysis.UnresolvedAttribute = { - analysis.UnresolvedAttribute(sc.s(args :_*)) + analysis.UnresolvedAttribute(sc.s(args : _*)) } } @@ -234,129 +233,59 @@ package object dsl { implicit class DslAttribute(a: AttributeReference) { def notNull: AttributeReference = a.withNullability(false) def nullable: AttributeReference = a.withNullability(true) - - // Protobuf terminology - def required: AttributeReference = a.withNullability(false) - def at(ordinal: Int): BoundReference = BoundReference(ordinal, a.dataType, a.nullable) } } - object expressions extends ExpressionConversions // scalastyle:ignore - abstract class LogicalPlanFunctions { - def logicalPlan: LogicalPlan - - def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) + object plans { // scalastyle:ignore + implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) { + def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) - def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) + def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) - def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) - def join( + def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, condition: Option[Expression] = None): LogicalPlan = - Join(logicalPlan, otherPlan, joinType, condition) + Join(logicalPlan, otherPlan, joinType, condition) - def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) + def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) - def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) + def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { - val aliasedExprs = aggregateExprs.map { - case ne: NamedExpression => ne - case e => Alias(e, e.toString)() + def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): LogicalPlan = { + val aliasedExprs = aggregateExprs.map { + case ne: NamedExpression => ne + case e => Alias(e, e.toString)() + } + Aggregate(groupingExprs, aliasedExprs, logicalPlan) } - Aggregate(groupingExprs, aliasedExprs, logicalPlan) - } - def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) + def subquery(alias: Symbol): LogicalPlan = Subquery(alias.name, logicalPlan) + + def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) - def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) + def intersect(otherPlan: LogicalPlan): LogicalPlan = Intersect(logicalPlan, otherPlan) - def sfilter[T1](arg1: Symbol)(udf: (T1) => Boolean): LogicalPlan = - Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan) + def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - // TODO specify the output column names - def generate( + // TODO specify the output column names + def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, alias: Option[String] = None): LogicalPlan = - Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) + Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) - def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = - InsertIntoTable( - analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) + def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = + InsertIntoTable( + analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite, false) - def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) - } - - object plans { // scalastyle:ignore - implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) extends LogicalPlanFunctions { - def writeToFile(path: String): LogicalPlan = WriteToFile(path, logicalPlan) + def analyze: LogicalPlan = EliminateSubQueries(analysis.SimpleAnalyzer.execute(logicalPlan)) } } - - case class ScalaUdfBuilder[T: TypeTag](f: AnyRef) { - def call(args: Expression*): ScalaUdf = { - ScalaUdf(f, ScalaReflection.schemaFor(typeTag[T]).dataType, args) - } - } - - // scalastyle:off - /** functionToUdfBuilder 1-22 were generated by this script - - (1 to 22).map { x => - val argTypes = Seq.fill(x)("_").mkString(", ") - s"implicit def functionToUdfBuilder[T: TypeTag](func: Function$x[$argTypes, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func)" - } - */ - - implicit def functionToUdfBuilder[T: TypeTag](func: Function1[_, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function2[_, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function3[_, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function4[_, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function5[_, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function6[_, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function7[_, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function8[_, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function9[_, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function10[_, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - - implicit def functionToUdfBuilder[T: TypeTag](func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): ScalaUdfBuilder[T] = ScalaUdfBuilder(func) - // scalastyle:on } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index 0fd4f9b374ee0..d2a90a50c89f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -49,11 +49,4 @@ package object errors { case e: Exception => throw new TreeNodeException(tree, msg, e) } } - - /** - * Executes `f` which is expected to throw a - * [[catalyst.errors.TreeNodeException TreeNodeException]]. The first tree encountered in - * the stack of exceptions of type `TreeType` is returned. - */ - def getTree[TreeType <: TreeNode[_]](f: => Unit): TreeType = ??? // TODO: Implement } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index c6217f07c452d..1ffc95c676f6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -30,8 +30,6 @@ import org.apache.spark.sql.catalyst.trees case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends NamedExpression with trees.LeafNode[Expression] { - type EvaluatedType = Any - override def toString: String = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index adf941ab2a45f..21adac144112e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ /** Cast the child expression to the target data type. */ @@ -34,48 +35,48 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true - case (StringType, TimestampType) => true - case (DoubleType, TimestampType) => true - case (FloatType, TimestampType) => true - case (StringType, DateType) => true - case (_: NumericType, DateType) => true - case (BooleanType, DateType) => true - case (DateType, _: NumericType) => true - case (DateType, BooleanType) => true + case (StringType, TimestampType) => true + case (DoubleType, TimestampType) => true + case (FloatType, TimestampType) => true + case (StringType, DateType) => true + case (_: NumericType, DateType) => true + case (BooleanType, DateType) => true + case (DateType, _: NumericType) => true + case (DateType, BooleanType) => true case (DoubleType, _: DecimalType) => true - case (FloatType, _: DecimalType) => true + case (FloatType, _: DecimalType) => true case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null - case _ => false + case _ => false } private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to private[this] def resolve(from: DataType, to: DataType): Boolean = { (from, to) match { - case (from, to) if from == to => true + case (from, to) if from == to => true - case (NullType, _) => true + case (NullType, _) => true - case (_, StringType) => true + case (_, StringType) => true - case (StringType, BinaryType) => true + case (StringType, BinaryType) => true - case (StringType, BooleanType) => true - case (DateType, BooleanType) => true - case (TimestampType, BooleanType) => true - case (_: NumericType, BooleanType) => true + case (StringType, BooleanType) => true + case (DateType, BooleanType) => true + case (TimestampType, BooleanType) => true + case (_: NumericType, BooleanType) => true - case (StringType, TimestampType) => true - case (BooleanType, TimestampType) => true - case (DateType, TimestampType) => true - case (_: NumericType, TimestampType) => true + case (StringType, TimestampType) => true + case (BooleanType, TimestampType) => true + case (DateType, TimestampType) => true + case (_: NumericType, TimestampType) => true - case (_, DateType) => true + case (_, DateType) => true - case (StringType, _: NumericType) => true - case (BooleanType, _: NumericType) => true - case (DateType, _: NumericType) => true - case (TimestampType, _: NumericType) => true + case (StringType, _: NumericType) => true + case (BooleanType, _: NumericType) => true + case (DateType, _: NumericType) => true + case (TimestampType, _: NumericType) => true case (_: NumericType, _: NumericType) => true case (ArrayType(from, fn), ArrayType(to, tn)) => @@ -104,8 +105,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override def toString: String = s"CAST($child, $dataType)" - type EvaluatedType = Any - // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) @@ -411,21 +410,21 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] def cast(from: DataType, to: DataType): Any => Any = to match { case dt if dt == child.dataType => identity[Any] - case StringType => castToString(from) - case BinaryType => castToBinary(from) - case DateType => castToDate(from) - case decimal: DecimalType => castToDecimal(from, decimal) - case TimestampType => castToTimestamp(from) - case BooleanType => castToBoolean(from) - case ByteType => castToByte(from) - case ShortType => castToShort(from) - case IntegerType => castToInt(from) - case FloatType => castToFloat(from) - case LongType => castToLong(from) - case DoubleType => castToDouble(from) - case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array) - case map: MapType => castMap(from.asInstanceOf[MapType], map) - case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case StringType => castToString(from) + case BinaryType => castToBinary(from) + case DateType => castToDate(from) + case decimal: DecimalType => castToDecimal(from, decimal) + case TimestampType => castToTimestamp(from) + case BooleanType => castToBoolean(from) + case ByteType => castToByte(from) + case ShortType => castToShort(from) + case IntegerType => castToInt(from) + case FloatType => castToFloat(from) + case LongType => castToLong(from) + case DoubleType => castToDouble(from) + case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array) + case map: MapType => castMap(from.asInstanceOf[MapType], map) + case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 0837a3179d897..3cf851aec15ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -25,9 +25,6 @@ import org.apache.spark.sql.types._ abstract class Expression extends TreeNode[Expression] { self: Product => - /** The narrowest possible type that is produced when this expression is evaluated. */ - type EvaluatedType <: Any - /** * Returns true when an expression is a candidate for static evaluation before the query is * executed. @@ -40,19 +37,28 @@ abstract class Expression extends TreeNode[Expression] { * - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable */ def foldable: Boolean = false + + /** + * Returns true when the current expression always return the same result for fixed input values. + */ + // TODO: Need to define explicit input values vs implicit input values. + def deterministic: Boolean = true + def nullable: Boolean + def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator)) /** Returns the result of evaluating this expression on a given input Row */ - def eval(input: Row = null): EvaluatedType + def eval(input: Row = null): Any /** * Returns `true` if this expression and all its children have been resolved to a specific schema - * and `false` if it still contains any unresolved placeholders. Implementations of expressions - * should override this if the resolution of this type of expression involves more than just - * the resolution of its children. + * and input data types checking passed, and `false` if it still contains any unresolved + * placeholders or has data types mismatch. + * Implementations of expressions should override this if the resolution of this type of + * expression involves more than just the resolution of its children and type checking. */ - lazy val resolved: Boolean = childrenResolved + lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess /** * Returns the [[DataType]] of the result of evaluating this expression. It is @@ -76,12 +82,34 @@ abstract class Expression extends TreeNode[Expression] { case u: UnresolvedAttribute => PrettyAttribute(u.name) }.toString } + + /** + * Returns true when two expressions will always compute the same result, even if they differ + * cosmetically (i.e. capitalization of names in attributes may be different). + */ + def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { + val elements1 = this.productIterator.toSeq + val elements2 = other.asInstanceOf[Product].productIterator.toSeq + elements1.length == elements2.length && elements1.zip(elements2).forall { + case (e1: Expression, e2: Expression) => e1 semanticEquals e2 + case (i1, i2) => i1 == i2 + } + } + + /** + * Checks the input data types, returns `TypeCheckResult.success` if it's valid, + * or returns a `TypeCheckResult` with an error message if invalid. + * Note: it's not valid to call this method until `childrenResolved == true` + * TODO: we should remove the default implementation and implement it for all + * expressions with proper error message. + */ + def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { self: Product => - def symbol: String + def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") override def foldable: Boolean = left.foldable && right.foldable @@ -104,8 +132,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio // not like a real expressions. case class GroupExpression(children: Seq[Expression]) extends Expression { self: Product => - type EvaluatedType = Seq[Any] - override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def eval(input: Row): Any = throw new UnsupportedOperationException override def nullable: Boolean = false override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException @@ -116,7 +143,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { * so that the proper type conversions can be performed in the analyzer. */ trait ExpectsInputTypes { + self: Expression => def expectedChildTypes: Seq[DataType] + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index e05926cbfe74b..a1e0819e8a433 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -47,7 +47,7 @@ object ExtractValue { case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) => val ordinal = findField(fields, fieldName.toString, resolver) GetArrayStructFields(child, fields(ordinal), ordinal, containsNull) - case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => + case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) case (_: MapType, _) => GetMapValue(child, extraction) @@ -92,8 +92,6 @@ object ExtractValue { trait ExtractValue extends UnaryExpression { self: Product => - - type EvaluatedType = Any } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 9a77ca624ebe2..5b45347872cca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.types.DataType case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { - type EvaluatedType = Any - override def nullable: Boolean = true override def toString: String = s"scalaUDF(${children.mkString(",")})" @@ -55,9 +53,9 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi }.foreach(println) */ - - val f = children.size match { - case 0 => + + private[this] val f = children.size match { + case 0 => val func = function.asInstanceOf[() => Any] (input: Row) => { func() @@ -956,7 +954,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi } // scalastyle:on - - override def eval(input: Row): Any = CatalystTypeConverters.convertToCatalyst(f(input), dataType) + private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) + override def eval(input: Row): Any = converter(f(input)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 83074eb1e6310..99340a14c9ecc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -29,14 +29,14 @@ case object Descending extends SortDirection * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder(child: Expression, direction: SortDirection) extends Expression +case class SortOrder(child: Expression, direction: SortDirection) extends Expression with trees.UnaryNode[Expression] { override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable // SortOrder itself is never evaluated. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index f3830c6d3bcf2..0266084a6d174 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -37,7 +37,7 @@ abstract class AggregateExpression extends Expression { * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are * replaced with a physical aggregate operator at runtime. */ - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -74,8 +74,6 @@ abstract class AggregateFunction extends AggregateExpression with Serializable with trees.LeafNode[Expression] { self: Product => - override type EvaluatedType = Any - /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression @@ -113,7 +111,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr override def update(input: Row): Unit = { if (currentMin.value == null) { currentMin.value = expr.eval(input) - } else if(cmp.eval(input) == true) { + } else if (cmp.eval(input) == true) { currentMin.value = expr.eval(input) } } @@ -144,7 +142,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr override def update(input: Row): Unit = { if (currentMax.value == null) { currentMax.value = expr.eval(input) - } else if(cmp.eval(input) == true) { + } else if (cmp.eval(input) == true) { currentMax.value = expr.eval(input) } } @@ -396,13 +394,13 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ * Combining PartitionLevel InputData * <-- null * Zero <-- Zero <-- null - * + * * <-- null <-- no data - * null <-- null <-- no data + * null <-- null <-- no data */ case class CombineSum(child: Expression) extends AggregateExpression { def this() = this(null) - + override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -618,7 +616,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr private val sum = MutableLiteral(null, calcType) - private val addFunction = + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) override def update(input: Row): Unit = { @@ -636,7 +634,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr case class CombineSumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - + def this() = this(null, null) // Required for serialization. private val calcType = @@ -651,12 +649,12 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) private val sum = MutableLiteral(null, calcType) - private val addFunction = + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - + override def update(input: Row): Unit = { val result = expr.eval(input) - // partial sum result can be null only when no input rows present + // partial sum result can be null only when no input rows present if(result != null) { sum.update(addFunction, input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index c7a37ad966df6..2ac53f8f6613f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -17,76 +17,89 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -case class UnaryMinus(child: Expression) extends UnaryExpression { - type EvaluatedType = Any +abstract class UnaryArithmetic extends UnaryExpression { + self: Product => - override def dataType: DataType = child.dataType override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable - override def toString: String = s"-$child" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override def dataType: DataType = child.dataType override def eval(input: Row): Any = { val evalE = child.eval(input) if (evalE == null) { null } else { - numeric.negate(evalE) + evalInternal(evalE) } } + + protected def evalInternal(evalE: Any): Any = + sys.error(s"UnaryArithmetics must override either eval or evalInternal") } -case class Sqrt(child: Expression) extends UnaryExpression { - type EvaluatedType = Any +case class UnaryMinus(child: Expression) extends UnaryArithmetic { + override def toString: String = s"-$child" + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "operator -") + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE: Any) = numeric.negate(evalE) +} + +case class Sqrt(child: Expression) extends UnaryArithmetic { override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"SQRT($child)" - lazy val numeric = child.dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support non-negative numeric operations") - } + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sqrt") - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val value = numeric.toDouble(evalE) - if (value < 0) null - else math.sqrt(value) - } + private lazy val numeric = TypeUtils.getNumeric(child.dataType) + + protected override def evalInternal(evalE: Any) = { + val value = numeric.toDouble(evalE) + if (value < 0) null + else math.sqrt(value) } } +/** + * A function that get the absolute value of the numeric value. + */ +case class Abs(child: Expression) extends UnaryArithmetic { + override def toString: String = s"Abs($child)" + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function abs") + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE: Any) = numeric.abs(evalE) +} + abstract class BinaryArithmetic extends BinaryExpression { self: Product => - type EvaluatedType = Any + override def dataType: DataType = left.dataType - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in ${this.getClass.getSimpleName} " + + s"(${left.dataType} and ${right.dataType}).") + } else { + checkTypesInternal(dataType) } - left.dataType } + protected def checkTypesInternal(t: DataType): TypeCheckResult + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if(evalE1 == null) { @@ -101,90 +114,67 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = - sys.error(s"BinaryExpressions must either override eval or evalInternal") + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryArithmetics must override either eval or evalInternal") } case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.plus(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.plus(evalE1, evalE2) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "-" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.minus(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.minus(evalE1, evalE2) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "*" - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - numeric.times(evalE1, evalE2) - } - } - } + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val numeric = TypeUtils.getNumeric(dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.times(evalE1, evalE2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" - override def nullable: Boolean = true - lazy val div: (Any, Any) => Any = dataType match { + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot - case other => sys.error(s"Type $other does not support numeric operations") } - + override def eval(input: Row): Any = { val evalE2 = right.eval(input) if (evalE2 == null || evalE2 == 0) { @@ -202,13 +192,17 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" - override def nullable: Boolean = true - lazy val integral = dataType match { + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) + + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForNumericExpr(t, "operator " + symbol) + + private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] - case other => sys.error(s"Type $other does not support numeric operations") } override def eval(input: Row): Any = { @@ -232,7 +226,10 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "&" - lazy val and: (Any, Any) => Any = dataType match { + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val and: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -241,10 +238,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise & operation on $other") } - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = and(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2) } /** @@ -253,7 +249,10 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "|" - lazy val or: (Any, Any) => Any = dataType match { + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val or: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -262,10 +261,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise | operation on $other") } - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = or(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2) } /** @@ -274,7 +272,10 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" - lazy val xor: (Any, Any) => Any = dataType match { + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForBitwiseExpr(t, "operator " + symbol) + + private lazy val xor: (Any, Any) => Any = dataType match { case ByteType => ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any] case ShortType => @@ -283,24 +284,21 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] case LongType => ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] - case other => sys.error(s"Unsupported bitwise ^ operation on $other") } - override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = xor(evalE1, evalE2) + protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) } /** * A function that calculates bitwise not(~) of a number. */ -case class BitwiseNot(child: Expression) extends UnaryExpression { - type EvaluatedType = Any - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable +case class BitwiseNot(child: Expression) extends UnaryArithmetic { override def toString: String = s"~$child" - lazy val not: (Any) => Any = dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~") + + private lazy val not: (Any) => Any = dataType match { case ByteType => ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any] case ShortType => @@ -309,44 +307,18 @@ case class BitwiseNot(child: Expression) extends UnaryExpression { ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any] case LongType => ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] - case other => sys.error(s"Unsupported bitwise ~ operation on $other") } - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - not(evalE) - } - } + protected override def evalInternal(evalE: Any) = not(evalE) } -case class MaxOf(left: Expression, right: Expression) extends Expression { - type EvaluatedType = Any - - override def foldable: Boolean = left.foldable && right.foldable - +case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - override def children: Seq[Expression] = left :: right :: Nil - - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(t, "function maxOf") - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } - - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + private lazy val ordering = TypeUtils.getOrdering(dataType) override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -367,31 +339,13 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MaxOf($left, $right)" } -case class MinOf(left: Expression, right: Expression) extends Expression { - type EvaluatedType = Any - - override def foldable: Boolean = left.foldable && right.foldable - +case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { override def nullable: Boolean = left.nullable && right.nullable - override def children: Seq[Expression] = left :: right :: Nil + protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(t, "function minOf") - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } - - lazy val ordering = left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } + private lazy val ordering = TypeUtils.getOrdering(dataType) override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -411,29 +365,3 @@ case class MinOf(left: Expression, right: Expression) extends Expression { override def toString: String = s"MinOf($left, $right)" } - -/** - * A function that get the absolute value of the numeric value. - */ -case class Abs(child: Expression) extends UnaryExpression { - type EvaluatedType = Any - - override def dataType: DataType = child.dataType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable - override def toString: String = s"Abs($child)" - - lazy val numeric = dataType match { - case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]] - case other => sys.error(s"Type $other does not support numeric operations") - } - - override def eval(input: Row): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - numeric.abs(evalE) - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d17af0e7ff87e..36964af68dd8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -250,7 +250,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case Cast(child @ DateType(), StringType) => child.castOrNull(c => q"""org.apache.spark.sql.types.UTF8String( - org.apache.spark.sql.types.DateUtils.toString($c))""", + org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""", StringType) case Cast(child @ NumericType(), IntegerType) => @@ -373,7 +373,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin // Uh, bad function name... child.castOrNull(c => q"!$c", BooleanType) - case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } + case Add(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 + $eval2" } case Subtract(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 - $eval2" } case Multiply(e1, e2) => (e1, e2) evaluate { case (eval1, eval2) => q"$eval1 * $eval2" } case Divide(e1, e2) => @@ -665,7 +665,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def defaultPrimitive(dt: DataType) = dt match { case BooleanType => ru.Literal(Constant(false)) case FloatType => ru.Literal(Constant(-1.0.toFloat)) - case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" + case StringType => q"""org.apache.spark.sql.types.UTF8String("")""" case ShortType => ru.Literal(Constant(-1.toShort)) case LongType => ru.Literal(Constant(-1L)) case ByteType => ru.Literal(Constant(-1.toByte)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 584f938445c8c..31c63a79ebc8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -161,7 +161,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } } - val hashValues = expressions.zipWithIndex.map { case (e,i) => + val hashValues = expressions.zipWithIndex.map { case (e, i) => val elementName = newTermName(s"c$i") val nonNull = e.dataType match { case BooleanType => q"if ($elementName) 0 else 1" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 956a2429b0b61..6398b8f9e4ed7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -24,10 +24,9 @@ import org.apache.spark.sql.types._ * Returns an Array containing the evaluation of all children expressions. */ case class CreateArray(children: Seq[Expression]) extends Expression { - override type EvaluatedType = Any - + override def foldable: Boolean = children.forall(_.foldable) - + lazy val childTypes = children.map(_.dataType).distinct override lazy val resolved = @@ -54,7 +53,6 @@ case class CreateArray(children: Seq[Expression]) extends Expression { * TODO: [[CreateStruct]] does not support codegen. */ case class CreateStruct(children: Seq[NamedExpression]) extends Expression { - override type EvaluatedType = Row override def foldable: Boolean = children.forall(_.foldable) @@ -71,7 +69,7 @@ case class CreateStruct(children: Seq[NamedExpression]) extends Expression { override def nullable: Boolean = false - override def eval(input: Row): EvaluatedType = { + override def eval(input: Row): Any = { Row(children.map(_.eval(input)): _*) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index adb94df7d1c7b..65ba18924afe1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ case class UnscaledValue(child: Expression) extends UnaryExpression { - override type EvaluatedType = Any override def dataType: DataType = LongType override def foldable: Boolean = child.foldable @@ -40,7 +39,6 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { /** Create a Decimal from an unscaled Long value */ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { - override type EvaluatedType = Decimal override def dataType: DataType = DecimalType(precision, scale) override def foldable: Boolean = child.foldable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 9a6cb048af5ad..b6191eafba71b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -40,8 +40,6 @@ import org.apache.spark.sql.types._ abstract class Generator extends Expression { self: Product => - override type EvaluatedType = TraversableOnce[Row] - // TODO ideally we should return the type of ArrayType(StructType), // however, we don't keep the output field names in the Generator. override def dataType: DataType = throw new UnsupportedOperationException @@ -56,6 +54,12 @@ abstract class Generator extends Expression { /** Should be implemented by child classes to perform specific Generators. */ override def eval(input: Row): TraversableOnce[Row] + + /** + * Notifies that there are no more rows to process, clean up code, and additional + * rows can be made here. + */ + def terminate(): TraversableOnce[Row] = Nil } /** @@ -67,12 +71,23 @@ case class UserDefinedGenerator( children: Seq[Expression]) extends Generator { + @transient private[this] var inputRow: InterpretedProjection = _ + @transient private[this] var convertToScala: (Row) => Row = _ + + private def initializeConverters(): Unit = { + inputRow = new InterpretedProjection(children) + convertToScala = { + val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) + CatalystTypeConverters.createToScalaConverter(inputSchema) + }.asInstanceOf[(Row => Row)] + } + override def eval(input: Row): TraversableOnce[Row] = { - // TODO(davies): improve this + if (inputRow == null) { + initializeConverters() + } // Convert the objects into Scala Type before calling function, we need schema to support UDT - val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) - val inputRow = new InterpretedProjection(children) - function(CatalystTypeConverters.convertToScala(inputRow(input), inputSchema).asInstanceOf[Row]) + function(convertToScala(inputRow(input))) } override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})" @@ -99,8 +114,8 @@ case class Explode(child: Expression) val inputArray = child.eval(input).asInstanceOf[Seq[Any]] if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) case MapType(_, _, _) => - val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]] - if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) } + val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]] + if (inputMap == null) Nil else inputMap.map { case (k, v) => new GenericRow(Array(k, v)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 18cba4cc46707..d3ca3d9a4b18b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ object Literal { @@ -77,14 +78,12 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def toString: String = if (value != null) value.toString else "null" - type EvaluatedType = Any override def eval(input: Row): Any = value } // TODO: Specialize case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) extends LeafExpression { - type EvaluatedType = Any def update(expression: Expression, input: Row): Unit = { value = expression.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala index fcc06d3aa1036..db853a2b97fad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.mathfuncs -import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} import org.apache.spark.sql.types._ @@ -27,27 +26,14 @@ import org.apache.spark.sql.types._ * @param f The math function. * @param name The short name of the function */ -abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => - type EvaluatedType = Any - override def symbol: String = null + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) - override def nullable: Boolean = left.nullable || right.nullable override def toString: String = s"$name($left, $right)" - override lazy val resolved = - left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) - - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, - s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") - } - left.dataType - } + override def dataType: DataType = DoubleType override def eval(input: Row): Any = { val evalE1 = left.eval(input) @@ -65,9 +51,9 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } -case class Atan2( - left: Expression, - right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { +case class Atan2(left: Expression, right: Expression) + extends BinaryMathExpression(math.atan2, "ATAN2") { + override def eval(input: Row): Any = { val evalE1 = left.eval(input) if (evalE1 == null) { @@ -86,8 +72,7 @@ case class Atan2( } } -case class Hypot( - left: Expression, - right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") +case class Hypot(left: Expression, right: Expression) + extends BinaryMathExpression(math.hypot, "HYPOT") case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala index dc68469e060cb..41b422346a02d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -25,10 +25,9 @@ import org.apache.spark.sql.types._ * input format, therefore these functions extend `ExpectsInputTypes`. * @param name The short name of the function */ -abstract class MathematicalExpression(f: Double => Double, name: String) +abstract class UnaryMathExpression(f: Double => Double, name: String) extends UnaryExpression with Serializable with ExpectsInputTypes { self: Product => - type EvaluatedType = Any override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType @@ -47,46 +46,44 @@ abstract class MathematicalExpression(f: Double => Double, name: String) } } -case class Acos(child: Expression) extends MathematicalExpression(math.acos, "ACOS") +case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") -case class Asin(child: Expression) extends MathematicalExpression(math.asin, "ASIN") +case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") -case class Atan(child: Expression) extends MathematicalExpression(math.atan, "ATAN") +case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") -case class Cbrt(child: Expression) extends MathematicalExpression(math.cbrt, "CBRT") +case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") -case class Ceil(child: Expression) extends MathematicalExpression(math.ceil, "CEIL") +case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") -case class Cos(child: Expression) extends MathematicalExpression(math.cos, "COS") +case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") -case class Cosh(child: Expression) extends MathematicalExpression(math.cosh, "COSH") +case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") -case class Exp(child: Expression) extends MathematicalExpression(math.exp, "EXP") +case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") -case class Expm1(child: Expression) extends MathematicalExpression(math.expm1, "EXPM1") +case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") -case class Floor(child: Expression) extends MathematicalExpression(math.floor, "FLOOR") +case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") -case class Log(child: Expression) extends MathematicalExpression(math.log, "LOG") +case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") -case class Log10(child: Expression) extends MathematicalExpression(math.log10, "LOG10") +case class Log10(child: Expression) extends UnaryMathExpression(math.log10, "LOG10") -case class Log1p(child: Expression) extends MathematicalExpression(math.log1p, "LOG1P") +case class Log1p(child: Expression) extends UnaryMathExpression(math.log1p, "LOG1P") -case class Rint(child: Expression) extends MathematicalExpression(math.rint, "ROUND") +case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") -case class Signum(child: Expression) extends MathematicalExpression(math.signum, "SIGNUM") +case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") -case class Sin(child: Expression) extends MathematicalExpression(math.sin, "SIN") +case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") -case class Sinh(child: Expression) extends MathematicalExpression(math.sinh, "SINH") +case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") -case class Tan(child: Expression) extends MathematicalExpression(math.tan, "TAN") +case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") -case class Tanh(child: Expression) extends MathematicalExpression(math.tanh, "TANH") +case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") -case class ToDegrees(child: Expression) - extends MathematicalExpression(math.toDegrees, "DEGREES") +case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") -case class ToRadians(child: Expression) - extends MathematicalExpression(math.toRadians, "RADIANS") +case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index a9170589f8c6c..00565ec651a59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -111,7 +111,6 @@ case class Alias(child: Expression, name: String)( val explicitMetadata: Option[Metadata] = None) extends NamedExpression with trees.UnaryNode[Expression] { - override type EvaluatedType = Any // Alias(Generator, xx) need to be transformed into Generate(generator, ...) override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator] @@ -181,6 +180,11 @@ case class AttributeReference( case _ => false } + override def semanticEquals(other: Expression): Boolean = other match { + case ar: AttributeReference => sameRef(ar) + case _ => false + } + override def hashCode: Int = { // See http://stackoverflow.com/questions/113511/hash-code-implementation var h = 17 @@ -224,7 +228,7 @@ case class AttributeReference( } // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"$name#${exprId.id}$typeSuffix" @@ -235,7 +239,6 @@ case class AttributeReference( * expression id or the unresolved indicator. */ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { - type EvaluatedType = Any override def toString: String = name @@ -247,7 +250,7 @@ case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[E override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def eval(input: Row): Any = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException override def dataType: DataType = NullType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index f9161cf34f0c9..5070570b4740d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { - type EvaluatedType = Any /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ override def nullable: Boolean = !children.exists(!_.nullable) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 1d72a9eb834b9..807021d50e8e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{DataType, BinaryType, BooleanType, AtomicType} +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType} object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) = @@ -35,8 +35,6 @@ trait Predicate extends Expression { self: Product => override def dataType: DataType = BooleanType - - type EvaluatedType = Any } trait PredicateHelper { @@ -173,22 +171,51 @@ case class Or(left: Expression, right: Expression) abstract class BinaryComparison extends BinaryExpression with Predicate { self: Product => -} -case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { - override def symbol: String = "=" + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in ${this.getClass.getSimpleName} " + + s"(${left.dataType} and ${right.dataType}).") + } else { + checkTypesInternal(dataType) + } + } + + protected def checkTypesInternal(t: DataType): TypeCheckResult override def eval(input: Row): Any = { - val l = left.eval(input) - if (l == null) { + val evalE1 = left.eval(input) + if (evalE1 == null) { null } else { - val r = right.eval(input) - if (r == null) null - else if (left.dataType != BinaryType) l == r - else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + evalInternal(evalE1, evalE2) + } } } + + protected def evalInternal(evalE1: Any, evalE2: Any): Any = + sys.error(s"BinaryComparisons must override either eval or evalInternal") +} + +object BinaryComparison { + def unapply(b: BinaryComparison): Option[(Expression, Expression)] = + Some((b.left, b.right)) +} + +case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { + override def symbol: String = "=" + + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess + + protected override def evalInternal(l: Any, r: Any) = { + if (left.dataType != BinaryType) l == r + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) + } } case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { @@ -196,6 +223,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def nullable: Boolean = false + override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess + override def eval(input: Row): Any = { val l = left.eval(input) val r = right.eval(input) @@ -212,117 +241,45 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.lt(evalE1, evalE2) - } - } - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lt(evalE1, evalE2) } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = "<=" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.lteq(evalE1, evalE2) - } - } - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lteq(evalE1, evalE2) } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.gt(evalE1, evalE2) - } - } - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gt(evalE1, evalE2) } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { override def symbol: String = ">=" - lazy val ordering: Ordering[Any] = { - if (left.dataType != right.dataType) { - throw new TreeNodeException(this, - s"Types do not match ${left.dataType} != ${right.dataType}") - } - left.dataType match { - case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] - case other => sys.error(s"Type $other does not support ordered operations") - } - } + override protected def checkTypesInternal(t: DataType) = + TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol) - override def eval(input: Row): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - ordering.gteq(evalE1, evalE2) - } - } - } + private lazy val ordering = TypeUtils.getOrdering(left.dataType) + + protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2) } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) @@ -331,17 +288,19 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil override def nullable: Boolean = trueValue.nullable || falseValue.nullable - override lazy val resolved = childrenResolved && trueValue.dataType == falseValue.dataType - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException( - this, - s"Can not resolve due to differing types ${trueValue.dataType}, ${falseValue.dataType}") + override def checkInputDataTypes(): TypeCheckResult = { + if (predicate.dataType != BooleanType) { + TypeCheckResult.TypeCheckFailure( + s"type of predicate expression in If should be boolean, not ${predicate.dataType}") + } else if (trueValue.dataType != falseValue.dataType) { + TypeCheckResult.TypeCheckFailure( + s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).") + } else { + TypeCheckResult.TypeCheckSuccess } - trueValue.dataType } - type EvaluatedType = Any + override def dataType: DataType = trueValue.dataType override def eval(input: Row): Any = { if (true == predicate.eval(input)) { @@ -357,8 +316,6 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi trait CaseWhenLike extends Expression { self: Product => - type EvaluatedType = Any - // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last // element is the value for the default catch-all case (if provided). // Hence, `branches` consists of at least two elements, and can have an odd or even length. @@ -370,17 +327,23 @@ trait CaseWhenLike extends Expression { branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) - // both then and else val should be considered. + // both then and else expressions should be considered. def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) - def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1 + def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") + override def checkInputDataTypes(): TypeCheckResult = { + if (valueTypesEqual) { + checkTypesInternal() + } else { + TypeCheckResult.TypeCheckFailure( + "THEN and ELSE expressions should all be same type or coercible to a common type") } - valueTypes.head } + protected def checkTypesInternal(): TypeCheckResult + + override def dataType: DataType = thenList.head.dataType + override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) @@ -401,10 +364,16 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { override def children: Seq[Expression] = branches - override lazy val resolved: Boolean = - childrenResolved && - whenList.forall(_.dataType == BooleanType) && - valueTypesEqual + override protected def checkTypesInternal(): TypeCheckResult = { + if (whenList.forall(_.dataType == BooleanType)) { + TypeCheckResult.TypeCheckSuccess + } else { + val index = whenList.indexWhere(_.dataType != BooleanType) + TypeCheckResult.TypeCheckFailure( + s"WHEN expressions in CaseWhen should all be boolean type, " + + s"but the ${index + 1}th when expression's type is ${whenList(index)}") + } + } /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { @@ -447,8 +416,14 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW override def children: Seq[Expression] = key +: branches - override lazy val resolved: Boolean = - childrenResolved && valueTypesEqual + override protected def checkTypesInternal(): TypeCheckResult = { + if ((key +: whenList).map(_.dataType).distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + "key and WHEN expressions should all be same type or coercible to a common type") + } else { + TypeCheckResult.TypeCheckSuccess + } + } /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 66d7c8b07cce8..b2647124c4e49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -24,7 +24,7 @@ import org.apache.spark.util.random.XORShiftRandom /** * A Random distribution generating expression. - * TODO: This can be made generic to generate any type of random distribution, or any type of + * TODO: This can be made generic to generate any type of random distribution, or any type of * StructType. * * Since this expression is stateful, it cannot be a case object. @@ -38,7 +38,7 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable { */ @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId()) - override type EvaluatedType = Double + override def deterministic: Boolean = false override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 4c44182278207..b65bf165f21db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -51,7 +51,6 @@ private[sql] class OpenHashSetUDT( * Creates a new set of the specified type */ case class NewSet(elementType: DataType) extends LeafExpression { - type EvaluatedType = Any override def nullable: Boolean = false @@ -69,7 +68,6 @@ case class NewSet(elementType: DataType) extends LeafExpression { * For performance, this expression mutates its input during evaluation. */ case class AddItemToSet(item: Expression, set: Expression) extends Expression { - type EvaluatedType = Any override def children: Seq[Expression] = item :: set :: Nil @@ -101,7 +99,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { * For performance, this expression mutates its left input set during evaluation. */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { - type EvaluatedType = Any override def nullable: Boolean = left.nullable || right.nullable @@ -133,7 +130,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres * Returns the number of elements in the input set. */ case class CountSet(child: Expression) extends UnaryExpression { - type EvaluatedType = Any override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 7683e0990ce80..c4ef9c30907f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -25,8 +25,6 @@ import org.apache.spark.sql.types._ trait StringRegexExpression extends ExpectsInputTypes { self: BinaryExpression => - type EvaluatedType = Any - def escape(v: String): String def matches(regex: Pattern, str: String): Boolean @@ -40,14 +38,14 @@ trait StringRegexExpression extends ExpectsInputTypes { case _ => null } - protected def compile(str: String): Pattern = if(str == null) { + protected def compile(str: String): Pattern = if (str == null) { null } else { // Let it raise exception if couldn't compile the regex string Pattern.compile(escape(str)) } - protected def pattern(str: String) = if(cache == null) compile(str) else cache + protected def pattern(str: String) = if (cache == null) compile(str) else cache override def eval(input: Row): Any = { val l = left.eval(input) @@ -114,8 +112,6 @@ case class RLike(left: Expression, right: Expression) trait CaseConversionExpression extends ExpectsInputTypes { self: UnaryExpression => - type EvaluatedType = Any - def convert(v: UTF8String): UTF8String override def foldable: Boolean = child.foldable @@ -137,7 +133,7 @@ trait CaseConversionExpression extends ExpectsInputTypes { * A function that converts the characters of a string to uppercase. */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - + override def convert(v: UTF8String): UTF8String = v.toUpperCase() override def toString: String = s"Upper($child)" @@ -147,7 +143,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE * A function that converts the characters of a string to lowercase. */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - + override def convert(v: UTF8String): UTF8String = v.toLowerCase() override def toString: String = s"Lower($child)" @@ -159,8 +155,6 @@ trait StringComparison extends ExpectsInputTypes { def compare(l: UTF8String, r: UTF8String): Boolean - override type EvaluatedType = Any - override def nullable: Boolean = left.nullable || right.nullable override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) @@ -211,8 +205,6 @@ case class EndsWith(left: Expression, right: Expression) */ case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression with ExpectsInputTypes { - - type EvaluatedType = Any override def foldable: Boolean = str.foldable && pos.foldable && len.foldable @@ -231,7 +223,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) @inline def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and - // negative indices for start positions. If a start index i is greater than 0, it + // negative indices for start positions. If a start index i is greater than 0, it // refers to element i-1 in the sequence. If a start index i is less than 0, it refers // to the -ith element before the end of the sequence. If a start index i is 0, it // refers to the first element. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 099d67ca7fee3..82c4d462cc322 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -66,9 +66,7 @@ case class WindowSpecDefinition( } } - type EvaluatedType = Any - - override def children: Seq[Expression] = partitionSpec ++ orderSpec + override def children: Seq[Expression] = partitionSpec ++ orderSpec override lazy val resolved: Boolean = childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame] @@ -76,7 +74,7 @@ case class WindowSpecDefinition( override def toString: String = simpleString - override def eval(input: Row): EvaluatedType = throw new UnsupportedOperationException + override def eval(input: Row): Any = throw new UnsupportedOperationException override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException @@ -299,7 +297,7 @@ case class UnresolvedWindowFunction( override def get(index: Int): Any = throw new UnresolvedException(this, "get") // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name(${children.mkString(",")})" @@ -311,25 +309,25 @@ case class UnresolvedWindowFunction( case class UnresolvedWindowExpression( child: UnresolvedWindowFunction, windowSpec: WindowSpecReference) extends UnaryExpression { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } case class WindowExpression( windowFunction: WindowFunction, windowSpec: WindowSpecDefinition) extends Expression { - override type EvaluatedType = Any override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil - override def eval(input: Row): EvaluatedType = + override def eval(input: Row): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def dataType: DataType = windowFunction.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b163707cc9925..5c6379b8d44b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -156,6 +156,11 @@ object ColumnPruning extends Rule[LogicalPlan] { case Project(projectList, Limit(exp, child)) => Limit(exp, Project(projectList, child)) + // push down project if possible when the child is sort + case p @ Project(projectList, s @ Sort(_, _, grandChild)) + if s.references.subsetOf(p.outputSet) => + s.copy(child = Project(projectList, grandChild)) + // Eliminate no-op Projects case Project(projectList, child) if child.output == projectList => child } @@ -174,8 +179,17 @@ object ColumnPruning extends Rule[LogicalPlan] { * expressions into one single expression. */ object ProjectCollapsing extends Rule[LogicalPlan] { + + /** Returns true if any expression in projectList is non-deterministic. */ + private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = { + projectList.exists(expr => expr.find(!_.deterministic).isDefined) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case Project(projectList1, Project(projectList2, child)) => + // We only collapse these two Projects if the child Project's expressions are all + // deterministic. + case Project(projectList1, Project(projectList2, child)) + if !hasNondeterministic(projectList2) => // Create a map of Aliases to their values from the child projection. // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). val aliasMap = AttributeMap(projectList2.collect { @@ -259,6 +273,10 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) + // MaxOf and MinOf can't do null propagation + case e: MaxOf => e + case e: MinOf => e + // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 51b5699affed5..73a21884a4710 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -51,9 +51,9 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { * filled in automatically by the QueryPlanner using the other execution strategies that are * available. */ - protected def planLater(plan: LogicalPlan) = apply(plan).next() + protected def planLater(plan: LogicalPlan) = this.plan(plan).next() - def apply(plan: LogicalPlan): Iterator[PhysicalPlan] = { + def plan(plan: LogicalPlan): Iterator[PhysicalPlan] = { // Obviously a lot to do here still... val iter = strategies.view.flatMap(_(plan)).toIterator assert(iter.hasNext, s"No plan for $plan") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cd54d04814ea4..1dd75a8846303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -159,9 +159,10 @@ object PartialAggregation { // Should trim aliases around `GetField`s. These aliases are introduced while // resolving struct field accesses, because `GetField` is not a `NamedExpression`. // (Should we just turn `GetField` into a `NamedExpression`?) + val trimmed = e.transform { case Alias(g: ExtractValue, _) => g } namedGroupingExpressions - .get(e.transform { case Alias(g: ExtractValue, _) => g }) - .map(_.toAttribute) + .find { case (k, v) => k semanticEquals trimmed } + .map(_._2.toAttribute) .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7967189cacb24..eff5c61644944 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -84,7 +84,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy val newArgs = productIterator.map { case e: Expression => transformExpressionDown(e) case Some(e: Expression) => Some(transformExpressionDown(e)) - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionDown(e) @@ -117,7 +117,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy val newArgs = productIterator.map { case e: Expression => transformExpressionUp(e) case Some(e: Expression) => Some(transformExpressionUp(e)) - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map { case e: Expression => transformExpressionUp(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index dbb12d56f9497..dba69659afc80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -105,7 +105,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { } /** - * Optionally resolves the given string to a [[NamedExpression]] using the input from all child + * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as * as string in the following form: `[scope].AttributeName.[nested].[fields]...`. */ @@ -116,7 +116,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { resolve(nameParts, children.flatMap(_.output), resolver, throwErrors) /** - * Optionally resolves the given string to a [[NamedExpression]] based on the output of this + * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this * LogicalPlan. The attribute is expressed as string in the following form: * `[scope].AttributeName.[nested].[fields]...`. */ @@ -126,6 +126,57 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { throwErrors: Boolean = false): Option[NamedExpression] = resolve(nameParts, output, resolver, throwErrors) + /** + * Given an attribute name, split it to name parts by dot, but + * don't split the name parts quoted by backticks, for example, + * `ab.cd`.`efg` should be split into two parts "ab.cd" and "efg". + */ + def resolveQuoted( + name: String, + resolver: Resolver): Option[NamedExpression] = { + resolve(parseAttributeName(name), resolver, true) + } + + /** + * Internal method, used to split attribute name by dot with backticks rule. + * Backticks must appear in pairs, and the quoted string must be a complete name part, + * which means `ab..c`e.f is not allowed. + * Escape character is not supported now, so we can't use backtick inside name part. + */ + private def parseAttributeName(name: String): Seq[String] = { + val e = new AnalysisException(s"syntax error in attribute name: $name") + val nameParts = scala.collection.mutable.ArrayBuffer.empty[String] + val tmp = scala.collection.mutable.ArrayBuffer.empty[Char] + var inBacktick = false + var i = 0 + while (i < name.length) { + val char = name(i) + if (inBacktick) { + if (char == '`') { + inBacktick = false + if (i + 1 < name.length && name(i + 1) != '.') throw e + } else { + tmp += char + } + } else { + if (char == '`') { + if (tmp.nonEmpty) throw e + inBacktick = true + } else if (char == '.') { + if (tmp.isEmpty) throw e + nameParts += tmp.mkString + tmp.clear() + } else { + tmp += char + } + } + i += 1 + } + if (tmp.isEmpty || inBacktick) throw e + nameParts += tmp.mkString + nameParts.toSeq + } + /** * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 0f349f9d11415..33a9e55a47dee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -59,6 +59,9 @@ case class Generate( child: LogicalPlan) extends UnaryNode { + /** The set of all attributes produced by this node. */ + def generatedSet: AttributeSet = AttributeSet(generatorOutput) + override lazy val resolved: Boolean = { generator.resolved && childrenResolved && @@ -90,7 +93,7 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override lazy val resolved: Boolean = childrenResolved && - left.output.zip(right.output).forall { case (l,r) => l.dataType == r.dataType } + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index fb4217a44807b..80ba57a082a60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -169,7 +169,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def keyExpressions: Seq[Expression] = expressions - override def eval(input: Row = null): EvaluatedType = + override def eval(input: Row = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } @@ -213,6 +213,6 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) override def keyExpressions: Seq[Expression] = ordering.map(_.child) - override def eval(input: Row): EvaluatedType = + override def eval(input: Row): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index bc2ad34523d2c..36d005d0e1684 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -254,7 +254,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { Some(arg) } - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -311,7 +311,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { Some(arg) } - case m: Map[_,_] => m + case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -385,6 +385,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { def argString: String = productIterator.flatMap { case tn: TreeNode[_] if children contains tn => Nil case tn: TreeNode[_] if tn.toString contains "\n" => s"(${tn.simpleString})" :: Nil + case seq: Seq[BaseType] if seq.toSet.subsetOf(children.toSet) => Nil case seq: Seq[_] => seq.mkString("[", ",", "]") :: Nil case set: Set[_] => set.mkString("{", ",", "}") :: Nil case other => other :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala similarity index 96% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala index d36a49159b87f..ad649acf536f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.types +package org.apache.spark.sql.catalyst.util import java.sql.Date import java.text.SimpleDateFormat @@ -24,7 +24,7 @@ import java.util.{Calendar, TimeZone} import org.apache.spark.sql.catalyst.expressions.Cast /** - * helper function to convert between Int value of days since 1970-01-01 and java.sql.Date + * Helper function to convert between Int value of days since 1970-01-01 and java.sql.Date */ object DateUtils { private val MILLIS_PER_DAY = 86400000 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala new file mode 100644 index 0000000000000..0bb12d2039ffc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types._ + +/** + * Helper function to check for valid data types + */ +object TypeUtils { + def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[NumericType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric types, not $t") + } + } + + def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[IntegralType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t") + } + } + + def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = { + if (t.isInstanceOf[AtomicType] || t == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller accepts non-complex types, not $t") + } + } + + def getNumeric(t: DataType): Numeric[Any] = + t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + + def getOrdering(t: DataType): Ordering[Any] = + t.asInstanceOf[AtomicType].ordering.asInstanceOf[Ordering[Any]] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 9d613a940ee86..07054166a5e88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -83,7 +83,7 @@ package object util { } def resourceToString( - resource:String, + resource: String, encoding: String = "UTF-8", classLoader: ClassLoader = Utils.getSparkClassLoader): String = { new String(resourceToBytes(resource, classLoader), encoding) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index a0b261649f66f..74677ddfcad65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -107,7 +107,7 @@ protected[sql] abstract class AtomicType extends DataType { abstract class NumericType extends AtomicType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a - // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets + // type parameter and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no // longer an no argument constructor and thus the JVM cannot serialize the object anymore. private[sql] val numeric: Numeric[InternalType] @@ -165,6 +165,9 @@ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) + /** + * @deprecated As of 1.2.0, replaced by `DataType.fromJson()` + */ @deprecated("Use DataType.fromJson instead", "1.2.0") def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) @@ -271,7 +274,7 @@ object DataType { protected lazy val structField: Parser[StructField] = ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => + case name ~ tpe ~ nullable => StructField(name, tpe, nullable = nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 994c5202c15dc..eb3c58c37f308 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -313,7 +313,7 @@ object Decimal { // See scala.math's Numeric.scala for examples for Scala's built-in types. /** Common methods for Decimal evidence parameters */ - trait DecimalIsConflicted extends Numeric[Decimal] { + private[sql] trait DecimalIsConflicted extends Numeric[Decimal] { override def plus(x: Decimal, y: Decimal): Decimal = x + y override def times(x: Decimal, y: Decimal): Decimal = x * y override def minus(x: Decimal, y: Decimal): Decimal = x - y @@ -327,12 +327,12 @@ object Decimal { } /** A [[scala.math.Fractional]] evidence parameter for Decimals. */ - object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { + private[sql] object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { override def div(x: Decimal, y: Decimal): Decimal = x / y } /** A [[scala.math.Integral]] evidence parameter for Decimals. */ - object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { + private[sql] object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { override def quot(x: Decimal, y: Decimal): Decimal = x / y override def rem(x: Decimal, y: Decimal): Decimal = x % y } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 0f8cecd28f7df..407dc27326c2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -82,12 +82,12 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT object DecimalType { val Unlimited: DecimalType = DecimalType(None) - object Fixed { + private[sql] object Fixed { def unapply(t: DecimalType): Option[(Int, Int)] = t.precisionInfo.map(p => (p.precision, p.scale)) } - object Expression { + private[sql] object Expression { def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) case _ => None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java index a64d2bb7cde37..df64a878b6b36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -24,11 +24,11 @@ /** * ::DeveloperApi:: * A user-defined type which can be automatically recognized by a SQLContext and registered. - * + *

    * WARNING: This annotation will only work if both Java and Scala reflection return the same class * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class * is enclosed in an object (a singleton). - * + *

    * WARNING: UDTs are currently only supported from Scala. */ // TODO: Should I used @Documented ? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 7e00a27dfe724..193c08a4d0df7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -230,10 +230,10 @@ object StructType { case (StructType(leftFields), StructType(rightFields)) => val newFields = ArrayBuffer.empty[StructField] + val rightMapped = fieldsMap(rightFields) leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => - rightFields - .find(_.name == leftName) + rightMapped.get(leftName) .map { case rightField @ StructField(_, rightType, rightNullable, _) => leftField.copy( dataType = merge(leftType, rightType), @@ -243,8 +243,9 @@ object StructType { .foreach(newFields += _) } + val leftMapped = fieldsMap(leftFields) rightFields - .filterNot(f => leftFields.map(_.name).contains(f.name)) + .filterNot(f => leftMapped.get(f.name).nonEmpty) .foreach(newFields += _) StructType(newFields) @@ -264,4 +265,9 @@ object StructType { case _ => throw new SparkException(s"Failed to merge incompatible data types $left and $right") } + + private[sql] def fieldsMap(fields: Array[StructField]): Map[String, StructField] = { + import scala.collection.breakOut + fields.map(s => (s.name, s))(breakOut) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index fc02ba6c9c43e..f5d8fcced362b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -19,15 +19,18 @@ package org.apache.spark.sql.types import java.util.Arrays +import org.apache.spark.annotation.DeveloperApi + /** - * A UTF-8 String, as internal representation of StringType in SparkSQL + * :: DeveloperApi :: + * A UTF-8 String, as internal representation of StringType in SparkSQL * - * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, - * search, see http://en.wikipedia.org/wiki/UTF-8 for details. + * A String encoded in UTF-8 as an Array[Byte], which can be used for comparison, + * search, see http://en.wikipedia.org/wiki/UTF-8 for details. * - * Note: This is not designed for general use cases, should not be used outside SQL. + * Note: This is not designed for general use cases, should not be used outside SQL. */ - +@DeveloperApi final class UTF8String extends Ordered[UTF8String] with Serializable { private[this] var bytes: Array[Byte] = _ @@ -180,6 +183,10 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { } } +/** + * :: DeveloperApi :: + */ +@DeveloperApi object UTF8String { // number of tailing bytes in a UTF8 sequence for a code point // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 @@ -196,7 +203,7 @@ object UTF8String { def apply(s: String): UTF8String = { if (s != null) { new UTF8String().set(s) - } else{ + } else { null } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala new file mode 100644 index 0000000000000..df0f04563edcf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class CatalystTypeConvertersSuite extends SparkFunSuite { + + private val simpleTypes: Seq[DataType] = Seq( + StringType, + DateType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType) + + test("null handling in rows") { + val schema = StructType(simpleTypes.map(t => StructField(t.getClass.getName, t))) + val convertToCatalyst = CatalystTypeConverters.createToCatalystConverter(schema) + val convertToScala = CatalystTypeConverters.createToScalaConverter(schema) + + val scalaRow = Row.fromSeq(Seq.fill(simpleTypes.length)(null)) + assert(convertToScala(convertToCatalyst(scalaRow)) === scalaRow) + } + + test("null handling for individual values") { + for (dataType <- simpleTypes) { + assert(CatalystTypeConverters.createToScalaConverter(dataType)(null) === null) + } + } + + test("option handling in convertToCatalyst") { + // convertToCatalyst doesn't handle unboxing from Options. This is inconsistent with + // createToCatalystConverter but it may not actually matter as this is only called internally + // in a handful of places where we don't expect to receive Options. + assert(CatalystTypeConverters.convertToCatalyst(Some(123)) === Some(123)) + } + + test("option handling in createToCatalystConverter") { + assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index ea82cd2622de9..c046dbf4dc2c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.plans.physical._ /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ -class DistributionSuite extends FunSuite { +class DistributionSuite extends SparkFunSuite { protected def checkSatisfied( inputPartitioning: Partitioning, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index bbc0b661a0c0c..9a24b23024e18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -20,8 +20,7 @@ package org.apache.spark.sql.catalyst import java.math.BigInteger import java.sql.{Date, Timestamp} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.types._ @@ -75,7 +74,7 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } -class ScalaReflectionSuite extends FunSuite { +class ScalaReflectionSuite extends SparkFunSuite { import ScalaReflection._ test("primitive data") { @@ -253,7 +252,7 @@ class ScalaReflectionSuite extends FunSuite { } assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) - assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3)))) + assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1, 2, 3)))) } test("convert PrimitiveData to catalyst") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 890ea2a84b82e..b93a3abc6ebd2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.Command -import org.scalatest.FunSuite private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { override def output: Seq[Attribute] = Seq.empty @@ -28,7 +28,7 @@ private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Comman } private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") + protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST") override protected lazy val start: Parser[LogicalPlan] = set @@ -39,7 +39,7 @@ private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser { } private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { - protected val EXECUTE = Keyword("EXECUTE") + protected val EXECUTE = Keyword("EXECUTE") override protected lazy val start: Parser[LogicalPlan] = set @@ -49,7 +49,7 @@ private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser { } } -class SqlParserSuite extends FunSuite { +class SqlParserSuite extends SparkFunSuite { test("test long keyword") { val parser = new SuperLongKeywordTestParser diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6f2f35564d12e..e09cd790a7187 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -class AnalysisSuite extends FunSuite with BeforeAndAfter { +class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) @@ -72,6 +73,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { StructField("cField", StringType) :: Nil ))()) + val listRelation = LocalRelation( + AttributeReference("list", ArrayType(IntegerType))()) + before { caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) @@ -152,17 +156,35 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { caseSensitive: Boolean = true): Unit = { test(name) { val error = intercept[AnalysisException] { - if(caseSensitive) { + if (caseSensitive) { caseSensitiveAnalyze(plan) } else { caseInsensitiveAnalyze(plan) } } - errorMessages.foreach(m => assert(error.getMessage contains m)) + errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) } } + errorTest( + "unresolved window function", + testRelation2.select( + WindowExpression( + UnresolvedWindowFunction( + "lead", + UnresolvedAttribute("c") :: Nil), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as('window)), + "lead" :: "window functions currently requires a HiveContext" :: Nil) + + errorTest( + "too many generators", + listRelation.select(Explode('list).as('a), Explode('list).as('b)), + "only one generator" :: "explode" :: Nil) + errorTest( "unresolved attributes", testRelation.select('abcd), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 565b1cfe019c7..7bac97b7894f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.SimpleCatalystConf -class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { +class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { val conf = new SimpleCatalystConf(true) val catalog = new SimpleCatalog(conf) val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) @@ -91,8 +92,10 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { } test("Comparison operations") { - checkComparison(LessThan(i, d1), DecimalType.Unlimited) - checkComparison(LessThanOrEqual(d1, d2), DecimalType.Unlimited) + checkComparison(EqualTo(i, d1), DecimalType(10, 1)) + checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) + checkComparison(LessThan(i, d1), DecimalType(10, 1)) + checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) checkComparison(GreaterThan(d2, u), DecimalType.Unlimited) checkComparison(GreaterThanOrEqual(d1, f), DoubleType) checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index fcd745f43cfbf..0df446636ea89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -20,18 +20,19 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project} +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ class HiveTypeCoercionSuite extends PlanTest { test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { - var found = HiveTypeCoercion.findTightestCommonType(t1, t2) + var found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = HiveTypeCoercion.findTightestCommonType(t2, t1) + found = HiveTypeCoercion.findTightestCommonTypeOfTwo(t2, t1) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") } @@ -104,31 +105,16 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(ArrayType(IntegerType), StructType(Seq()), None) } - test("boolean casts") { - val booleanCasts = new HiveTypeCoercion { }.BooleanCasts - def ruleTest(initial: Expression, transformed: Expression) { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - comparePlans( - booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - // Remove superflous boolean -> boolean casts. - ruleTest(Cast(Literal(true), BooleanType), Literal(true)) - // Stringify boolean when casting to string. - ruleTest( - Cast(Literal(false), StringType), - If(Literal(false), Literal("true"), Literal("false"))) + private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + comparePlans( + rule(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) } test("coalesce casts") { val fac = new HiveTypeCoercion { }.FunctionArgumentConversion - def ruleTest(initial: Expression, transformed: Expression) { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - comparePlans( - fac(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - ruleTest( + ruleTest(fac, Coalesce(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -137,7 +123,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest( + ruleTest(fac, Coalesce(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -147,4 +133,36 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType()) :: Nil)) } + + test("type coercion for CaseKeyWhen") { + val cwc = new HiveTypeCoercion {}.CaseWhenCoercion + ruleTest(cwc, + CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) + ) + ruleTest(cwc, + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), + CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) + ) + } + + test("type coercion simplification for equal to") { + val be = new HiveTypeCoercion {}.BooleanEqualization + ruleTest(be, + EqualTo(Literal(true), Literal(1)), + Literal(true) + ) + ruleTest(be, + EqualTo(Literal(true), Literal(0)), + Not(Literal(true)) + ) + ruleTest(be, + EqualNullSafe(Literal(true), Literal(1)), + And(IsNotNull(Literal(true)), Literal(true)) + ) + ruleTest(be, + EqualNullSafe(Literal(true), Literal(0)), + And(IsNotNull(Literal(true)), Not(Literal(true))) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala index f2f3a84d19380..97cfb5f06dd73 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.IntegerType -class AttributeSetSuite extends FunSuite { +class AttributeSetSuite extends SparkFunSuite { val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1)) val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 04fd261d16aa3..b6927485f42bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -22,17 +22,18 @@ import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet import org.scalactic.TripleEqualsSupport.Spread -import org.scalatest.FunSuite import org.scalatest.Matchers._ +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.mathfuncs._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ -class ExpressionEvaluationBaseSuite extends FunSuite { +class ExpressionEvaluationBaseSuite extends SparkFunSuite { def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { expression.eval(inputRow) @@ -42,8 +43,8 @@ class ExpressionEvaluationBaseSuite extends FunSuite { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + if (actual != expected) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } } @@ -125,37 +126,37 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } booleanLogicTest("AND", _ && _, - (true, true, true) :: - (true, false, false) :: - (true, null, null) :: - (false, true, false) :: + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: (false, false, false) :: - (false, null, false) :: - (null, true, null) :: - (null, false, false) :: - (null, null, null) :: Nil) + (false, null, false) :: + (null, true, null) :: + (null, false, false) :: + (null, null, null) :: Nil) booleanLogicTest("OR", _ || _, - (true, true, true) :: - (true, false, true) :: - (true, null, true) :: - (false, true, true) :: + (true, true, true) :: + (true, false, true) :: + (true, null, true) :: + (false, true, true) :: (false, false, false) :: - (false, null, null) :: - (null, true, true) :: - (null, false, null) :: - (null, null, null) :: Nil) + (false, null, null) :: + (null, true, true) :: + (null, false, null) :: + (null, null, null) :: Nil) booleanLogicTest("=", _ === _, - (true, true, true) :: - (true, false, false) :: - (true, null, null) :: - (false, true, false) :: + (true, true, true) :: + (true, false, false) :: + (true, null, null) :: + (false, true, false) :: (false, false, true) :: - (false, null, null) :: - (null, true, null) :: - (null, false, null) :: - (null, null, null) :: Nil) + (false, null, null) :: + (null, true, null) :: + (null, false, null) :: + (null, null, null) :: Nil) def booleanLogicTest( name: String, @@ -163,7 +164,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { truthTable: Seq[(Any, Any, Any)]) { test(s"3VL $name") { truthTable.foreach { - case (l,r,answer) => + case (l, r, answer) => val expr = op(Literal.create(l, BooleanType), Literal.create(r, BooleanType)) checkEvaluation(expr, answer) } @@ -371,6 +372,8 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) + checkEvaluation(Literal(true) cast StringType, "true") + checkEvaluation(Literal(false) cast StringType, "false") checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) checkEvaluation("23" cast DoubleType, 23d) @@ -859,7 +862,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { val c5 = 'a.string.at(4) val c6 = 'a.string.at(5) - val literalNull = Literal.create(null, BooleanType) + val literalNull = Literal.create(null, IntegerType) val literalInt = Literal(1) val literalString = Literal("a") @@ -868,12 +871,12 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) - checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row) + checkEvaluation(CaseKeyWhen(c4, Seq(c6, c3, c5, c2, Literal(3))), 3, row) checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) - checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row) - checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row) + checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row) + checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) } test("complex type") { @@ -927,7 +930,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { :: StructField("b", StringType, nullable = false) :: Nil ) - assert(getStructField(BoundReference(2,typeS, nullable = true), "a").nullable === true) + assert(getStructField(BoundReference(2, typeS, nullable = true), "a").nullable === true) assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false) @@ -1206,7 +1209,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { } /** - * Used for testing math functions for DataFrames. + * Used for testing math functions for DataFrames. * @param c The DataFrame function * @param f The functions in scala.math * @param domain The set of values to run the function with @@ -1214,7 +1217,7 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { * @tparam T Generic type for primitives */ def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( - c: Expression => Expression, + c: Expression => Expression, f: T => T, domain: Iterable[T] = (-20 to 20).map(_ * 0.1), expectNull: Boolean = false): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala new file mode 100644 index 0000000000000..dcb3635c5ccae --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types.StringType + +class ExpressionTypeCheckingSuite extends SparkFunSuite { + + val testRelation = LocalRelation( + 'intField.int, + 'stringField.string, + 'booleanField.boolean, + 'complexField.array(StringType)) + + def assertError(expr: Expression, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + assertSuccess(expr) + } + assert(e.getMessage.contains( + s"cannot resolve '${expr.prettyString}' due to data type mismatch:")) + assert(e.getMessage.contains(errorMessage)) + } + + def assertSuccess(expr: Expression): Unit = { + val analyzed = testRelation.select(expr.as("c")).analyze + SimpleAnalyzer.checkAnalysis(analyzed) + } + + def assertErrorForDifferingTypes(expr: Expression): Unit = { + assertError(expr, + s"differing types in ${expr.getClass.getSimpleName} (IntegerType and BooleanType).") + } + + test("check types for unary arithmetic") { + assertError(UnaryMinus('stringField), "operator - accepts numeric type") + assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt + assertError(Sqrt('booleanField), "function sqrt accepts numeric type") + assertError(Abs('stringField), "function abs accepts numeric type") + assertError(BitwiseNot('stringField), "operator ~ accepts integral type") + } + + test("check types for binary arithmetic") { + // We will cast String to Double for binary arithmetic + assertSuccess(Add('intField, 'stringField)) + assertSuccess(Subtract('intField, 'stringField)) + assertSuccess(Multiply('intField, 'stringField)) + assertSuccess(Divide('intField, 'stringField)) + assertSuccess(Remainder('intField, 'stringField)) + // checkAnalysis(BitwiseAnd('intField, 'stringField)) + + assertErrorForDifferingTypes(Add('intField, 'booleanField)) + assertErrorForDifferingTypes(Subtract('intField, 'booleanField)) + assertErrorForDifferingTypes(Multiply('intField, 'booleanField)) + assertErrorForDifferingTypes(Divide('intField, 'booleanField)) + assertErrorForDifferingTypes(Remainder('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseAnd('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseOr('intField, 'booleanField)) + assertErrorForDifferingTypes(BitwiseXor('intField, 'booleanField)) + assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) + assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) + + assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type") + assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type") + assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type") + assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type") + assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type") + + assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type") + assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type") + assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type") + + assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type") + assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type") + } + + test("check types for predicates") { + // We will cast String to Double for binary comparison + assertSuccess(EqualTo('intField, 'stringField)) + assertSuccess(EqualNullSafe('intField, 'stringField)) + assertSuccess(LessThan('intField, 'stringField)) + assertSuccess(LessThanOrEqual('intField, 'stringField)) + assertSuccess(GreaterThan('intField, 'stringField)) + assertSuccess(GreaterThanOrEqual('intField, 'stringField)) + + // We will transform EqualTo with numeric and boolean types to CaseKeyWhen + assertSuccess(EqualTo('intField, 'booleanField)) + assertSuccess(EqualNullSafe('intField, 'booleanField)) + + assertError(EqualTo('intField, 'complexField), "differing types") + assertError(EqualNullSafe('intField, 'complexField), "differing types") + + assertErrorForDifferingTypes(LessThan('intField, 'booleanField)) + assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) + assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) + + assertError( + LessThan('complexField, 'complexField), "operator < accepts non-complex type") + assertError( + LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type") + assertError( + GreaterThan('complexField, 'complexField), "operator > accepts non-complex type") + assertError( + GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type") + + assertError( + If('intField, 'stringField, 'stringField), + "type of predicate expression in If should be boolean") + assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField)) + + assertError( + CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)), + "THEN and ELSE expressions should all be same type or coercible to a common type") + assertError( + CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), + "WHEN expressions in CaseWhen should all be boolean type") + + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala index b5ebe4b38e337..d7c437095e395 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedEvaluationSuite.scala @@ -41,9 +41,9 @@ class GeneratedEvaluationSuite extends ExpressionEvaluationSuite { """.stripMargin) } - val actual = plan(inputRow).apply(0) - if(actual != expected) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + val actual = plan(inputRow).apply(0) + if (actual != expected) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala index 97af2e0fd0502..a40324b008e16 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratedMutableEvaluationSuite.scala @@ -53,7 +53,7 @@ class GeneratedMutableEvaluationSuite extends ExpressionEvaluationSuite { """.stripMargin) } if (actual != expectedRow) { - val input = if(inputRow == EmptyRow) "" else s", input: $inputRow" + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 7a19e511eb8b5..88a36aa121b55 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -20,12 +20,16 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.JavaConverters._ import scala.util.Random +import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} -import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.sql.types._ -class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach { +class UnsafeFixedWidthAggregationMapSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach { import UnsafeFixedWidthAggregationMap._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 3a60c7fd32675..61722f1ffa462 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Arrays -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods -class UnsafeRowConverterSuite extends FunSuite with Matchers { +class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with only primitive types") { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 6255578d7fa57..465a5e6914204 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -78,9 +78,9 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { test("(a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ...") { checkCondition('b > 3 || 'c > 5, 'b > 3 || 'c > 5) - checkCondition(('a < 2 && 'a > 3 && 'b > 5) || 'a < 2, 'a < 2) + checkCondition(('a < 2 && 'a > 3 && 'b > 5) || 'a < 2, 'a < 2) - checkCondition('a < 2 || ('a < 2 && 'a > 3 && 'b > 5), 'a < 2) + checkCondition('a < 2 || ('a < 2 && 'a > 3 && 'b > 5), 'a < 2) val input = ('a === 'b && 'b > 3 && 'c > 2) || ('a === 'b && 'c < 1 && 'a === 5) || diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index a30052b38fc11..06c592f4905a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -71,7 +71,7 @@ class CombiningLimitsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("limits: combines two limits after ColumnPruning") { val originalQuery = testRelation @@ -79,7 +79,7 @@ class CombiningLimitsSuite extends PlanTest { .limit(2) .select('a) .limit(5) - + val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 5697c2272b8e8..ec3b2f1edfa05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -248,7 +248,7 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("Constant folding test: Fold In(v, list) into true or false") { var originalQuery = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0c428f7231b8e..17dc9124749e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.expressions.{Count, Explode} +import org.apache.spark.sql.catalyst.expressions.{SortOrder, Ascending, Count, Explode} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ @@ -93,11 +93,11 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("column pruning for Project(ne, Limit)") { val originalQuery = testRelation - .select('a,'b) + .select('a, 'b) .limit(2) .select('a) @@ -109,7 +109,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + // After this line is unimplemented. test("simple push down") { val originalQuery = @@ -542,4 +542,38 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery) } + + test("push down project past sort") { + val x = testRelation.subquery('x) + + // push down valid + val originalQuery = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('a) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + x.select('a) + .sortBy(SortOrder('a, Ascending)).analyze + + comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer)) + + // push down invalid + val originalQuery1 = { + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b) + } + + val optimized1 = Optimize.execute(originalQuery1.analyze) + val correctAnswer1 = + x.select('a, 'b) + .sortBy(SortOrder('a, Ascending)) + .select('b).analyze + + comparePlans(optimized1, analysis.EliminateSubQueries(correctAnswer1)) + + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 3eb399e68e70c..1d433275fed2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -46,7 +46,7 @@ class OptimizeInSuite extends PlanTest { test("OptimizedIn test: In clause optimized to InSet") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2)))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2)))) .analyze val optimized = Optimize.execute(originalQuery.analyze) @@ -57,17 +57,17 @@ class OptimizeInSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - + test("OptimizedIn test: In clause not optimized in case filter has attributes") { val originalQuery = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(In(UnresolvedAttribute("a"), Seq(Literal(1),Literal(2), UnresolvedAttribute("b")))) + .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b")))) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala new file mode 100644 index 0000000000000..151654bffbd66 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Rand +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + + +class ProjectCollapsingSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", FixedPoint(10), EliminateSubQueries) :: + Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int) + + test("collapse two deterministic, independent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select('a_plus_1, ('b + 1).as('b_plus_1)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("collapse two deterministic, dependent projects into one") { + val query = testRelation + .select(('a + 1).as('a_plus_1), 'b) + .select(('a_plus_1 + 1).as('a_plus_2), 'b) + + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = testRelation.select( + (('a + 1).as('a_plus_1) + 1).as('a_plus_2), + 'b).analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not collapse nondeterministic projects") { + val query = testRelation + .select(Rand(10).as('rand)) + .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala index a3ad200800b02..35f50be46b76f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala @@ -33,8 +33,8 @@ class UnionPushdownSuite extends PlanTest { UnionPushdown) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) val testUnion = Union(testRelation, testRelation2) test("union: filter to each side") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index e7cafcc96de87..765c1e2dda99f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} import org.apache.spark.sql.catalyst.util._ @@ -26,7 +25,7 @@ import org.apache.spark.sql.catalyst.util._ /** * Provides helper methods for comparing plans. */ -class PlanTest extends FunSuite { +class PlanTest extends SparkFunSuite { /** * Since attribute references are given globally unique ids during analysis, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 1273921f6394c..62d5f6ac74885 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.plans -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} @@ -28,7 +27,7 @@ import org.apache.spark.sql.catalyst.util._ /** * Tests for the sameResult function of [[LogicalPlan]]. */ -class SameResultSuite extends FunSuite { +class SameResultSuite extends SparkFunSuite { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index 2a641c63f87bb..a7de7b052bdc3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.catalyst.trees -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral, Literal} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} -class RuleExecutorSuite extends FunSuite { +class RuleExecutorSuite extends SparkFunSuite { object DecrementLiterals extends Rule[Expression] { def apply(e: Expression): Expression = e transform { case IntegerLiteral(i) if i > 0 => Literal(i - 1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 3d10dab5ba34c..67db3d5e6d751 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -19,21 +19,19 @@ package org.apache.spark.sql.catalyst.trees import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{IntegerType, StringType, NullType} case class Dummy(optKey: Option[Expression]) extends Expression { - def children: Seq[Expression] = optKey.toSeq - def nullable: Boolean = true - def dataType: NullType = NullType + override def children: Seq[Expression] = optKey.toSeq + override def nullable: Boolean = true + override def dataType: NullType = NullType override lazy val resolved = true - type EvaluatedType = Any - def eval(input: Row): Any = null.asInstanceOf[Any] + override def eval(input: Row): Any = null.asInstanceOf[Any] } -class TreeNodeSuite extends FunSuite { +class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } assert(after === Literal(2)) @@ -92,7 +90,7 @@ class TreeNodeSuite extends FunSuite { test("transform works on nodes with Option children") { val dummy1 = Dummy(Some(Literal.create("1", StringType))) val dummy2 = Dummy(None) - val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } var actual = dummy1 transformDown toZero assert(actual === Dummy(Some(Literal(0)))) @@ -105,7 +103,7 @@ class TreeNodeSuite extends FunSuite { } test("preserves origin") { - CurrentOrigin.setPosition(1,1) + CurrentOrigin.setPosition(1, 1) val add = Add(Literal(1), Literal(1)) CurrentOrigin.reset() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala index d7d60efee50fa..4030a1b1df358 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.catalyst.util import org.json4s.jackson.JsonMethods.parse -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.{MetadataBuilder, Metadata} -class MetadataSuite extends FunSuite { +class MetadataSuite extends SparkFunSuite { val baseMetadata = new MetadataBuilder() .putString("purpose", "ml") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index 3e7cf7cbb5e63..c6171b7b6916d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.types -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class DataTypeParserSuite extends FunSuite { +class DataTypeParserSuite extends SparkFunSuite { def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { test(s"parse ${dataTypeString.replace("\n", "")}") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index d797510f36685..261c4fcad24aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.types -import org.scalatest.FunSuite +import org.apache.spark.{SparkException, SparkFunSuite} -class DataTypeSuite extends FunSuite { +class DataTypeSuite extends SparkFunSuite { test("construct an ArrayType") { val array = ArrayType(StringType) @@ -69,6 +69,76 @@ class DataTypeSuite extends FunSuite { } } + test("fieldsMap returns map of name to StructField") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val mapped = StructType.fieldsMap(struct.fields) + + val expected = Map( + "a" -> StructField("a", LongType), + "b" -> StructField("b", FloatType)) + + assert(mapped === expected) + } + + test("merge where right is empty") { + val left = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val right = StructType(List()) + val merged = left.merge(right) + + assert(merged === left) + } + + test("merge where left is empty") { + + val left = StructType(List()) + + val right = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val merged = left.merge(right) + + assert(right === merged) + + } + + test("merge where both are non-empty") { + val left = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val right = StructType( + StructField("c", LongType) :: Nil) + + val expected = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: + StructField("c", LongType) :: Nil) + + val merged = left.merge(right) + + assert(merged === expected) + } + + test("merge where right contains type conflict") { + val left = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + + val right = StructType( + StructField("b", LongType) :: Nil) + + intercept[SparkException] { + left.merge(right) + } + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) @@ -120,7 +190,7 @@ class DataTypeSuite extends FunSuite { checkDefaultSize(DecimalType(10, 5), 4096) checkDefaultSize(DecimalType.Unlimited, 4096) checkDefaultSize(DateType, 4) - checkDefaultSize(TimestampType,12) + checkDefaultSize(TimestampType, 12) checkDefaultSize(StringType, 4096) checkDefaultSize(BinaryType, 4096) checkDefaultSize(ArrayType(DoubleType, true), 800) @@ -179,11 +249,11 @@ class DataTypeSuite extends FunSuite { expected = false) checkEqualsIgnoreCompatibleNullability( from = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), - to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), + to = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), expected = false) checkEqualsIgnoreCompatibleNullability( from = MapType(StringType, ArrayType(IntegerType, false), valueContainsNull = true), - to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), + to = MapType(StringType, ArrayType(IntegerType, true), valueContainsNull = true), expected = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala index a22aa6f244c48..81d7ab010f394 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/UTF8StringSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.types -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite // scalastyle:off -class UTF8StringSuite extends FunSuite { +class UTF8StringSuite extends SparkFunSuite { test("basic") { def check(str: String, len: Int) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index de6a2cd448c47..28b373e258311 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.types.decimal +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.Decimal -import org.scalatest.{PrivateMethodTester, FunSuite} +import org.scalatest.PrivateMethodTester import scala.language.postfixOps -class DecimalSuite extends FunSuite with PrivateMethodTester { +class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("creating decimals") { /** Check that a Decimal has the given string representation, precision and scale */ def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ffe95bb49188f..8210c552603ea 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -41,6 +41,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-catalyst_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 42f5bcda49cfb..d3efa83380d04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -21,9 +21,10 @@ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.Logging +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.types._ @@ -346,8 +347,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * }}} * * @group expr_ops + * @since 1.4.0 */ - def when(condition: Column, value: Any):Column = this.expr match { + def when(condition: Column, value: Any): Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) case _ => @@ -374,8 +376,9 @@ class Column(protected[sql] val expr: Expression) extends Logging { * }}} * * @group expr_ops + * @since 1.4.0 */ - def otherwise(value: Any):Column = this.expr match { + def otherwise(value: Any): Column = this.expr match { case CaseWhen(branches: Seq[Expression]) => if (branches.size % 2 == 0) { CaseWhen(branches :+ lit(value).expr) @@ -713,6 +716,18 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def endsWith(literal: String): Column = this.endsWith(lit(literal)) + /** + * Gives the column an alias. Same as `as`. + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".alias("colB")) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def alias(alias: String): Column = as(alias) + /** * Gives the column an alias. * {{{ @@ -725,6 +740,30 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def as(alias: String): Column = Alias(expr, alias)() + /** + * (Scala-specific) Assigns the given aliases to the results of a table generating function. + * {{{ + * // Renames colA to colB in select output. + * df.select(explode($"myMap").as("key" :: "value" :: Nil)) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def as(aliases: Seq[String]): Column = MultiAlias(expr, aliases) + + /** + * Assigns the given aliases to the results of a table generating function. + * {{{ + * // Renames colA to colB in select output. + * df.select(explode($"myMap").as("key" :: "value" :: Nil)) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def as(aliases: Array[String]): Column = MultiAlias(expr, aliases) + /** * Gives the column an alias. * {{{ @@ -862,6 +901,22 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def bitwiseXOR(other: Any): Column = BitwiseXor(expr, lit(other).expr) + /** + * Define a windowing column. + * + * {{{ + * val w = Window.partitionBy("name").orderBy("id") + * df.select( + * sum("price").over(w.rangeBetween(Long.MinValue, 2)), + * avg("price").over(w.rowsBetween(0, 4)) + * ) + * }}} + * + * @group expr_ops + * @since 1.4.0 + */ + def over(window: expressions.WindowSpec): Column = window.withAggregate(this) + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c820a673575ff..034d887901975 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter -import java.sql.DriverManager import java.util.Properties import scala.collection.JavaConversions._ @@ -34,15 +33,14 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} -import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JacksonGenerator -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, ResolvedDataSource} +import org.apache.spark.sql.sources.CreateTableUsingAsSelect import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -59,14 +57,11 @@ private[sql] object DataFrame { * :: Experimental :: * A distributed collection of data organized into named columns. * - * A [[DataFrame]] is equivalent to a relational table in Spark SQL. There are multiple ways - * to create a [[DataFrame]]: + * A [[DataFrame]] is equivalent to a relational table in Spark SQL. The following example creates + * a [[DataFrame]] by pointing Spark SQL to a Parquet data set. * {{{ - * // Create a DataFrame from Parquet files - * val people = sqlContext.parquetFile("...") - * - * // Create a DataFrame from data sources - * val df = sqlContext.load("...", "json") + * val people = sqlContext.read.parquet("...") // in Scala + * DataFrame people = sqlContext.read().parquet("...") // in Java * }}} * * Once created, it can be manipulated using the various domain-specific-language (DSL) functions @@ -88,8 +83,8 @@ private[sql] object DataFrame { * A more concrete example in Scala: * {{{ * // To create DataFrame using SQLContext - * val people = sqlContext.parquetFile("...") - * val department = sqlContext.parquetFile("...") + * val people = sqlContext.read.parquet("...") + * val department = sqlContext.read.parquet("...") * * people.filter("age > 30") * .join(department, people("deptId") === department("id")) @@ -100,8 +95,8 @@ private[sql] object DataFrame { * and in Java: * {{{ * // To create DataFrame using SQLContext - * DataFrame people = sqlContext.parquetFile("..."); - * DataFrame department = sqlContext.parquetFile("..."); + * DataFrame people = sqlContext.read().parquet("..."); + * DataFrame department = sqlContext.read().parquet("..."); * * people.filter("age".gt(30)) * .join(department, people.col("deptId").equalTo(department("id"))) @@ -160,7 +155,7 @@ class DataFrame private[sql]( } protected[sql] def resolve(colName: String): NamedExpression = { - queryExecution.analyzed.resolve(colName.split("\\."), sqlContext.analyzer.resolver).getOrElse { + queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse { throw new AnalysisException( s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") } @@ -168,7 +163,7 @@ class DataFrame private[sql]( protected[sql] def numericColumns: Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolve(n.name.split("\\."), sqlContext.analyzer.resolver).get + queryExecution.analyzed.resolveQuoted(n.name, sqlContext.analyzer.resolver).get } } @@ -227,10 +222,6 @@ class DataFrame private[sql]( } } - /** Left here for backward compatibility. */ - @deprecated("1.3.0", "use toDF") - def toSchemaRDD: DataFrame = this - /** * Returns the object itself. * @group basic @@ -261,7 +252,7 @@ class DataFrame private[sql]( val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => Column(oldAttribute).as(newName) } - select(newCols :_*) + select(newCols : _*) } /** @@ -417,7 +408,7 @@ class DataFrame private[sql]( joined.left, joined.right, joinType = Inner, - Some(expressions.EqualTo( + Some(catalyst.expressions.EqualTo( joined.left.resolve(usingColumn), joined.right.resolve(usingColumn)))) ) @@ -486,8 +477,9 @@ class DataFrame private[sql]( // By the time we get here, since we have already run analysis, all attributes should've been // resolved and become AttributeReference. val cond = plan.condition.map { _.transform { - case expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => - expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) + case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualTo(plan.left.resolve(a.name), plan.right.resolve(b.name)) }} plan.copy(condition = cond) } @@ -505,7 +497,7 @@ class DataFrame private[sql]( */ @scala.annotation.varargs def sort(sortCol: String, sortCols: String*): DataFrame = { - sort((sortCol +: sortCols).map(apply) :_*) + sort((sortCol +: sortCols).map(apply) : _*) } /** @@ -536,7 +528,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols :_*) + def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols : _*) /** * Returns a new [[DataFrame]] sorted by the given expressions. @@ -545,7 +537,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs :_*) + def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*) /** * Selects column based on the column name and return it as a [[Column]]. @@ -593,6 +585,9 @@ class DataFrame private[sql]( def select(cols: Column*): DataFrame = { val namedExpressions = cols.map { case Column(expr: NamedExpression) => expr + // Leave an unaliased explode with an empty list of names since the analzyer will generate the + // correct defaults after the nested expression's type has been resolved. + case Column(explode: Explode) => MultiAlias(explode, Nil) case Column(expr: Expression) => Alias(expr, expr.prettyString)() } // When user continuously call `select`, speed up analysis by collapsing `Project` @@ -613,7 +608,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) :_*) + def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) /** * Selects a set of SQL expressions. This is a variant of `select` that accepts @@ -688,7 +683,53 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr)) + def groupBy(cols: Column*): GroupedData = { + GroupedData(this, cols.map(_.expr), GroupedData.GroupByType) + } + + /** + * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns, + * so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns rolluped by department and group. + * df.rollup($"department", $"group").avg() + * + * // Compute the max age and average salary, rolluped by department and gender. + * df.rollup($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + * @since 1.4.0 + */ + @scala.annotation.varargs + def rollup(cols: Column*): GroupedData = { + GroupedData(this, cols.map(_.expr), GroupedData.RollupType) + } + + /** + * Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns, + * so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns cubed by department and group. + * df.cube($"department", $"group").avg() + * + * // Compute the max age and average salary, cubed by department and gender. + * df.cube($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + * @since 1.4.0 + */ + @scala.annotation.varargs + def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), GroupedData.CubeType) /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. @@ -713,7 +754,61 @@ class DataFrame private[sql]( @scala.annotation.varargs def groupBy(col1: String, cols: String*): GroupedData = { val colNames: Seq[String] = col1 +: cols - new GroupedData(this, colNames.map(colName => resolve(colName))) + GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.GroupByType) + } + + /** + * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns, + * so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * This is a variant of rollup that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns rolluped by department and group. + * df.rollup("department", "group").avg() + * + * // Compute the max age and average salary, rolluped by department and gender. + * df.rollup($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + * @since 1.4.0 + */ + @scala.annotation.varargs + def rollup(col1: String, cols: String*): GroupedData = { + val colNames: Seq[String] = col1 +: cols + GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.RollupType) + } + + /** + * Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns, + * so we can run aggregation on them. + * See [[GroupedData]] for all the available aggregate functions. + * + * This is a variant of cube that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns cubed by department and group. + * df.cube("department", "group").avg() + * + * // Compute the max age and average salary, cubed by department and gender. + * df.cube($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * @group dfops + * @since 1.4.0 + */ + @scala.annotation.varargs + def cube(col1: String, cols: String*): GroupedData = { + val colNames: Seq[String] = col1 +: cols + GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType) } /** @@ -727,7 +822,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { - groupBy().agg(aggExpr, aggExprs :_*) + groupBy().agg(aggExpr, aggExprs : _*) } /** @@ -765,7 +860,7 @@ class DataFrame private[sql]( * @since 1.3.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) + def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) /** * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function @@ -941,7 +1036,7 @@ class DataFrame private[sql]( val name = field.name if (resolver(name, colName)) col.as(colName) else Column(name) } - select(colNames :_*) + select(colNames : _*) } else { select(Column("*"), col.as(colName)) } @@ -1066,7 +1161,7 @@ class DataFrame private[sql]( val ret: Seq[Row] = if (outputCols.nonEmpty) { val aggExprs = statistics.flatMap { case (_, colToAgg) => - outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c)) + outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) } val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq @@ -1080,9 +1175,9 @@ class DataFrame private[sql]( statistics.map { case (name, _) => Row(name) } } - // The first column is string type, and the rest are double type. + // All columns are string type val schema = StructType( - StructField("summary", StringType) :: outputCols.map(StructField(_, DoubleType))).toAttributes + StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes LocalRelation(schema, ret) } @@ -1164,7 +1259,7 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) + override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() : _*) /** * Returns the number of rows in the [[DataFrame]]. @@ -1286,23 +1381,123 @@ class DataFrame private[sql]( sqlContext.registerDataFrameAsTable(this, tableName) } + /** + * :: Experimental :: + * Interface for saving the content of the [[DataFrame]] out into external storage. + * + * @group output + * @since 1.4.0 + */ + @Experimental + def write: DataFrameWriter = new DataFrameWriter(this) + + /** + * Returns the content of the [[DataFrame]] as a RDD of JSON strings. + * @group rdd + * @since 1.3.0 + */ + def toJSON: RDD[String] = { + val rowSchema = this.schema + this.mapPartitions { iter => + val writer = new CharArrayWriter() + // create the Generator without separator inserted between 2 records + val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) + + new Iterator[String] { + override def hasNext: Boolean = iter.hasNext + override def next(): String = { + JacksonGenerator(rowSchema, gen)(iter.next()) + gen.flush() + + val json = writer.toString + if (hasNext) { + writer.reset() + } else { + gen.close() + } + + json + } + } + } + } + + //////////////////////////////////////////////////////////////////////////// + // for Python API + //////////////////////////////////////////////////////////////////////////// + + /** + * Converts a JavaRDD to a PythonRDD. + */ + protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) + } + + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + // Deprecated methods + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + + /** + * @deprecated As of 1.3.0, replaced by `toDF()`. + */ + @deprecated("use toDF", "1.3.0") + def toSchemaRDD: DataFrame = this + + /** + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. + * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. + * If you pass `true` for `allowExisting`, it will drop any table with the + * given name; if you pass `false`, it will throw if the table already + * exists. + * @group output + * @deprecated As of 1.340, replaced by `write().jdbc()`. + */ + @deprecated("Use write.jdbc()", "1.4.0") + def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { + val w = if (allowExisting) write.mode(SaveMode.Overwrite) else write + w.jdbc(url, table, new Properties) + } + + /** + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. + * Assumes the table already exists and has a compatible schema. If you + * pass `true` for `overwrite`, it will `TRUNCATE` the table before + * performing the `INSERT`s. + * + * The table must already exist on the database. It must have a schema + * that is compatible with the schema of this RDD; inserting the rows of + * the RDD in order via the simple statement + * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. + * @group output + * @deprecated As of 1.4.0, replaced by `write().jdbc()`. + */ + @deprecated("Use write.jdbc()", "1.4.0") + def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { + val w = if (overwrite) write.mode(SaveMode.Overwrite) else write + w.jdbc(url, table, new Properties) + } + /** * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. * Files that are written out using this method can be read back in as a [[DataFrame]] * using the `parquetFile` function in [[SQLContext]]. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().parquet()`. */ + @deprecated("Use write.parquet(path)", "1.4.0") def saveAsParquetFile(path: String): Unit = { if (sqlContext.conf.parquetUseDataSourceApi) { - save("org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> path)) + write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) } else { sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd } } /** - * :: Experimental :: * Creates a table from the the contents of this DataFrame. * It will use the default data source configured by spark.sql.sources.default. * This will fail if the table already exists. @@ -1315,15 +1510,14 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String): Unit = { - saveAsTable(tableName, SaveMode.ErrorIfExists) + write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName) } /** - * :: Experimental :: * Creates a table from the the contents of this DataFrame, using the default data source * configured by spark.sql.sources.default and [[SaveMode.ErrorIfExists]] as the save mode. * @@ -1335,22 +1529,14 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, mode: SaveMode): Unit = { - if (sqlContext.catalog.tableExists(Seq(tableName)) && mode == SaveMode.Append) { - // If table already exists and the save mode is Append, - // we will just call insertInto to append the contents of this DataFrame. - insertInto(tableName, overwrite = false) - } else { - val dataSourceName = sqlContext.conf.defaultDataSourceName - saveAsTable(tableName, dataSourceName, mode) - } + write.mode(mode).saveAsTable(tableName) } /** - * :: Experimental :: * Creates a table at the given path from the the contents of this DataFrame * based on a given data source and a set of options, * using [[SaveMode.ErrorIfExists]] as the save mode. @@ -1363,11 +1549,11 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.format(source).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String): Unit = { - saveAsTable(tableName, source, SaveMode.ErrorIfExists) + write.format(source).saveAsTable(tableName) } /** @@ -1383,15 +1569,14 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName)", "1.4.0") def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { - saveAsTable(tableName, source, mode, Map.empty[String, String]) + write.format(source).mode(mode).saveAsTable(tableName) } /** - * :: Experimental :: * Creates a table at the given path from the the contents of this DataFrame * based on a given data source, [[SaveMode]] specified by mode, and a set of options. * @@ -1403,42 +1588,20 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", + "1.4.0") def saveAsTable( tableName: String, source: String, mode: SaveMode, options: java.util.Map[String, String]): Unit = { - saveAsTable(tableName, source, mode, options.toMap) - } - - /** - * :: Experimental :: - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source, [[SaveMode]] specified by mode, a set of options, and a list of - * partition columns. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * @group output - * @since 1.4.0 - */ - @Experimental - def saveAsTable( - tableName: String, - source: String, - mode: SaveMode, - options: java.util.Map[String, String], - partitionColumns: java.util.List[String]): Unit = { - saveAsTable(tableName, source, mode, options.toMap, partitionColumns) + write.format(source).mode(mode).options(options).saveAsTable(tableName) } /** - * :: Experimental :: * (Scala-specific) * Creates a table from the the contents of this DataFrame based on a given data source, * [[SaveMode]] specified by mode, and a set of options. @@ -1451,328 +1614,123 @@ class DataFrame private[sql]( * Also note that while this function can persist the table metadata into Hive's metastore, * the table will NOT be accessible from Hive, until SPARK-7550 is resolved. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName)", + "1.4.0") def saveAsTable( tableName: String, source: String, mode: SaveMode, options: Map[String, String]): Unit = { - val cmd = - CreateTableUsingAsSelect( - tableName, - source, - temporary = false, - Array.empty[String], - mode, - options, - logicalPlan) - - sqlContext.executePlan(cmd).toRdd + write.format(source).mode(mode).options(options).saveAsTable(tableName) } /** - * :: Experimental :: - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source, [[SaveMode]] specified by mode, a set of options, and a list of - * partition columns. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * @group output - * @since 1.4.0 - */ - @Experimental - def saveAsTable( - tableName: String, - source: String, - mode: SaveMode, - options: Map[String, String], - partitionColumns: Seq[String]): Unit = { - sqlContext.executePlan( - CreateTableUsingAsSelect( - tableName, - source, - temporary = false, - partitionColumns.toArray, - mode, - options, - logicalPlan)).toRdd - } - - /** - * :: Experimental :: * Saves the contents of this DataFrame to the given path, * using the default data source configured by spark.sql.sources.default and * [[SaveMode.ErrorIfExists]] as the save mode. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().save(path)`. */ - @Experimental + @deprecated("Use write.save(path)", "1.4.0") def save(path: String): Unit = { - save(path, SaveMode.ErrorIfExists) + write.save(path) } /** - * :: Experimental :: * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, * using the default data source configured by spark.sql.sources.default. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. */ - @Experimental + @deprecated("Use write.mode(mode).save(path)", "1.4.0") def save(path: String, mode: SaveMode): Unit = { - val dataSourceName = sqlContext.conf.defaultDataSourceName - save(path, dataSourceName, mode) + write.mode(mode).save(path) } /** - * :: Experimental :: * Saves the contents of this DataFrame to the given path based on the given data source, * using [[SaveMode.ErrorIfExists]] as the save mode. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. */ - @Experimental + @deprecated("Use write.format(source).save(path)", "1.4.0") def save(path: String, source: String): Unit = { - save(source, SaveMode.ErrorIfExists, Map("path" -> path)) + write.format(source).save(path) } /** - * :: Experimental :: * Saves the contents of this DataFrame to the given path based on the given data source and * [[SaveMode]] specified by mode. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. */ - @Experimental + @deprecated("Use write.format(source).mode(mode).save(path)", "1.4.0") def save(path: String, source: String, mode: SaveMode): Unit = { - save(source, mode, Map("path" -> path)) + write.format(source).mode(mode).save(path) } /** - * :: Experimental :: * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).save(path)`. */ - @Experimental + @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( source: String, mode: SaveMode, options: java.util.Map[String, String]): Unit = { - save(source, mode, options.toMap) - } - - /** - * :: Experimental :: - * Saves the contents of this DataFrame to the given path based on the given data source, - * [[SaveMode]] specified by mode, and partition columns specified by `partitionColumns`. - * @group output - * @since 1.4.0 - */ - @Experimental - def save( - source: String, - mode: SaveMode, - options: java.util.Map[String, String], - partitionColumns: java.util.List[String]): Unit = { - save(source, mode, options.toMap, partitionColumns) + write.format(source).mode(mode).options(options).save() } /** - * :: Experimental :: * (Scala-specific) * Saves the contents of this DataFrame based on the given data source, * [[SaveMode]] specified by mode, and a set of options * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by + * `write().format(source).mode(mode).options(options).save(path)`. */ - @Experimental + @deprecated("Use write.format(source).mode(mode).options(options).save()", "1.4.0") def save( source: String, mode: SaveMode, options: Map[String, String]): Unit = { - ResolvedDataSource(sqlContext, source, Array.empty[String], mode, options, this) + write.format(source).mode(mode).options(options).save() } - /** - * :: Experimental :: - * Saves the contents of this DataFrame to the given path based on the given data source, - * [[SaveMode]] specified by mode, and partition columns specified by `partitionColumns`. - * @group output - * @since 1.4.0 - */ - @Experimental - def save( - source: String, - mode: SaveMode, - options: Map[String, String], - partitionColumns: Seq[String]): Unit = { - ResolvedDataSource(sqlContext, source, partitionColumns.toArray, mode, options, this) - } /** - * :: Experimental :: * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. * @group output - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by + * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. */ - @Experimental + @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)", "1.4.0") def insertInto(tableName: String, overwrite: Boolean): Unit = { - sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite, ifNotExists = false)).toRdd + write.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append).insertInto(tableName) } /** - * :: Experimental :: * Adds the rows from this RDD to the specified table. * Throws an exception if the table already exists. * @group output - * @since 1.3.0 - */ - @Experimental - def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) - - /** - * Returns the content of the [[DataFrame]] as a RDD of JSON strings. - * @group rdd - * @since 1.3.0 + * @deprecated As of 1.4.0, replaced by + * `write().mode(SaveMode.Append).saveAsTable(tableName)`. */ - def toJSON: RDD[String] = { - val rowSchema = this.schema - this.mapPartitions { iter => - val writer = new CharArrayWriter() - // create the Generator without separator inserted between 2 records - val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null) - - new Iterator[String] { - override def hasNext: Boolean = iter.hasNext - override def next(): String = { - JacksonGenerator(rowSchema, gen)(iter.next()) - gen.flush() - - val json = writer.toString - if (hasNext) { - writer.reset() - } else { - gen.close() - } - - json - } - } - } + @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName)", "1.4.0") + def insertInto(tableName: String): Unit = { + write.mode(SaveMode.Append).insertInto(tableName) } //////////////////////////////////////////////////////////////////////////// - // JDBC Write Support //////////////////////////////////////////////////////////////////////////// - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. - * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. - * If you pass `true` for `allowExisting`, it will drop any table with the - * given name; if you pass `false`, it will throw if the table already - * exists. - * @group output - * @since 1.3.0 - */ - def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { - createJDBCTable(url, table, allowExisting, new Properties()) - } - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table` - * using connection properties defined in `properties`. - * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. - * If you pass `true` for `allowExisting`, it will drop any table with the - * given name; if you pass `false`, it will throw if the table already - * exists. - * @group output - * @since 1.4.0 - */ - def createJDBCTable( - url: String, - table: String, - allowExisting: Boolean, - properties: Properties): Unit = { - val conn = DriverManager.getConnection(url, properties) - try { - if (allowExisting) { - val sql = s"DROP TABLE IF EXISTS $table" - conn.prepareStatement(sql).executeUpdate() - } - val schema = JDBCWriteDetails.schemaString(this, url) - val sql = s"CREATE TABLE $table ($schema)" - conn.prepareStatement(sql).executeUpdate() - } finally { - conn.close() - } - JDBCWriteDetails.saveTable(this, url, table, properties) - } - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. - * Assumes the table already exists and has a compatible schema. If you - * pass `true` for `overwrite`, it will `TRUNCATE` the table before - * performing the `INSERT`s. - * - * The table must already exist on the database. It must have a schema - * that is compatible with the schema of this RDD; inserting the rows of - * the RDD in order via the simple statement - * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. - * @group output - * @since 1.3.0 - */ - def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { - insertIntoJDBC(url, table, overwrite, new Properties()) - } - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table` - * using connection properties defined in `properties`. - * Assumes the table already exists and has a compatible schema. If you - * pass `true` for `overwrite`, it will `TRUNCATE` the table before - * performing the `INSERT`s. - * - * The table must already exist on the database. It must have a schema - * that is compatible with the schema of this RDD; inserting the rows of - * the RDD in order via the simple statement - * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. - * @group output - * @since 1.4.0 - */ - def insertIntoJDBC( - url: String, - table: String, - overwrite: Boolean, - properties: Properties): Unit = { - if (overwrite) { - val conn = DriverManager.getConnection(url, properties) - try { - val sql = s"TRUNCATE TABLE $table" - conn.prepareStatement(sql).executeUpdate() - } finally { - conn.close() - } - } - JDBCWriteDetails.saveTable(this, url, table, properties) - } + // End of deprecated methods //////////////////////////////////////////////////////////////////////////// - // for Python API //////////////////////////////////////////////////////////////////////////// - /** - * Converts a JavaRDD to a PythonRDD. - */ - protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala index b87efb58d51e5..2f19ec0403017 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala @@ -28,5 +28,5 @@ private[sql] case class DataFrameHolder(df: DataFrame) { // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = df - def toDF(colNames: String*): DataFrame = df.toDF(colNames :_*) + def toDF(colNames: String*): DataFrame = df.toDF(colNames : _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala new file mode 100644 index 0000000000000..b44d4c86ac5d3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -0,0 +1,289 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.util.Properties + +import org.apache.hadoop.fs.Path +import org.apache.spark.Partition + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} +import org.apache.spark.sql.json.{JsonRDD, JSONRelation} +import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.sources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * Interface used to load a [[DataFrame]] from external storage systems (e.g. file systems, + * key-value stores, etc). Use [[SQLContext.read]] to access this. + * + * @since 1.4.0 + */ +@Experimental +class DataFrameReader private[sql](sqlContext: SQLContext) { + + /** + * Specifies the input data source format. + * + * @since 1.4.0 + */ + def format(source: String): DataFrameReader = { + this.source = source + this + } + + /** + * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema + * automatically from data. By specifying the schema here, the underlying data source can + * skip the schema inference step, and thus speed up data loading. + * + * @since 1.4.0 + */ + def schema(schema: StructType): DataFrameReader = { + this.userSpecifiedSchema = Option(schema) + this + } + + /** + * Adds an input option for the underlying data source. + * + * @since 1.4.0 + */ + def option(key: String, value: String): DataFrameReader = { + this.extraOptions += (key -> value) + this + } + + /** + * (Scala-specific) Adds input options for the underlying data source. + * + * @since 1.4.0 + */ + def options(options: scala.collection.Map[String, String]): DataFrameReader = { + this.extraOptions ++= options + this + } + + /** + * Adds input options for the underlying data source. + * + * @since 1.4.0 + */ + def options(options: java.util.Map[String, String]): DataFrameReader = { + this.options(scala.collection.JavaConversions.mapAsScalaMap(options)) + this + } + + /** + * Loads input in as a [[DataFrame]], for data sources that require a path (e.g. data backed by + * a local or distributed file system). + * + * @since 1.4.0 + */ + def load(path: String): DataFrame = { + option("path", path).load() + } + + /** + * Loads input in as a [[DataFrame]], for data sources that don't require a path (e.g. external + * key-value stores). + * + * @since 1.4.0 + */ + def load(): DataFrame = { + val resolved = ResolvedDataSource( + sqlContext, + userSpecifiedSchema = userSpecifiedSchema, + partitionColumns = Array.empty[String], + provider = source, + options = extraOptions.toMap) + DataFrame(sqlContext, LogicalRelation(resolved.relation)) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table and connection properties. + * + * @since 1.4.0 + */ + def jdbc(url: String, table: String, properties: Properties): DataFrame = { + jdbc(url, table, JDBCRelation.columnPartition(null), properties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. Partitions of the table will be retrieved in parallel based on the parameters + * passed to this function. + * + * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash + * your external database systems. + * + * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param table Name of the table in the external database. + * @param columnName the name of a column of integral type that will be used for partitioning. + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride + * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split + * evenly into this many partitions + * @param connectionProperties JDBC database connection arguments, a list of arbitrary string + * tag/value. Normally at least a "user" and "password" property + * should be included. + * + * @since 1.4.0 + */ + def jdbc( + url: String, + table: String, + columnName: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + connectionProperties: Properties): DataFrame = { + val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions) + val parts = JDBCRelation.columnPartition(partitioning) + jdbc(url, table, parts, connectionProperties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table using connection properties. The `predicates` parameter gives a list + * expressions suitable for inclusion in WHERE clauses; each one defines one partition + * of the [[DataFrame]]. + * + * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash + * your external database systems. + * + * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param table Name of the table in the external database. + * @param predicates Condition in the where clause for each partition. + * @param connectionProperties JDBC database connection arguments, a list of arbitrary string + * tag/value. Normally at least a "user" and "password" property + * should be included. + * @since 1.4.0 + */ + def jdbc( + url: String, + table: String, + predicates: Array[String], + connectionProperties: Properties): DataFrame = { + val parts: Array[Partition] = predicates.zipWithIndex.map { case (part, i) => + JDBCPartition(part, i) : Partition + } + jdbc(url, table, parts, connectionProperties) + } + + private def jdbc( + url: String, + table: String, + parts: Array[Partition], + connectionProperties: Properties): DataFrame = { + val relation = JDBCRelation(url, table, parts, connectionProperties)(sqlContext) + sqlContext.baseRelationToDataFrame(relation) + } + + /** + * Loads a JSON file (one object per line) and returns the result as a [[DataFrame]]. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * @param path input path + * @since 1.4.0 + */ + def json(path: String): DataFrame = format("json").load(path) + + /** + * Loads an `JavaRDD[String]` storing JSON objects (one object per record) and + * returns the result as a [[DataFrame]]. + * + * Unless the schema is specified using [[schema]] function, this function goes through the + * input once to determine the input schema. + * + * @param jsonRDD input RDD with one JSON object per record + * @since 1.4.0 + */ + def json(jsonRDD: JavaRDD[String]): DataFrame = json(jsonRDD.rdd) + + /** + * Loads an `RDD[String]` storing JSON objects (one object per record) and + * returns the result as a [[DataFrame]]. + * + * Unless the schema is specified using [[schema]] function, this function goes through the + * input once to determine the input schema. + * + * @param jsonRDD input RDD with one JSON object per record + * @since 1.4.0 + */ + def json(jsonRDD: RDD[String]): DataFrame = { + val samplingRatio = extraOptions.getOrElse("samplingRatio", "1.0").toDouble + if (sqlContext.conf.useJacksonStreamingAPI) { + sqlContext.baseRelationToDataFrame( + new JSONRelation(() => jsonRDD, None, samplingRatio, userSpecifiedSchema)(sqlContext)) + } else { + val columnNameOfCorruptJsonRecord = sqlContext.conf.columnNameOfCorruptRecord + val appliedSchema = userSpecifiedSchema.getOrElse( + JsonRDD.nullTypeToStringType( + JsonRDD.inferSchema(jsonRDD, 1.0, columnNameOfCorruptJsonRecord))) + val rowRDD = JsonRDD.jsonStringToRow(jsonRDD, appliedSchema, columnNameOfCorruptJsonRecord) + sqlContext.createDataFrame(rowRDD, appliedSchema, needsConversion = false) + } + } + + /** + * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty + * [[DataFrame]] if no paths are passed in. + * + * @since 1.4.0 + */ + @scala.annotation.varargs + def parquet(paths: String*): DataFrame = { + if (paths.isEmpty) { + sqlContext.emptyDataFrame + } else { + val globbedPaths = paths.map(new Path(_)).flatMap(SparkHadoopUtil.get.globPath).toArray + sqlContext.baseRelationToDataFrame( + new ParquetRelation2( + globbedPaths.map(_.toString), None, None, Map.empty[String, String])(sqlContext)) + } + } + + /** + * Returns the specified table as a [[DataFrame]]. + * + * @since 1.4.0 + */ + def table(tableName: String): DataFrame = { + DataFrame(sqlContext, sqlContext.catalog.lookupRelation(Seq(tableName))) + } + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = sqlContext.conf.defaultDataSourceName + + private var userSpecifiedSchema: Option[StructType] = None + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 5d106c1ac2674..edb9ed7bba56a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -43,7 +43,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { /** * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson - * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in + * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in * MLlib's Statistics. * * @param col1 the name of the column @@ -97,6 +97,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * The `support` should be greater than 1e-4. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @param support The minimum frequency for an item to be considered `frequent`. Should be greater * than 1e-4. @@ -114,6 +117,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * Uses a `default` support of 1%. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * @@ -128,6 +134,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * frequent element count algorithm described in * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * @@ -143,6 +152,9 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * Uses a `default` support of 1%. * + * This function is meant for exploratory data analysis, as we make no guarantee about the + * backward compatibility of the schema of the resulting [[DataFrame]]. + * * @param cols the names of the columns to search frequent items in. * @return A Local DataFrame with the Array of frequent items for each column. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala new file mode 100644 index 0000000000000..5548b26cb8f80 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -0,0 +1,295 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.util.Properties + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.jdbc.{JDBCWriteDetails, JdbcUtils} +import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} + + +/** + * :: Experimental :: + * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, + * key-value stores, etc). Use [[DataFrame.write]] to access this. + * + * @since 1.4.0 + */ +@Experimental +final class DataFrameWriter private[sql](df: DataFrame) { + + /** + * Specifies the behavior when data or table already exists. Options include: + * - `SaveMode.Overwrite`: overwrite the existing data. + * - `SaveMode.Append`: append the data. + * - `SaveMode.Ignore`: ignore the operation (i.e. no-op). + * - `SaveMode.ErrorIfExists`: default option, throw an exception at runtime. + * + * @since 1.4.0 + */ + def mode(saveMode: SaveMode): DataFrameWriter = { + this.mode = saveMode + this + } + + /** + * Specifies the behavior when data or table already exists. Options include: + * - `overwrite`: overwrite the existing data. + * - `append`: append the data. + * - `ignore`: ignore the operation (i.e. no-op). + * - `error`: default option, throw an exception at runtime. + * + * @since 1.4.0 + */ + def mode(saveMode: String): DataFrameWriter = { + this.mode = saveMode.toLowerCase match { + case "overwrite" => SaveMode.Overwrite + case "append" => SaveMode.Append + case "ignore" => SaveMode.Ignore + case "error" | "default" => SaveMode.ErrorIfExists + case _ => throw new IllegalArgumentException(s"Unknown save mode: $saveMode. " + + "Accepted modes are 'overwrite', 'append', 'ignore', 'error'.") + } + this + } + + /** + * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. + * + * @since 1.4.0 + */ + def format(source: String): DataFrameWriter = { + this.source = source + this + } + + /** + * Adds an output option for the underlying data source. + * + * @since 1.4.0 + */ + def option(key: String, value: String): DataFrameWriter = { + this.extraOptions += (key -> value) + this + } + + /** + * (Scala-specific) Adds output options for the underlying data source. + * + * @since 1.4.0 + */ + def options(options: scala.collection.Map[String, String]): DataFrameWriter = { + this.extraOptions ++= options + this + } + + /** + * Adds output options for the underlying data source. + * + * @since 1.4.0 + */ + def options(options: java.util.Map[String, String]): DataFrameWriter = { + this.options(scala.collection.JavaConversions.mapAsScalaMap(options)) + this + } + + /** + * Partitions the output by the given columns on the file system. If specified, the output is + * laid out on the file system similar to Hive's partitioning scheme. + * + * This is only applicable for Parquet at the moment. + * + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colNames: String*): DataFrameWriter = { + this.partitioningColumns = Option(colNames) + this + } + + /** + * Saves the content of the [[DataFrame]] at the specified path. + * + * @since 1.4.0 + */ + def save(path: String): Unit = { + this.extraOptions += ("path" -> path) + save() + } + + /** + * Saves the content of the [[DataFrame]] as the specified table. + * + * @since 1.4.0 + */ + def save(): Unit = { + ResolvedDataSource( + df.sqlContext, + source, + partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + mode, + extraOptions.toMap, + df) + } + + /** + * Inserts the content of the [[DataFrame]] to the specified table. It requires that + * the schema of the [[DataFrame]] is the same as the schema of the table. + * + * Because it inserts data to an existing table, format or options will be ignored. + * + * @since 1.4.0 + */ + def insertInto(tableName: String): Unit = { + val partitions = + partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) + val overwrite = (mode == SaveMode.Overwrite) + df.sqlContext.executePlan(InsertIntoTable( + UnresolvedRelation(Seq(tableName)), + partitions.getOrElse(Map.empty[String, Option[String]]), + df.logicalPlan, + overwrite, + ifNotExists = false)).toRdd + } + + /** + * Saves the content of the [[DataFrame]] as the specified table. + * + * In the case the table already exists, behavior of this function depends on the + * save mode, specified by the `mode` function (default to throwing an exception). + * When `mode` is `Overwrite`, the schema of the [[DataFrame]] does not need to be + * the same as that of the existing table. + * When `mode` is `Append`, the schema of the [[DataFrame]] need to be + * the same as that of the existing table, and format or options will be ignored. + * + * @since 1.4.0 + */ + def saveAsTable(tableName: String): Unit = { + if (df.sqlContext.catalog.tableExists(tableName :: Nil) && mode != SaveMode.Overwrite) { + mode match { + case SaveMode.Ignore => + // Do nothing + + case SaveMode.ErrorIfExists => + throw new AnalysisException(s"Table $tableName already exists.") + + case SaveMode.Append => + // If it is Append, we just ask insertInto to handle it. We will not use insertInto + // to handle saveAsTable with Overwrite because saveAsTable can change the schema of + // the table. But, insertInto with Overwrite requires the schema of data be the same + // the schema of the table. + insertInto(tableName) + } + } else { + val cmd = + CreateTableUsingAsSelect( + tableName, + source, + temporary = false, + partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + mode, + extraOptions.toMap, + df.logicalPlan) + df.sqlContext.executePlan(cmd).toRdd + } + } + + /** + * Saves the content of the [[DataFrame]] to a external database table via JDBC. In the case the + * table already exists in the external database, behavior of this function depends on the + * save mode, specified by the `mode` function (default to throwing an exception). + * + * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash + * your external database systems. + * + * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param table Name of the table in the external database. + * @param connectionProperties JDBC database connection arguments, a list of arbitrary string + * tag/value. Normally at least a "user" and "password" property + * should be included. + */ + def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { + val conn = JdbcUtils.createConnection(url, connectionProperties) + + try { + var tableExists = JdbcUtils.tableExists(conn, table) + + if (mode == SaveMode.Ignore && tableExists) { + return + } + + if (mode == SaveMode.ErrorIfExists && tableExists) { + sys.error(s"Table $table already exists.") + } + + if (mode == SaveMode.Overwrite && tableExists) { + JdbcUtils.dropTable(conn, table) + tableExists = false + } + + // Create the table if the table didn't exist. + if (!tableExists) { + val schema = JDBCWriteDetails.schemaString(df, url) + val sql = s"CREATE TABLE $table ($schema)" + conn.prepareStatement(sql).executeUpdate() + } + } finally { + conn.close() + } + + JDBCWriteDetails.saveTable(df, url, table, connectionProperties) + } + + /** + * Saves the content of the [[DataFrame]] in JSON format at the specified path. + * This is equivalent to: + * {{{ + * format("json").save(path) + * }}} + * + * @since 1.4.0 + */ + def json(path: String): Unit = format("json").save(path) + + /** + * Saves the content of the [[DataFrame]] in Parquet format at the specified path. + * This is equivalent to: + * {{{ + * format("parquet").save(path) + * }}} + * + * @since 1.4.0 + */ + def parquet(path: String): Unit = format("parquet").save(path) + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// + + private var source: String = df.sqlContext.conf.defaultDataSourceName + + private var mode: SaveMode = SaveMode.ErrorIfExists + + private var extraOptions = new scala.collection.mutable.HashMap[String, String] + + private var partitioningColumns: Option[Seq[String]] = None + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 1381b9f1a6080..45b3e1bc627d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -23,9 +23,40 @@ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType +/** + * Companion object for GroupedData + */ +private[sql] object GroupedData { + def apply( + df: DataFrame, + groupingExprs: Seq[Expression], + groupType: GroupType): GroupedData = { + new GroupedData(df, groupingExprs, groupType: GroupType) + } + + /** + * The Grouping Type + */ + private[sql] trait GroupType + + /** + * To indicate it's the GroupBy + */ + private[sql] object GroupByType extends GroupType + + /** + * To indicate it's the CUBE + */ + private[sql] object CubeType extends GroupType + + /** + * To indicate it's the ROLLUP + */ + private[sql] object RollupType extends GroupType +} /** * :: Experimental :: @@ -34,19 +65,37 @@ import org.apache.spark.sql.types.NumericType * @since 1.3.0 */ @Experimental -class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) { +class GroupedData protected[sql]( + df: DataFrame, + groupingExprs: Seq[Expression], + private val groupType: GroupedData.GroupType) { - private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { - val namedGroupingExprs = groupingExprs.map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() + private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { + val retainedExprs = groupingExprs.map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + retainedExprs ++ aggExprs + } else { + aggExprs + } + + groupType match { + case GroupedData.GroupByType => + DataFrame( + df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan)) + case GroupedData.RollupType => + DataFrame( + df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates)) + case GroupedData.CubeType => + DataFrame( + df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates)) } - DataFrame( - df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) - : Seq[NamedExpression] = { + : DataFrame = { val columnExprs = if (colNames.isEmpty) { // No columns specified. Use all numeric columns. @@ -63,10 +112,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) namedExpr } } - columnExprs.map { c => + toDF(columnExprs.map { c => val a = f(c) Alias(a, a.prettyString)() - } + }) } private[this] def strToExpr(expr: String): (Expression => Expression) = { @@ -119,10 +168,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) * @since 1.3.0 */ def agg(exprs: Map[String, String]): DataFrame = { - exprs.map { case (colName, expr) => + toDF(exprs.map { case (colName, expr) => val a = strToExpr(expr)(df(colName).expr) Alias(a, a.prettyString)() - }.toSeq + }.toSeq) } /** @@ -175,19 +224,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - val aggExprs = (expr +: exprs).map(_.expr).map { + toDF((expr +: exprs).map(_.expr).map { case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.prettyString)() - } - if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - val retainedExprs = groupingExprs.map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan)) - } else { - DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) - } + }) } /** @@ -196,7 +236,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) * * @since 1.3.0 */ - def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")()) + def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")())) /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`. @@ -207,9 +247,9 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Average) + aggregateNumericColumns(colNames : _*)(Average) } - + /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. @@ -219,7 +259,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def max(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Max) + aggregateNumericColumns(colNames : _*)(Max) } /** @@ -231,7 +271,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Average) + aggregateNumericColumns(colNames : _*)(Average) } /** @@ -243,7 +283,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def min(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Min) + aggregateNumericColumns(colNames : _*)(Min) } /** @@ -255,6 +295,6 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Sum) - } + aggregateNumericColumns(colNames : _*)(Sum) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index f07bb196c11ec..77c6af27d1007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -43,6 +43,8 @@ private[spark] object SQLConf { val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" val PARQUET_USE_DATA_SOURCE_API = "spark.sql.parquet.useDataSourceApi" + val ORC_FILTER_PUSHDOWN_ENABLED = "spark.sql.orc.filterPushdown" + val HIVE_VERIFY_PARTITIONPATH = "spark.sql.hive.verifyPartitionPath" val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" @@ -69,6 +71,10 @@ private[spark] object SQLConf { // Whether to perform partition discovery when loading external data sources. Default to true. val PARTITION_DISCOVERY_ENABLED = "spark.sql.sources.partitionDiscovery.enabled" + // The output committer class used by FSBasedRelation. The specified class needs to be a + // subclass of org.apache.hadoop.mapreduce.OutputCommitter. + val OUTPUT_COMMITTER_CLASS = "spark.sql.sources.outputCommitterClass" + // Whether to perform eager analysis when constructing a dataframe. // Set to false when debugging requires the ability to look at invalid query plans. val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" @@ -143,6 +149,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def parquetUseDataSourceApi = getConf(PARQUET_USE_DATA_SOURCE_API, "true").toBoolean + private[spark] def orcFilterPushDown = + getConf(ORC_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + /** When true uses verifyPartitionPath to prune the path which is not exists. */ private[spark] def verifyPartitionPath = getConf(HIVE_VERIFY_PARTITIONPATH, "true").toBoolean @@ -254,7 +263,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def dataFrameRetainGroupColumns: Boolean = getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean - + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 975498c11fa23..91e6385dec81b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.beans.Introspector import java.util.Properties +import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConversions._ import scala.collection.immutable @@ -26,11 +27,11 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import com.google.common.reflect.TypeToken - +import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.errors.DialectException @@ -38,50 +39,10 @@ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.ParserDialect -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, expressions} import org.apache.spark.sql.execution.{Filter, _} -import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} -import org.apache.spark.sql.json._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -import org.apache.spark.{Partition, SparkContext} - -/** - * Currently we support the default dialect named "sql", associated with the class - * [[DefaultParserDialect]] - * - * And we can also provide custom SQL Dialect, for example in Spark SQL CLI: - * {{{ - *-- switch to "hiveql" dialect - * spark-sql>SET spark.sql.dialect=hiveql; - * spark-sql>SELECT * FROM src LIMIT 1; - * - *-- switch to "sql" dialect - * spark-sql>SET spark.sql.dialect=sql; - * spark-sql>SELECT * FROM src LIMIT 1; - * - *-- register the new SQL dialect - * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect; - * spark-sql> SELECT * FROM src LIMIT 1; - * - *-- register the non-exist SQL dialect - * spark-sql> SET spark.sql.dialect=NotExistedClass; - * spark-sql> SELECT * FROM src LIMIT 1; - * - *-- Exception will be thrown and switch to dialect - *-- "sql" (for SQLContext) or - *-- "hiveql" (for HiveContext) - * }}} - */ -private[spark] class DefaultParserDialect extends ParserDialect { - @transient - protected val sqlParser = new catalyst.SqlParser - - override def parse(sqlText: String): LogicalPlan = { - sqlParser.parse(sqlText) - } -} /** * The entry point for working with structured data (rows and columns) in Spark. Allows the @@ -159,7 +120,7 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(true) + protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(conf) @transient protected[sql] lazy val analyzer: Analyzer = @@ -221,9 +182,28 @@ class SQLContext(@transient val sparkContext: SparkContext) conf.dialect } - sparkContext.getConf.getAll.foreach { - case (key, value) if key.startsWith("spark.sql") => setConf(key, value) - case _ => + { + // We extract spark sql settings from SparkContext's conf and put them to + // Spark SQL's conf. + // First, we populate the SQLConf (conf). So, we can make sure that other values using + // those settings in their construction can get the correct settings. + // For example, metadataHive in HiveContext may need both spark.sql.hive.metastore.version + // and spark.sql.hive.metastore.jars to get correctly constructed. + val properties = new Properties + sparkContext.getConf.getAll.foreach { + case (key, value) if key.startsWith("spark.sql") => properties.setProperty(key, value) + case _ => + } + // We directly put those settings to conf to avoid of calling setConf, which may have + // side-effects. For example, in HiveContext, setConf may cause executionHive and metadataHive + // get constructed. If we call setConf directly, the constructed metadataHive may have + // wrong settings, or the construction may fail. + conf.setConf(properties) + // After we have populated SQLConf, we call setConf to populate other confs in the subclass + // (e.g. hiveconf in HiveContext). + properties.foreach { + case (key, value) => setConf(key, value) + } } @transient @@ -337,7 +317,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit class StringToColumn(val sc: StringContext) { def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args :_*)) + new ColumnName(sc.s(args : _*)) } } @@ -429,7 +409,7 @@ class SQLContext(@transient val sparkContext: SparkContext) SparkPlan.currentContext.set(self) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes - val rowRDD = RDDConversions.productToRowRdd(rdd, schema) + val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) DataFrame(self, LogicalRDD(attributeSeq, rowRDD)(self)) } @@ -569,634 +549,232 @@ class SQLContext(@transient val sparkContext: SparkContext) } /** - * :: DeveloperApi :: - * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. - * It is important to make sure that the structure of every [[Row]] of the provided RDD matches - * the provided schema. Otherwise, there will be runtime exception. - * Example: + * :: Experimental :: + * Returns a [[DataFrameReader]] that can be used to read data in as a [[DataFrame]]. * {{{ - * import org.apache.spark.sql._ - * import org.apache.spark.sql.types._ - * val sqlContext = new org.apache.spark.sql.SQLContext(sc) - * - * val schema = - * StructType( - * StructField("name", StringType, false) :: - * StructField("age", IntegerType, true) :: Nil) - * - * val people = - * sc.textFile("examples/src/main/resources/people.txt").map( - * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) - * val dataFrame = sqlContext. applySchema(people, schema) - * dataFrame.printSchema - * // root - * // |-- name: string (nullable = false) - * // |-- age: integer (nullable = true) - * - * dataFrame.registerTempTable("people") - * sqlContext.sql("select name from people").collect.foreach(println) + * sqlContext.read.parquet("/path/to/file.parquet") + * sqlContext.read.schema(schema).json("/path/to/file.json") * }}} + * + * @group genericdata + * @since 1.4.0 */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } - - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } + @Experimental + def read: DataFrameReader = new DataFrameReader(this) /** - * Applies a schema to an RDD of Java Beans. + * :: Experimental :: + * Creates an external table from the given path and returns the corresponding DataFrame. + * It will use the default data source configured by spark.sql.sources.default. * - * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, - * SELECT * queries will return the columns in an undefined order. + * @group ddl_ops + * @since 1.3.0 */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) + @Experimental + def createExternalTable(tableName: String, path: String): DataFrame = { + val dataSourceName = conf.defaultDataSourceName + createExternalTable(tableName, path, dataSourceName) } /** - * Applies a schema to an RDD of Java Beans. + * :: Experimental :: + * Creates an external table from the given path based on a data source + * and returns the corresponding DataFrame. * - * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, - * SELECT * queries will return the columns in an undefined order. + * @group ddl_ops + * @since 1.3.0 */ - @deprecated("use createDataFrame", "1.3.0") - def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) + @Experimental + def createExternalTable( + tableName: String, + path: String, + source: String): DataFrame = { + createExternalTable(tableName, source, Map("path" -> path)) } /** - * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty - * [[DataFrame]] if no paths are passed in. + * :: Experimental :: + * Creates an external table from the given path based on a data source and a set of options. + * Then, returns the corresponding DataFrame. * - * @group specificdata + * @group ddl_ops * @since 1.3.0 */ - @scala.annotation.varargs - def parquetFile(paths: String*): DataFrame = { - if (paths.isEmpty) { - emptyDataFrame - } else if (conf.parquetUseDataSourceApi) { - baseRelationToDataFrame(parquet.ParquetRelation2(paths, Map.empty)(this)) - } else { - DataFrame(this, parquet.ParquetRelation( - paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) - } + @Experimental + def createExternalTable( + tableName: String, + source: String, + options: java.util.Map[String, String]): DataFrame = { + createExternalTable(tableName, source, options.toMap) } /** - * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. + * :: Experimental :: + * (Scala-specific) + * Creates an external table from the given path based on a data source and a set of options. + * Then, returns the corresponding DataFrame. * - * @group specificdata + * @group ddl_ops * @since 1.3.0 */ - def jsonFile(path: String): DataFrame = jsonFile(path, 1.0) + @Experimental + def createExternalTable( + tableName: String, + source: String, + options: Map[String, String]): DataFrame = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = None, + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableName) + } /** * :: Experimental :: - * Loads a JSON file (one object per line) and applies the given schema, - * returning the result as a [[DataFrame]]. + * Create an external table from the given path based on a data source, a schema and + * a set of options. Then, returns the corresponding DataFrame. * - * @group specificdata + * @group ddl_ops * @since 1.3.0 */ @Experimental - def jsonFile(path: String, schema: StructType): DataFrame = - load("json", schema, Map("path" -> path)) + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + createExternalTable(tableName, source, schema, options.toMap) + } /** * :: Experimental :: - * @group specificdata + * (Scala-specific) + * Create an external table from the given path based on a data source, a schema and + * a set of options. Then, returns the corresponding DataFrame. + * + * @group ddl_ops * @since 1.3.0 */ @Experimental - def jsonFile(path: String, samplingRatio: Double): DataFrame = - load("json", Map("path" -> path, "samplingRatio" -> samplingRatio.toString)) + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: Map[String, String]): DataFrame = { + val cmd = + CreateTableUsing( + tableName, + userSpecifiedSchema = Some(schema), + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableName) + } /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @since 1.3.0 + * Registers the given [[DataFrame]] as a temporary table in the catalog. Temporary tables exist + * only during the lifetime of this instance of SQLContext. */ - def jsonRDD(json: RDD[String]): DataFrame = jsonRDD(json, 1.0) - + private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { + catalog.registerTable(Seq(tableName), df.logicalPlan) + } /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. + * Drops the temporary table with the given table name in the catalog. If the table has been + * cached/persisted before, it's also unpersisted. * - * @group specificdata + * @param tableName the name of the table to be unregistered. + * + * @group basic * @since 1.3.0 */ - def jsonRDD(json: JavaRDD[String]): DataFrame = jsonRDD(json.rdd, 1.0) + def dropTempTable(tableName: String): Unit = { + cacheManager.tryUncacheQuery(table(tableName)) + catalog.unregisterTable(Seq(tableName)) + } /** * :: Experimental :: - * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, - * returning the result as a [[DataFrame]]. + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end`(exclusive) with step value 1. * - * @group specificdata - * @since 1.3.0 + * @since 1.4.0 + * @group dataframe */ @Experimental - def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { - if (conf.useJacksonStreamingAPI) { - baseRelationToDataFrame(new JSONRelation(() => json, None, 1.0, Some(schema))(this)) - } else { - val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord - val appliedSchema = - Option(schema).getOrElse( - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord))) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema, needsConversion = false) - } + def range(start: Long, end: Long): DataFrame = { + createDataFrame( + sparkContext.range(start, end).map(Row(_)), + StructType(StructField("id", LongType, nullable = false) :: Nil)) } /** * :: Experimental :: - * Loads an JavaRDD storing JSON objects (one object per record) and applies the given - * schema, returning the result as a [[DataFrame]]. + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end`(exclusive) with an step value, with partition number + * specified. * - * @group specificdata - * @since 1.3.0 + * @since 1.4.0 + * @group dataframe */ @Experimental - def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { - jsonRDD(json.rdd, schema) + def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { + createDataFrame( + sparkContext.range(start, end, step, numPartitions).map(Row(_)), + StructType(StructField("id", LongType, nullable = false) :: Nil)) } /** - * :: Experimental :: - * Loads an RDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a [[DataFrame]]. + * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is + * used for SQL parsing can be configured with 'spark.sql.dialect'. * - * @group specificdata + * @group basic * @since 1.3.0 */ - @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { - if (conf.useJacksonStreamingAPI) { - baseRelationToDataFrame(new JSONRelation(() => json, None, samplingRatio, None)(this)) - } else { - val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord - val appliedSchema = - JsonRDD.nullTypeToStringType( - JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord)) - val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord) - createDataFrame(rowRDD, appliedSchema, needsConversion = false) - } + def sql(sqlText: String): DataFrame = { + DataFrame(this, parseSql(sqlText)) } /** - * :: Experimental :: - * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a [[DataFrame]]. + * Returns the specified table as a [[DataFrame]]. * - * @group specificdata + * @group ddl_ops * @since 1.3.0 */ - @Experimental - def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { - jsonRDD(json.rdd, samplingRatio); - } + def table(tableName: String): DataFrame = + DataFrame(this, catalog.lookupRelation(Seq(tableName))) /** - * :: Experimental :: - * Returns the dataset stored at path as a DataFrame, - * using the default data source configured by spark.sql.sources.default. + * Returns a [[DataFrame]] containing names of existing tables in the current database. + * The returned DataFrame has two columns, tableName and isTemporary (a Boolean + * indicating if a table is a temporary one or not). * - * @group genericdata + * @group ddl_ops * @since 1.3.0 */ - @Experimental - def load(path: String): DataFrame = { - val dataSourceName = conf.defaultDataSourceName - load(path, dataSourceName) + def tables(): DataFrame = { + DataFrame(this, ShowTablesCommand(None)) } /** - * :: Experimental :: - * Returns the dataset stored at path as a DataFrame, using the given data source. + * Returns a [[DataFrame]] containing names of existing tables in the given database. + * The returned DataFrame has two columns, tableName and isTemporary (a Boolean + * indicating if a table is a temporary one or not). * - * @group genericdata + * @group ddl_ops * @since 1.3.0 */ - @Experimental - def load(path: String, source: String): DataFrame = { - load(source, Map("path" -> path)) - } - - /** - * :: Experimental :: - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load(source: String, options: java.util.Map[String, String]): DataFrame = { - load(source, options.toMap) - } - - /** - * :: Experimental :: - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load(source: String, options: Map[String, String]): DataFrame = { - val resolved = ResolvedDataSource(this, None, Array.empty[String], source, options) - DataFrame(this, LogicalRelation(resolved.relation)) - } - - /** - * :: Experimental :: - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load( - source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - load(source, schema, options.toMap) - } - - /** - * :: Experimental :: - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load( - source: String, - schema: StructType, - partitionColumns: Array[String], - options: java.util.Map[String, String]): DataFrame = { - load(source, schema, partitionColumns, options.toMap) - } - - /** - * :: Experimental :: - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load( - source: String, - schema: StructType, - options: Map[String, String]): DataFrame = { - val resolved = ResolvedDataSource(this, Some(schema), Array.empty[String], source, options) - DataFrame(this, LogicalRelation(resolved.relation)) - } - - /** - * :: Experimental :: - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * @group genericdata - * @since 1.3.0 - */ - @Experimental - def load( - source: String, - schema: StructType, - partitionColumns: Array[String], - options: Map[String, String]): DataFrame = { - val resolved = ResolvedDataSource(this, Some(schema), partitionColumns, source, options) - DataFrame(this, LogicalRelation(resolved.relation)) - } - - /** - * :: Experimental :: - * Creates an external table from the given path and returns the corresponding DataFrame. - * It will use the default data source configured by spark.sql.sources.default. - * - * @group ddl_ops - * @since 1.3.0 - */ - @Experimental - def createExternalTable(tableName: String, path: String): DataFrame = { - val dataSourceName = conf.defaultDataSourceName - createExternalTable(tableName, path, dataSourceName) - } - - /** - * :: Experimental :: - * Creates an external table from the given path based on a data source - * and returns the corresponding DataFrame. - * - * @group ddl_ops - * @since 1.3.0 - */ - @Experimental - def createExternalTable( - tableName: String, - path: String, - source: String): DataFrame = { - createExternalTable(tableName, source, Map("path" -> path)) - } - - /** - * :: Experimental :: - * Creates an external table from the given path based on a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @group ddl_ops - * @since 1.3.0 - */ - @Experimental - def createExternalTable( - tableName: String, - source: String, - options: java.util.Map[String, String]): DataFrame = { - createExternalTable(tableName, source, options.toMap) - } - - /** - * :: Experimental :: - * (Scala-specific) - * Creates an external table from the given path based on a data source and a set of options. - * Then, returns the corresponding DataFrame. - * - * @group ddl_ops - * @since 1.3.0 - */ - @Experimental - def createExternalTable( - tableName: String, - source: String, - options: Map[String, String]): DataFrame = { - val cmd = - CreateTableUsing( - tableName, - userSpecifiedSchema = None, - source, - temporary = false, - options, - allowExisting = false, - managedIfNoPath = false) - executePlan(cmd).toRdd - table(tableName) - } - - /** - * :: Experimental :: - * Create an external table from the given path based on a data source, a schema and - * a set of options. Then, returns the corresponding DataFrame. - * - * @group ddl_ops - * @since 1.3.0 - */ - @Experimental - def createExternalTable( - tableName: String, - source: String, - schema: StructType, - options: java.util.Map[String, String]): DataFrame = { - createExternalTable(tableName, source, schema, options.toMap) - } - - /** - * :: Experimental :: - * (Scala-specific) - * Create an external table from the given path based on a data source, a schema and - * a set of options. Then, returns the corresponding DataFrame. - * - * @group ddl_ops - * @since 1.3.0 - */ - @Experimental - def createExternalTable( - tableName: String, - source: String, - schema: StructType, - options: Map[String, String]): DataFrame = { - val cmd = - CreateTableUsing( - tableName, - userSpecifiedSchema = Some(schema), - source, - temporary = false, - options, - allowExisting = false, - managedIfNoPath = false) - executePlan(cmd).toRdd - table(tableName) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. - * - * @group specificdata - * @since 1.3.0 - */ - @Experimental - def jdbc(url: String, table: String): DataFrame = { - jdbc(url, table, JDBCRelation.columnPartition(null), new Properties()) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table and connection properties. - * - * @group specificdata - * @since 1.4.0 - */ - @Experimental - def jdbc(url: String, table: String, properties: Properties): DataFrame = { - jdbc(url, table, JDBCRelation.columnPartition(null), properties) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. Partitions of the table will be retrieved in parallel based on the parameters - * passed to this function. - * - * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions - * @group specificdata - * @since 1.3.0 - */ - @Experimental - def jdbc( - url: String, - table: String, - columnName: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int): DataFrame = { - jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties()) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. Partitions of the table will be retrieved in parallel based on the parameters - * passed to this function. - * - * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions - * @param properties connection properties - * @group specificdata - * @since 1.4.0 - */ - @Experimental - def jdbc( - url: String, - table: String, - columnName: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int, - properties: Properties): DataFrame = { - val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions) - val parts = JDBCRelation.columnPartition(partitioning) - jdbc(url, table, parts, properties) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. The theParts parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition - * of the [[DataFrame]]. - * - * @group specificdata - * @since 1.3.0 - */ - @Experimental - def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { - jdbc(url, table, theParts, new Properties()) - } - - /** - * :: Experimental :: - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table using connection properties. The theParts parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition - * of the [[DataFrame]]. - * - * @group specificdata - * @since 1.4.0 - */ - @Experimental - def jdbc( - url: String, - table: String, - theParts: Array[String], - properties: Properties): DataFrame = { - val parts: Array[Partition] = theParts.zipWithIndex.map { case (part, i) => - JDBCPartition(part, i) : Partition - } - jdbc(url, table, parts, properties) - } - - private def jdbc( - url: String, - table: String, - parts: Array[Partition], - properties: Properties): DataFrame = { - val relation = JDBCRelation(url, table, parts, properties)(this) - baseRelationToDataFrame(relation) - } - - /** - * Registers the given [[DataFrame]] as a temporary table in the catalog. Temporary tables exist - * only during the lifetime of this instance of SQLContext. - */ - private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - catalog.registerTable(Seq(tableName), df.logicalPlan) - } - - /** - * Drops the temporary table with the given table name in the catalog. If the table has been - * cached/persisted before, it's also unpersisted. - * - * @param tableName the name of the table to be unregistered. - * - * @group basic - * @since 1.3.0 - */ - def dropTempTable(tableName: String): Unit = { - cacheManager.tryUncacheQuery(table(tableName)) - catalog.unregisterTable(Seq(tableName)) - } - - /** - * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is - * used for SQL parsing can be configured with 'spark.sql.dialect'. - * - * @group basic - * @since 1.3.0 - */ - def sql(sqlText: String): DataFrame = { - DataFrame(this, parseSql(sqlText)) - } - - /** - * Returns the specified table as a [[DataFrame]]. - * - * @group ddl_ops - * @since 1.3.0 - */ - def table(tableName: String): DataFrame = - DataFrame(this, catalog.lookupRelation(Seq(tableName))) - - /** - * Returns a [[DataFrame]] containing names of existing tables in the current database. - * The returned DataFrame has two columns, tableName and isTemporary (a Boolean - * indicating if a table is a temporary one or not). - * - * @group ddl_ops - * @since 1.3.0 - */ - def tables(): DataFrame = { - DataFrame(this, ShowTablesCommand(None)) - } - - /** - * Returns a [[DataFrame]] containing names of existing tables in the given database. - * The returned DataFrame has two columns, tableName and isTemporary (a Boolean - * indicating if a table is a temporary one or not). - * - * @group ddl_ops - * @since 1.3.0 - */ - def tables(databaseName: String): DataFrame = { - DataFrame(this, ShowTablesCommand(Some(databaseName))) + def tables(databaseName: String): DataFrame = { + DataFrame(this, ShowTablesCommand(Some(databaseName))) } /** @@ -1270,7 +848,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val projectSet = AttributeSet(projectList.flatMap(_.references)) val filterSet = AttributeSet(filterPredicates.flatMap(_.references)) val filterCondition = - prunePushedDownFilters(filterPredicates).reduceLeftOption(expressions.And) + prunePushedDownFilters(filterPredicates).reduceLeftOption(catalyst.expressions.And) // Right now we still use a projection even if the only evaluation is applying an alias // to a column. Since this is a no-op, it could be avoided. However, using this @@ -1351,7 +929,7 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO: Don't just pick the first one... lazy val sparkPlan: SparkPlan = { SparkPlan.currentContext.set(self) - planner(optimizedPlan).next() + planner.plan(optimizedPlan).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. @@ -1450,12 +1028,346 @@ class SQLContext(@transient val sparkContext: SparkContext) * Returns a Catalyst Schema for the given java bean class. */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { - val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass)) + val (dataType, _) = JavaTypeInference.inferDataType(beanClass) dataType.asInstanceOf[StructType].fields.map { f => AttributeReference(f.name, f.dataType, f.nullable)() } } + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + // Deprecated methods + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) + } + + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema) + } + + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd, beanClass) + } + + /** + * @deprecated As of 1.3.0, replaced by `createDataFrame()`. + */ + @deprecated("use createDataFrame", "1.3.0") + def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd, beanClass) + } + + /** + * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty + * [[DataFrame]] if no paths are passed in. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().parquet()`. + */ + @deprecated("Use read.parquet()", "1.4.0") + @scala.annotation.varargs + def parquetFile(paths: String*): DataFrame = { + if (paths.isEmpty) { + emptyDataFrame + } else if (conf.parquetUseDataSourceApi) { + read.parquet(paths : _*) + } else { + DataFrame(this, parquet.ParquetRelation( + paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) + } + } + + /** + * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonFile(path: String): DataFrame = { + read.json(path) + } + + /** + * Loads a JSON file (one object per line) and applies the given schema, + * returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonFile(path: String, schema: StructType): DataFrame = { + read.schema(schema).json(path) + } + + /** + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonFile(path: String, samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(path) + } + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: RDD[String]): DataFrame = read.json(json) + + /** + * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a + * [[DataFrame]]. + * It goes through the entire dataset once to determine the schema. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) + + /** + * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, + * returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { + read.schema(schema).json(json) + } + + /** + * Loads an JavaRDD storing JSON objects (one object per record) and applies the given + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { + read.schema(schema).json(json) + } + + /** + * Loads an RDD[String] storing JSON objects (one object per record) inferring the + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(json) + } + + /** + * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the + * schema, returning the result as a [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().json()`. + */ + @deprecated("Use read.json()", "1.4.0") + def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { + read.option("samplingRatio", samplingRatio.toString).json(json) + } + + /** + * Returns the dataset stored at path as a DataFrame, + * using the default data source configured by spark.sql.sources.default. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().load(path)`. + */ + @deprecated("Use read.load(path)", "1.4.0") + def load(path: String): DataFrame = { + read.load(path) + } + + /** + * Returns the dataset stored at path as a DataFrame, using the given data source. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. + */ + @deprecated("Use read.format(source).load(path)", "1.4.0") + def load(path: String, source: String): DataFrame = { + read.format(source).load(path) + } + + /** + * (Java-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. + */ + @deprecated("Use read.format(source).options(options).load()", "1.4.0") + def load(source: String, options: java.util.Map[String, String]): DataFrame = { + read.options(options).format(source).load() + } + + /** + * (Scala-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. + */ + @deprecated("Use read.format(source).options(options).load()", "1.4.0") + def load(source: String, options: Map[String, String]): DataFrame = { + read.options(options).format(source).load() + } + + /** + * (Java-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by + * `read().format(source).schema(schema).options(options).load()`. + */ + @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = + { + read.format(source).schema(schema).options(options).load() + } + + /** + * (Scala-specific) Returns the dataset specified by the given data source and + * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. + * + * @group genericdata + * @deprecated As of 1.4.0, replaced by + * `read().format(source).schema(schema).options(options).load()`. + */ + @deprecated("Use read.format(source).schema(schema).options(options).load()", "1.4.0") + def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { + read.format(source).schema(schema).options(options).load() + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + */ + @deprecated("use read.jdbc()", "1.4.0") + def jdbc(url: String, table: String): DataFrame = { + read.jdbc(url, table, new Properties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. Partitions of the table will be retrieved in parallel based on the parameters + * passed to this function. + * + * @param columnName the name of a column of integral type that will be used for partitioning. + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride + * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split + * evenly into this many partitions + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + */ + @deprecated("use read.jdbc()", "1.4.0") + def jdbc( + url: String, + table: String, + columnName: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int): DataFrame = { + read.jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties) + } + + /** + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. The theParts parameter gives a list expressions + * suitable for inclusion in WHERE clauses; each one defines one partition + * of the [[DataFrame]]. + * + * @group specificdata + * @deprecated As of 1.4.0, replaced by `read().jdbc()`. + */ + @deprecated("use read.jdbc()", "1.4.0") + def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { + read.jdbc(url, table, theParts, new Properties) + } + + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + // End of deprecated methods + //////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////// + + + // Register a succesfully instantiatd context to the singleton. This should be at the end of + // the class definition so that the singleton is updated only if there is no exception in the + // construction of the instance. + SQLContext.setLastInstantiatedContext(self) } +/** + * This SQLContext object contains utility functions to create a singleton SQLContext instance, + * or to get the last created SQLContext instance. + */ +object SQLContext { + + private val INSTANTIATION_LOCK = new Object() + + /** + * Reference to the last created SQLContext. + */ + @transient private val lastInstantiatedContext = new AtomicReference[SQLContext]() + + /** + * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. + * This function can be used to create a singleton SQLContext object that can be shared across + * the JVM. + */ + def getOrCreate(sparkContext: SparkContext): SQLContext = { + INSTANTIATION_LOCK.synchronized { + if (lastInstantiatedContext.get() == null) { + new SQLContext(sparkContext) + } + } + lastInstantiatedContext.get() + } + + private[sql] def clearLastInstantiatedContext(): Unit = { + INSTANTIATION_LOCK.synchronized { + lastInstantiatedContext.set(null) + } + } + private[sql] def setLastInstantiatedContext(sqlContext: SQLContext): Unit = { + INSTANTIATION_LOCK.synchronized { + lastInstantiatedContext.set(sqlContext) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index 6b1ae81972e4e..305b306a79871 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -54,15 +54,15 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr } } - protected val AS = Keyword("AS") - protected val CACHE = Keyword("CACHE") - protected val CLEAR = Keyword("CLEAR") - protected val IN = Keyword("IN") - protected val LAZY = Keyword("LAZY") - protected val SET = Keyword("SET") - protected val SHOW = Keyword("SHOW") - protected val TABLE = Keyword("TABLE") - protected val TABLES = Keyword("TABLES") + protected val AS = Keyword("AS") + protected val CACHE = Keyword("CACHE") + protected val CLEAR = Keyword("CLEAR") + protected val IN = Keyword("IN") + protected val LAZY = Keyword("LAZY") + protected val SET = Keyword("SET") + protected val SHOW = Keyword("SHOW") + protected val TABLE = Keyword("TABLE") + protected val TABLES = Keyword("TABLES") protected val UNCACHE = Keyword("UNCACHE") override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | show | others diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index dc3389c41bbfa..3cc5c2441d8a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -46,6 +46,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], stringDataType: String): Unit = { @@ -70,6 +71,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { envVars, pythonIncludes, pythonExec, + pythonVer, broadcastVars, accumulator, dataType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index 505ab1301ec96..a02e202d2eebc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -58,14 +58,15 @@ private[sql] case class UserDefinedPythonFunction( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType) { /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ def apply(exprs: Column*): Column = { - val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars, - accumulator, dataType, exprs.map(_.expr)) + val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, + broadcastVars, accumulator, dataType, exprs.map(_.expr)) Column(udf) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 423ecdff5804a..604f3124e23ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -106,7 +106,7 @@ private[r] object SQLUtils { dfCols.map { col => colToRBytes(col) - } + } } def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = { @@ -121,7 +121,7 @@ private[r] object SQLUtils { val numRows = col.length val bos = new ByteArrayOutputStream() val dos = new DataOutputStream(bos) - + SerDe.writeInt(dos, numRows) col.map { item => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 0ded1cce68391..3db26fad2b92f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -236,7 +236,7 @@ private[sql] case class InMemoryColumnarTableScan( case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l - case IsNull(a: Attribute) => statsFor(a).nullCount > 0 + case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 } @@ -314,7 +314,7 @@ private[sql] case class InMemoryColumnarTableScan( columnAccessors(i).extractTo(nextRow, i) i += 1 } - nextRow + if (attributes.isEmpty) Row.empty else nextRow } override def hasNext: Boolean = columnAccessors(0).hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 18584c2dcf797..5fcc48a67948b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -15,18 +15,19 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK /** Holds a cached logical plan and its data */ -private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) +private[sql] case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) /** * Provides support in a SQLContext for caching query results and automatically using these cached diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index c3d2c7019a54a..f25d10fec0411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.serializer.Serializer -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.util.MutablePair object Exchange { @@ -85,7 +86,9 @@ case class Exchange( // corner-cases where a partitioner constructed with `numPartitions` partitions may output // fewer partitions (like RangePartitioner, for example). val conf = child.sqlContext.sparkContext.conf - val sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + val shuffleManager = SparkEnv.get.shuffleManager + val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] || + shuffleManager.isInstanceOf[UnsafeShuffleManager] val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) if (newOrdering.nonEmpty) { @@ -93,11 +96,11 @@ case class Exchange( // which requires a defensive copy. true } else if (sortBasedShuffleOn) { - // Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory. - // However, there are two special cases where we can avoid the copy, described below: - if (partitioner.numPartitions <= bypassMergeThreshold) { - // If the number of output partitions is sufficiently small, then Spark will fall back to - // the old hash-based shuffle write path which doesn't buffer deserialized records. + val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) { + // If we're using the original SortShuffleManager and the number of output partitions is + // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which + // doesn't buffer deserialized records. // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. false } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) { @@ -105,9 +108,14 @@ case class Exchange( // them. This optimization is guarded by a feature-flag and is only applied in cases where // shuffle dependency does not specify an ordering and the record serializer has certain // properties. If this optimization is enabled, we can safely avoid the copy. + // + // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081). false } else { - // None of the special cases held, so we must copy. + // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code + // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls + // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In + // both cases, we must copy. true } } else { @@ -288,7 +296,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ .sliding(2) .map { case Seq(a) => true - case Seq(a,b) => a compatibleWith b + case Seq(a, b) => a.compatibleWith(b) }.exists(!_) // Adds Exchange or Sort operators as required diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index a500269f3cdcf..f931dc95ef575 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -21,9 +21,9 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.DataType import org.apache.spark.sql.{Row, SQLContext} /** @@ -31,26 +31,19 @@ import org.apache.spark.sql.{Row, SQLContext} */ @DeveloperApi object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = { + def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[Row] = { data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new SpecificMutableRow(schema.fields.map(_.dataType)) - val schemaFields = schema.fields.toArray - val converters = schemaFields.map { - f => CatalystTypeConverters.createToCatalystConverter(f.dataType) - } - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = converters(i)(r.productElement(i)) - i += 1 - } - - mutableRow + val numColumns = outputTypes.length + val mutableRow = new GenericMutableRow(numColumns) + val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) + iterator.map { r => + var i = 0 + while (i < numColumns) { + mutableRow(i) = converters(i)(r.productElement(i)) + i += 1 } + + mutableRow } } } @@ -58,26 +51,19 @@ object RDDConversions { /** * Convert the objects inside Row into the types Catalyst expected. */ - def rowToRowRdd(data: RDD[Row], schema: StructType): RDD[Row] = { + def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[Row] = { data.mapPartitions { iterator => - if (iterator.isEmpty) { - Iterator.empty - } else { - val bufferedIterator = iterator.buffered - val mutableRow = new GenericMutableRow(bufferedIterator.head.toSeq.toArray) - val schemaFields = schema.fields.toArray - val converters = schemaFields.map { - f => CatalystTypeConverters.createToCatalystConverter(f.dataType) - } - bufferedIterator.map { r => - var i = 0 - while (i < mutableRow.length) { - mutableRow(i) = converters(i)(r(i)) - i += 1 - } - - mutableRow + val numColumns = outputTypes.length + val mutableRow = new GenericMutableRow(numColumns) + val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter) + iterator.map { r => + var i = 0 + while (i < numColumns) { + mutableRow(i) = converters(i)(r(i)) + i += 1 } + + mutableRow } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 08d9079335132..dd02c1f4573bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -21,6 +21,18 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ +/** + * For lazy computing, be sure the generator.terminate() called in the very last + * TODO reusing the CompletionIterator? + */ +private[execution] sealed case class LazyIterator(func: () => TraversableOnce[Row]) + extends Iterator[Row] { + + lazy val results = func().toIterator + override def hasNext: Boolean = results.hasNext + override def next(): Row = results.next() +} + /** * :: DeveloperApi :: * Applies a [[catalyst.expressions.Generator Generator]] to a stream of input rows, combining the @@ -47,27 +59,33 @@ case class Generate( val boundGenerator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[Row] = { + // boundGenerator.terminate() should be triggered after all of the rows in the partition if (join) { child.execute().mapPartitions { iter => - val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null)) - // Used to produce rows with no matches when outer = true. - val outerProjection = - newProjection(child.output ++ nullValues, child.output) - - val joinProjection = newProjection(output, output) + val generatorNullRow = Row.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow - iter.flatMap {row => + iter.flatMap { row => + // we should always set the left (child output) + joinedRow.withLeft(row) val outputRows = boundGenerator.eval(row) if (outer && outputRows.isEmpty) { - outerProjection(row) :: Nil + joinedRow.withRight(generatorNullRow) :: Nil } else { - outputRows.map(or => joinProjection(joinedRow(row, or))) + outputRows.map(or => joinedRow.withRight(or)) } + } ++ LazyIterator(() => boundGenerator.terminate()).map { row => + // we leave the left side as the last element of its child output + // keep it the same as Hive does + joinedRow.withRight(row) } } } else { - child.execute().mapPartitions(iter => iter.flatMap(row => boundGenerator.eval(row))) + child.execute().mapPartitions { iter => + iter.flatMap(row => boundGenerator.eval(row)) ++ + LazyIterator(() => boundGenerator.terminate()) + } } } } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 2ec7d4fbc92de..3e27c1bde2dfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -138,15 +138,15 @@ case class GeneratedAggregate( case UnscaledValue(e) => e case _ => expr } - // partial sum result can be null only when no input rows present + // partial sum result can be null only when no input rows present val updateFunction = If( IsNotNull(actualExpr), Coalesce( Add( - Coalesce(currentSum :: zero :: Nil), + Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: zero :: Nil), currentSum) - + val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -155,7 +155,7 @@ case class GeneratedAggregate( } AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - + case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index af0029cb84f9a..d0a1ad00560d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -243,8 +243,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case (predicate, None) => predicate // Filter needs to be applied above when it contains partitioning // columns - case (predicate, _) if(!predicate.references.map(_.name).toSet - .intersect (partitionColNames).isEmpty) => predicate + case (predicate, _) + if !predicate.references.map(_.name).toSet.intersect(partitionColNames).isEmpty => + predicate } } } else { @@ -270,7 +271,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, identity[Seq[Expression]], // All filters still need to be evaluated. - InMemoryColumnarTableScan(_, filters, mem)) :: Nil + InMemoryColumnarTableScan(_, filters, mem)) :: Nil case _ => Nil } } @@ -354,10 +355,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") - case LogicalDescribeCommand(table, isExtended) => + case describe @ LogicalDescribeCommand(table, isExtended) => val resultPlan = self.sqlContext.executePlan(table).executedPlan ExecutedCommand( - RunnableDescribeCommand(resultPlan, resultPlan.output, isExtended)) :: Nil + RunnableDescribeCommand(resultPlan, describe.output, isExtended)) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 6cb67b4bbbb65..a30ade86441ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -65,7 +65,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { * :: DeveloperApi :: * Sample the dataset. * @param lowerBound Lower-bound of the sampling probability (usually 0.0) - * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled + * @param upperBound Upper-bound of the sampling probability. The expected fraction sampled * will be ub - lb. * @param withReplacement Whether to sample with replacement. * @param seed the random seed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 9ac732b55b188..e228a60c9029f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -39,8 +39,6 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { */ @transient private[this] var count: Long = 0L - override type EvaluatedType = Long - override def nullable: Boolean = false override def dataType: DataType = LongType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index c2c6cbd491598..1272793f88cd0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.types.{IntegerType, DataType} */ private[sql] case object SparkPartitionID extends LeafExpression { - override type EvaluatedType = Int - override def nullable: Boolean = false override def dataType: DataType = IntegerType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 05dd5681edfac..b8b12be8756f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.rdd.RDD +import org.apache.spark.util.ThreadUtils import scala.concurrent._ import scala.concurrent.duration._ -import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Row, Expression} @@ -64,7 +64,7 @@ case class BroadcastHashJoin( val input: Array[Row] = buildPlan.execute().map(_.copy()).collect() val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) sparkContext.broadcast(hashed) - } + }(BroadcastHashJoin.broadcastHashJoinExecutionContext) protected override def doExecute(): RDD[Row] = { val broadcastRelation = Await.result(broadcastFuture, timeout) @@ -74,3 +74,9 @@ case class BroadcastHashJoin( } } } + +object BroadcastHashJoin { + + private val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 640fc26ba3baa..a32e5fc4f7ea4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -39,7 +39,7 @@ case class BroadcastLeftSemiJoinHash( override def output: Seq[Attribute] = left.output protected override def doExecute(): RDD[Row] = { - val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator + val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator val hashSet = new java.util.HashSet[Row]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 3dbc3837950e0..55f3ff4709013 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -19,20 +19,21 @@ package org.apache.spark.sql.execution import java.util.{List => JList, Map => JMap} -import org.apache.spark.rdd.RDD - import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ import org.apache.spark.{Accumulator, Logging => SparkLogging} @@ -45,6 +46,7 @@ private[spark] case class PythonUDF( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + pythonVer: String, broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, @@ -54,7 +56,7 @@ private[spark] case class PythonUDF( def nullable: Boolean = true - override def eval(input: Row): PythonUDF.this.EvaluatedType = { + override def eval(input: Row): Any = { sys.error("PythonUDFs can not be directly evaluated.") } } @@ -250,6 +252,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: udf.pythonIncludes, false, udf.pythonExec, + udf.pythonVer, udf.broadcastVars, udf.accumulator ).mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 5ae7e107544f8..c41c21c0eeb50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -62,7 +62,7 @@ private[sql] object FrequentItems extends Logging { } /** - * Finding frequent items for columns, possibly with false positives. Using the + * Finding frequent items for columns, possibly with false positives. Using the * frequent element count algorithm described in * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]]. * The `support` should be greater than 1e-4. @@ -75,7 +75,7 @@ private[sql] object FrequentItems extends Logging { * @return A Local DataFrame with the Array of frequent items for each column. */ private[sql] def singlePassFreqItems( - df: DataFrame, + df: DataFrame, cols: Seq[String], support: Double): DataFrame = { require(support >= 1e-4, s"support ($support) must be greater than 1e-4.") @@ -88,8 +88,8 @@ private[sql] object FrequentItems extends Logging { val index = originalSchema.fieldIndex(name) (name, originalSchema.fields(index).dataType) } - - val freqItems = df.select(cols.map(Column(_)):_*).rdd.aggregate(countMaps)( + + val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { @@ -110,7 +110,7 @@ private[sql] object FrequentItems extends Logging { } ) val justItems = freqItems.map(m => m.baseMap.keys.toSeq) - val resultRow = Row(justItems:_*) + val resultRow = Row(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => StructField(v._1 + "_freqItems", ArrayType(v._2, false)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index d22f5fd2d439c..93383e5a62f11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.Logging -import org.apache.spark.sql.{Column, DataFrame} +import org.apache.spark.sql.{Row, Column, DataFrame} import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ private[sql] object StatFunctions extends Logging { - + /** Calculate the Pearson Correlation Coefficient for the given columns */ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { val counts = collectStatisticalData(df, cols) @@ -116,7 +116,10 @@ private[sql] object StatFunctions extends Logging { s"exceed 1e4. Currently $columnSize") val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => val countsRow = new GenericMutableRow(columnSize + 1) - rows.foreach { row => + rows.foreach { (row: Row) => + // row.get(0) is column 1 + // row.get(1) is column 2 + // row.get(3) is the frequency countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts @@ -126,6 +129,6 @@ private[sql] object StatFunctions extends Logging { val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq val schema = StructType(StructField(tableName, StringType) +: headerNames) - new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)) + new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala new file mode 100644 index 0000000000000..e9b60841fc28c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions._ + +/** + * :: Experimental :: + * Utility functions for defining window in DataFrames. + * + * {{{ + * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * + * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING + * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) + * }}} + * + * @since 1.4.0 + */ +@Experimental +object Window { + + /** + * Creates a [[WindowSpec]] with the partitioning defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowSpec = { + spec.partitionBy(colName, colNames : _*) + } + + /** + * Creates a [[WindowSpec]] with the partitioning defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowSpec = { + spec.partitionBy(cols : _*) + } + + /** + * Creates a [[WindowSpec]] with the ordering defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowSpec = { + spec.orderBy(colName, colNames : _*) + } + + /** + * Creates a [[WindowSpec]] with the ordering defined. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(cols: Column*): WindowSpec = { + spec.orderBy(cols : _*) + } + + private def spec: WindowSpec = { + new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) + } + +} + +/** + * :: Experimental :: + * Utility functions for defining window in DataFrames. + * + * {{{ + * // PARTITION BY country ORDER BY date ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + * Window.partitionBy("country").orderBy("date").rowsBetween(Long.MinValue, 0) + * + * // PARTITION BY country ORDER BY date ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING + * Window.partitionBy("country").orderBy("date").rowsBetween(-3, 3) + * }}} + * + * @since 1.4.0 + */ +@Experimental +class Window private() // So we can see Window in JavaDoc. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala new file mode 100644 index 0000000000000..c3d2246297021 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.{Column, catalyst} +import org.apache.spark.sql.catalyst.expressions._ + + +/** + * :: Experimental :: + * A window specification that defines the partitioning, ordering, and frame boundaries. + * + * Use the static methods in [[Window]] to create a [[WindowSpec]]. + * + * @since 1.4.0 + */ +@Experimental +class WindowSpec private[sql]( + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], + frame: catalyst.expressions.WindowFrame) { + + /** + * Defines the partitioning columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(colName: String, colNames: String*): WindowSpec = { + partitionBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Defines the partitioning columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def partitionBy(cols: Column*): WindowSpec = { + new WindowSpec(cols.map(_.expr), orderSpec, frame) + } + + /** + * Defines the ordering columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(colName: String, colNames: String*): WindowSpec = { + orderBy((colName +: colNames).map(Column(_)): _*) + } + + /** + * Defines the ordering columns in a [[WindowSpec]]. + * @since 1.4.0 + */ + @scala.annotation.varargs + def orderBy(cols: Column*): WindowSpec = { + val sortOrder: Seq[SortOrder] = cols.map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + new WindowSpec(partitionSpec, sortOrder, frame) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative positions from the current row. For example, "0" means + * "current row", while "-1" means the row before the current row, and "5" means the fifth row + * after the current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 1.4.0 + */ + def rowsBetween(start: Long, end: Long): WindowSpec = { + between(RowFrame, start, end) + } + + /** + * Defines the frame boundaries, from `start` (inclusive) to `end` (inclusive). + * + * Both `start` and `end` are relative from the current row. For example, "0" means "current row", + * while "-1" means one off before the current row, and "5" means the five off after the + * current row. + * + * @param start boundary start, inclusive. + * The frame is unbounded if this is the minimum long value. + * @param end boundary end, inclusive. + * The frame is unbounded if this is the maximum long value. + * @since 1.4.0 + */ + def rangeBetween(start: Long, end: Long): WindowSpec = { + between(RangeFrame, start, end) + } + + private def between(typ: FrameType, start: Long, end: Long): WindowSpec = { + val boundaryStart = start match { + case 0 => CurrentRow + case Long.MinValue => UnboundedPreceding + case x if x < 0 => ValuePreceding(-start.toInt) + case x if x > 0 => ValueFollowing(start.toInt) + } + + val boundaryEnd = end match { + case 0 => CurrentRow + case Long.MaxValue => UnboundedFollowing + case x if x < 0 => ValuePreceding(-end.toInt) + case x if x > 0 => ValueFollowing(end.toInt) + } + + new WindowSpec( + partitionSpec, + orderSpec, + SpecifiedWindowFrame(typ, boundaryStart, boundaryEnd)) + } + + /** + * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. + */ + private[sql] def withAggregate(aggregate: Column): Column = { + val windowExpr = aggregate.expr match { + case Average(child) => WindowExpression( + UnresolvedWindowFunction("avg", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Sum(child) => WindowExpression( + UnresolvedWindowFunction("sum", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Count(child) => WindowExpression( + UnresolvedWindowFunction("count", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case First(child) => WindowExpression( + // TODO this is a hack for Hive UDAF first_value + UnresolvedWindowFunction("first_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Last(child) => WindowExpression( + // TODO this is a hack for Hive UDAF last_value + UnresolvedWindowFunction("last_value", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Min(child) => WindowExpression( + UnresolvedWindowFunction("min", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Max(child) => WindowExpression( + UnresolvedWindowFunction("max", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case wf: WindowFunction => WindowExpression( + wf, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => + throw new UnsupportedOperationException(s"$x is not supported in window operation.") + } + new Column(windowExpr) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 099e1d8f03272..77327f2b84eaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -37,6 +37,7 @@ import org.apache.spark.util.Utils * @groupname sort_funcs Sorting functions * @groupname normal_funcs Non-aggregate functions * @groupname math_funcs Math functions + * @groupname window_funcs Window functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -186,7 +187,7 @@ object functions { */ @scala.annotation.varargs def countDistinct(columnName: String, columnNames: String*): Column = - countDistinct(Column(columnName), columnNames.map(Column.apply) :_*) + countDistinct(Column(columnName), columnNames.map(Column.apply) : _*) /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -320,6 +321,218 @@ object functions { */ def max(columnName: String): Column = max(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // Window functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `null` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(e: Column, offset: Int): Column = { + lag(e, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `null` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(columnName: String, offset: Int): Column = { + lag(columnName, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `defaultValue` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(columnName: String, offset: Int, defaultValue: Any): Column = { + lag(Column(columnName), offset, defaultValue) + } + + /** + * Window function: returns the value that is `offset` rows before the current row, and + * `defaultValue` if there is less than `offset` rows before the current row. For example, + * an `offset` of one will return the previous row at any given point in the window partition. + * + * This is equivalent to the LAG function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lag(e: Column, offset: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `null` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(columnName: String, offset: Int): Column = { + lead(columnName, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `null` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(e: Column, offset: Int): Column = { + lead(e, offset, null) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `defaultValue` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(columnName: String, offset: Int, defaultValue: Any): Column = { + lead(Column(columnName), offset, defaultValue) + } + + /** + * Window function: returns the value that is `offset` rows after the current row, and + * `defaultValue` if there is less than `offset` rows after the current row. For example, + * an `offset` of one will return the next row at any given point in the window partition. + * + * This is equivalent to the LEAD function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def lead(e: Column, offset: Int, defaultValue: Any): Column = { + UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + } + + /** + * Window function: returns the ntile group id (from 1 to `n` inclusive) in an ordered window + * partition. Fow example, if `n` is 4, the first quarter of the rows will get value 1, the second + * quarter will get 2, the third quarter will get 3, and the last quarter will get 4. + * + * This is equivalent to the NTILE function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def ntile(n: Int): Column = { + UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) + } + + /** + * Window function: returns a sequential number starting at 1 within a window partition. + * + * This is equivalent to the ROW_NUMBER function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def rowNumber(): Column = { + UnresolvedWindowFunction("row_number", Nil) + } + + /** + * Window function: returns the rank of rows within a window partition, without any gaps. + * + * The difference between rank and denseRank is that denseRank leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using denseRank + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * This is equivalent to the DENSE_RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def denseRank(): Column = { + UnresolvedWindowFunction("dense_rank", Nil) + } + + /** + * Window function: returns the rank of rows within a window partition. + * + * The difference between rank and denseRank is that denseRank leaves no gaps in ranking + * sequence when there are ties. That is, if you were ranking a competition using denseRank + * and had three people tie for second place, you would say that all three were in second + * place and that the next person came in third. + * + * This is equivalent to the RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def rank(): Column = { + UnresolvedWindowFunction("rank", Nil) + } + + /** + * Window function: returns the cumulative distribution of values within a window partition, + * i.e. the fraction of rows that are below the current row. + * + * {{{ + * N = total number of rows in the partition + * cumeDist(x) = number of values before (and including) x / N + * }}} + * + * + * This is equivalent to the CUME_DIST function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def cumeDist(): Column = { + UnresolvedWindowFunction("cume_dist", Nil) + } + + /** + * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. + * + * This is computed by: + * {{{ + * (rank of row in its partition - 1) / (number of rows in the partition - 1) + * }}} + * + * This is equivalent to the PERCENT_RANK function in SQL. + * + * @group window_funcs + * @since 1.4.0 + */ + def percentRank(): Column = { + UnresolvedWindowFunction("percent_rank", Nil) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -363,6 +576,11 @@ object functions { @scala.annotation.varargs def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + /** + * Creates a new row for each element in the given array or map column. + */ + def explode(e: Column): Column = Explode(e.expr) + /** * Converts a string exprsesion to lower case. * @@ -438,6 +656,7 @@ object functions { * }}} * * @group normal_funcs + * @since 1.4.0 */ def when(condition: Column, value: Any): Column = { CaseWhen(Seq(condition.expr, lit(value).expr)) @@ -1080,7 +1299,7 @@ object functions { * @since 1.4.0 */ def toRadians(columnName: String): Column = toRadians(Column(columnName)) - + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala deleted file mode 100644 index 0feabc4282f4a..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.jdbc - -import org.apache.spark.sql.types._ - -import java.sql.Types - - -/** - * Encapsulates workarounds for the extensions, quirks, and bugs in various - * databases. Lots of databases define types that aren't explicitly supported - * by the JDBC spec. Some JDBC drivers also report inaccurate - * information---for instance, BIT(n>1) being reported as a BIT type is quite - * common, even though BIT in JDBC is meant for single-bit values. Also, there - * does not appear to be a standard name for an unbounded string or binary - * type; we use BLOB and CLOB by default but override with database-specific - * alternatives when these are absent or do not behave correctly. - * - * Currently, the only thing DriverQuirks does is handle type mapping. - * `getCatalystType` is used when reading from a JDBC table and `getJDBCType` - * is used when writing to a JDBC table. If `getCatalystType` returns `null`, - * the default type handling is used for the given JDBC type. Similarly, - * if `getJDBCType` returns `(null, None)`, the default type handling is used - * for the given Catalyst type. - */ -private[sql] abstract class DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType - def getJDBCType(dt: DataType): (String, Option[Int]) -} - -private[sql] object DriverQuirks { - /** - * Fetch the DriverQuirks class corresponding to a given database url. - */ - def get(url: String): DriverQuirks = { - if (url.startsWith("jdbc:mysql")) { - new MySQLQuirks() - } else if (url.startsWith("jdbc:postgresql")) { - new PostgresQuirks() - } else { - new NoQuirks() - } - } -} - -private[sql] class NoQuirks extends DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = - null - def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) -} - -private[sql] class PostgresQuirks extends DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { - if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - BinaryType - } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - StringType - } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - StringType - } else null - } - - def getJDBCType(dt: DataType): (String, Option[Int]) = dt match { - case StringType => ("TEXT", Some(java.sql.Types.CHAR)) - case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY)) - case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN)) - case _ => (null, None) - } -} - -private[sql] class MySQLQuirks extends DriverQuirks { - def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = { - if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { - // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as - // byte arrays instead of longs. - md.putLong("binarylong", 1) - LongType - } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { - BooleanType - } else null - } - def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index a03ade3881f59..40b604d710dce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -25,24 +25,38 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ import org.apache.spark.sql.sources._ -import org.apache.spark.util.Utils + +/** + * Data corresponding to one partition of a JDBCRDD. + */ +private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition { + override def index: Int = idx +} + private[sql] object JDBCRDD extends Logging { + /** * Maps a JDBC type to a Catalyst type. This function is called only when - * the DriverQuirks class corresponding to your database driver returns null. + * the JdbcDialect class corresponding to your database driver returns null. * * @param sqlType - A field of java.sql.Types * @return The Catalyst type corresponding to sqlType. */ - private def getCatalystType(sqlType: Int, precision: Int, scale: Int): DataType = { + private def getCatalystType( + sqlType: Int, + precision: Int, + scale: Int, + signed: Boolean): DataType = { val answer = sqlType match { + // scalastyle:off case java.sql.Types.ARRAY => null - case java.sql.Types.BIGINT => LongType + case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType.Unlimited } case java.sql.Types.BINARY => BinaryType - case java.sql.Types.BIT => BooleanType // Per JDBC; Quirks handles quirky drivers. + case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks case java.sql.Types.BLOB => BinaryType case java.sql.Types.BOOLEAN => BooleanType case java.sql.Types.CHAR => StringType @@ -55,7 +69,7 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.DISTINCT => null case java.sql.Types.DOUBLE => DoubleType case java.sql.Types.FLOAT => FloatType - case java.sql.Types.INTEGER => IntegerType + case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType } case java.sql.Types.JAVA_OBJECT => null case java.sql.Types.LONGNVARCHAR => StringType case java.sql.Types.LONGVARBINARY => BinaryType @@ -79,7 +93,8 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType - case _ => null + case _ => null + // scalastyle:on } if (answer == null) throw new SQLException("Unsupported type " + sqlType) @@ -99,7 +114,7 @@ private[sql] object JDBCRDD extends Logging { * @throws SQLException if the table contains an unsupported type. */ def resolveTable(url: String, table: String, properties: Properties): StructType = { - val quirks = DriverQuirks.get(url) + val dialect = JdbcDialects.get(url) val conn: Connection = DriverManager.getConnection(url, properties) try { val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() @@ -114,10 +129,12 @@ private[sql] object JDBCRDD extends Logging { val typeName = rsmd.getColumnTypeName(i + 1) val fieldSize = rsmd.getPrecision(i + 1) val fieldScale = rsmd.getScale(i + 1) + val isSigned = rsmd.isSigned(i + 1) val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls val metadata = new MetadataBuilder().putString("name", columnName) - var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata) - if (columnType == null) columnType = getCatalystType(dataType, fieldSize, fieldScale) + val columnType = + dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( + getCatalystType(dataType, fieldSize, fieldScale, isSigned)) fields(i) = StructField(columnName, columnType, nullable, metadata.build()) i = i + 1 } @@ -168,6 +185,7 @@ private[sql] object JDBCRDD extends Logging { DriverManager.getConnection(url, properties) } } + /** * Build and return JDBCRDD from the given information. * @@ -193,18 +211,15 @@ private[sql] object JDBCRDD extends Logging { requiredColumns: Array[String], filters: Array[Filter], parts: Array[Partition]): RDD[Row] = { - - val prunedSchema = pruneSchema(schema, requiredColumns) - - return new - JDBCRDD( - sc, - getConnector(driver, url, properties), - prunedSchema, - fqTable, - requiredColumns, - filters, - parts) + new JDBCRDD( + sc, + getConnector(driver, url, properties), + pruneSchema(schema, requiredColumns), + fqTable, + requiredColumns, + filters, + parts, + properties) } } @@ -220,7 +235,8 @@ private[sql] class JDBCRDD( fqTable: String, columns: Array[String], filters: Array[Filter], - partitions: Array[Partition]) + partitions: Array[Partition], + properties: Properties) extends RDD[Row](sc, Nil) { /** @@ -246,7 +262,7 @@ private[sql] class JDBCRDD( } private def escapeSql(value: String): String = - if (value == null) null else StringUtils.replace(value, "'", "''") + if (value == null) null else StringUtils.replace(value, "'", "''") /** * Turns a single Filter into a String representing a SQL expression. @@ -288,13 +304,13 @@ private[sql] class JDBCRDD( // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that // we don't have to potentially poke around in the Metadata once for every - // row. + // row. // Is there a better way to do this? I'd rather be using a type that // contains only the tags I define. abstract class JDBCConversion case object BooleanConversion extends JDBCConversion case object DateConversion extends JDBCConversion - case object DecimalConversion extends JDBCConversion + case class DecimalConversion(precisionInfo: Option[(Int, Int)]) extends JDBCConversion case object DoubleConversion extends JDBCConversion case object FloatConversion extends JDBCConversion case object IntegerConversion extends JDBCConversion @@ -309,19 +325,19 @@ private[sql] class JDBCRDD( */ def getConversions(schema: StructType): Array[JDBCConversion] = { schema.fields.map(sf => sf.dataType match { - case BooleanType => BooleanConversion - case DateType => DateConversion - case DecimalType.Unlimited => DecimalConversion - case DecimalType.Fixed(d) => DecimalConversion - case DoubleType => DoubleConversion - case FloatType => FloatConversion - case IntegerType => IntegerConversion - case LongType => + case BooleanType => BooleanConversion + case DateType => DateConversion + case DecimalType.Unlimited => DecimalConversion(None) + case DecimalType.Fixed(d) => DecimalConversion(Some(d)) + case DoubleType => DoubleConversion + case FloatType => FloatConversion + case IntegerType => IntegerConversion + case LongType => if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion - case StringType => StringConversion - case TimestampType => TimestampConversion - case BinaryType => BinaryConversion - case _ => throw new IllegalArgumentException(s"Unsupported field $sf") + case StringType => StringConversion + case TimestampType => TimestampConversion + case BinaryType => BinaryConversion + case _ => throw new IllegalArgumentException(s"Unsupported field $sf") }).toArray } @@ -349,6 +365,8 @@ private[sql] class JDBCRDD( val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" val stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + val fetchSize = properties.getProperty("fetchSize", "0").toInt + stmt.setFetchSize(fetchSize) val rs = stmt.executeQuery() val conversions = getConversions(schema) @@ -360,8 +378,8 @@ private[sql] class JDBCRDD( while (i < conversions.length) { val pos = i + 1 conversions(i) match { - case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) - case DateConversion => + case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos)) + case DateConversion => // DateUtils.fromJavaDate does not handle null value, so we need to check it. val dateVal = rs.getDate(pos) if (dateVal != null) { @@ -369,21 +387,36 @@ private[sql] class JDBCRDD( } else { mutableRow.update(i, null) } - case DecimalConversion => + // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal + // object returned by ResultSet.getBigDecimal is not correctly matched to the table + // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale. + // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through + // a BigDecimal object with scale as 0. But the dataframe schema has correct type as + // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then + // retrieve it, you will get wrong result 199.99. + // So it is needed to set precision and scale for Decimal based on JDBC metadata. + case DecimalConversion(Some((p, s))) => + val decimalVal = rs.getBigDecimal(pos) + if (decimalVal == null) { + mutableRow.update(i, null) + } else { + mutableRow.update(i, Decimal(decimalVal, p, s)) + } + case DecimalConversion(None) => val decimalVal = rs.getBigDecimal(pos) if (decimalVal == null) { mutableRow.update(i, null) } else { mutableRow.update(i, Decimal(decimalVal)) } - case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) - case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) - case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) - case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) + case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) + case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) + case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) + case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 - case StringConversion => mutableRow.setString(i, rs.getString(pos)) - case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) - case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) + case StringConversion => mutableRow.setString(i, rs.getString(pos)) + case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos)) + case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) case BinaryLongConversion => { val bytes = rs.getBytes(pos) var ans = 0L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index d6b3fb3291a2e..30f9190d45bf8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -17,25 +17,16 @@ package org.apache.spark.sql.jdbc -import java.sql.DriverManager import java.util.Properties import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils - -/** - * Data corresponding to one partition of a JDBCRDD. - */ -private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition { - override def index: Int = idx -} /** * Instructions on how to partition the table among workers. @@ -63,7 +54,7 @@ private[sql] object JDBCRelation { if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0)) // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. - val stride: Long = (partitioning.upperBound / numPartitions + val stride: Long = (partitioning.upperBound / numPartitions - partitioning.lowerBound / numPartitions) var i: Int = 0 var currentValue: Long = partitioning.lowerBound @@ -129,7 +120,8 @@ private[sql] case class JDBCRelation( parts: Array[Partition], properties: Properties = new Properties())(@transient val sqlContext: SQLContext) extends BaseRelation - with PrunedFilteredScan { + with PrunedFilteredScan + with InsertableRelation { override val needConversion: Boolean = false @@ -148,4 +140,10 @@ private[sql] case class JDBCRelation( filters, parts) } + + override def insert(data: DataFrame, overwrite: Boolean): Unit = { + data.write + .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) + .jdbc(url, table, properties) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala new file mode 100644 index 0000000000000..6a169e106b968 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types._ +import org.apache.spark.annotation.DeveloperApi + +import java.sql.Types + +/** + * :: DeveloperApi :: + * A database type definition coupled with the jdbc type needed to send null + * values to the database. + * @param databaseTypeDefinition The database type definition + * @param jdbcNullType The jdbc type (as defined in java.sql.Types) used to + * send a null value to the database. + */ +@DeveloperApi +case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) + +/** + * :: DeveloperApi :: + * Encapsulates everything (extensions, workarounds, quirks) to handle the + * SQL dialect of a certain database or jdbc driver. + * Lots of databases define types that aren't explicitly supported + * by the JDBC spec. Some JDBC drivers also report inaccurate + * information---for instance, BIT(n>1) being reported as a BIT type is quite + * common, even though BIT in JDBC is meant for single-bit values. Also, there + * does not appear to be a standard name for an unbounded string or binary + * type; we use BLOB and CLOB by default but override with database-specific + * alternatives when these are absent or do not behave correctly. + * + * Currently, the only thing done by the dialect is type mapping. + * `getCatalystType` is used when reading from a JDBC table and `getJDBCType` + * is used when writing to a JDBC table. If `getCatalystType` returns `null`, + * the default type handling is used for the given JDBC type. Similarly, + * if `getJDBCType` returns `(null, None)`, the default type handling is used + * for the given Catalyst type. + */ +@DeveloperApi +abstract class JdbcDialect { + /** + * Check if this dialect instance can handle a certain jdbc url. + * @param url the jdbc url. + * @return True if the dialect can be applied on the given jdbc url. + * @throws NullPointerException if the url is null. + */ + def canHandle(url : String): Boolean + + /** + * Get the custom datatype mapping for the given jdbc meta information. + * @param sqlType The sql type (see java.sql.Types) + * @param typeName The sql type name (e.g. "BIGINT UNSIGNED") + * @param size The size of the type. + * @param md Result metadata associated with this type. + * @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]]) + * or null if the default type mapping should be used. + */ + def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None + + /** + * Retrieve the jdbc / sql type for a given datatype. + * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) + * @return The new JdbcType if there is an override for this DataType + */ + def getJDBCType(dt: DataType): Option[JdbcType] = None +} + +/** + * :: DeveloperApi :: + * Registry of dialects that apply to every new jdbc [[org.apache.spark.sql.DataFrame]]. + * + * If multiple matching dialects are registered then all matching ones will be + * tried in reverse order. A user-added dialect will thus be applied first, + * overwriting the defaults. + * + * Note that all new dialects are applied to new jdbc DataFrames only. Make + * sure to register your dialects first. + */ +@DeveloperApi +object JdbcDialects { + + private var dialects = List[JdbcDialect]() + + /** + * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]]. + * Readding an existing dialect will cause a move-to-front. + * @param dialect The new dialect. + */ + def registerDialect(dialect: JdbcDialect) : Unit = { + dialects = dialect :: dialects.filterNot(_ == dialect) + } + + /** + * Unregister a dialect. Does nothing if the dialect is not registered. + * @param dialect The jdbc dialect. + */ + def unregisterDialect(dialect : JdbcDialect) : Unit = { + dialects = dialects.filterNot(_ == dialect) + } + + registerDialect(MySQLDialect) + registerDialect(PostgresDialect) + + /** + * Fetch the JdbcDialect class corresponding to a given database url. + */ + private[sql] def get(url: String): JdbcDialect = { + val matchingDialects = dialects.filter(_.canHandle(url)) + matchingDialects.length match { + case 0 => NoopDialect + case 1 => matchingDialects.head + case _ => new AggregatedDialect(matchingDialects) + } + } +} + +/** + * :: DeveloperApi :: + * AggregatedDialect can unify multiple dialects into one virtual Dialect. + * Dialects are tried in order, and the first dialect that does not return a + * neutral element will will. + * @param dialects List of dialects. + */ +@DeveloperApi +class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { + + require(!dialects.isEmpty) + + def canHandle(url : String): Boolean = + dialects.map(_.canHandle(url)).reduce(_ && _) + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + dialects.map(_.getCatalystType(sqlType, typeName, size, md)).flatten.headOption + + override def getJDBCType(dt: DataType): Option[JdbcType] = + dialects.map(_.getJDBCType(dt)).flatten.headOption + +} + +/** + * :: DeveloperApi :: + * NOOP dialect object, always returning the neutral element. + */ +@DeveloperApi +case object NoopDialect extends JdbcDialect { + def canHandle(url : String): Boolean = true +} + +/** + * :: DeveloperApi :: + * Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write. + */ +@DeveloperApi +case object PostgresDialect extends JdbcDialect { + def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { + Some(BinaryType) + } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { + Some(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("inet")) { + Some(StringType) + } else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) + case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) + case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + case _ => None + } +} + +/** + * :: DeveloperApi :: + * Default mysql dialect to read bit/bitsets correctly. + */ +@DeveloperApi +case object MySQLDialect extends JdbcDialect { + def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { + // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as + // byte arrays instead of longs. + md.putLong("binarylong", 1) + Some(LongType) + } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { + Some(BooleanType) + } else None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala new file mode 100644 index 0000000000000..cc918c237192b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcUtils.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.{Connection, DriverManager} +import java.util.Properties + +import scala.util.Try + +/** + * Util functions for JDBC tables. + */ +private[sql] object JdbcUtils { + + /** + * Establishes a JDBC connection. + */ + def createConnection(url: String, connectionProperties: Properties): Connection = { + DriverManager.getConnection(url, connectionProperties) + } + + /** + * Returns true if the table already exists in the JDBC database. + */ + def tableExists(conn: Connection, table: String): Boolean = { + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all + // SQL database systems, considering "table" could also include the database name. + Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + } + + /** + * Drops a table from the JDBC database. + */ + def dropTable(conn: Connection, table: String): Unit = { + conn.prepareStatement(s"DROP TABLE $table").executeUpdate() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index c099881a01226..dd8aaf6474895 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -129,25 +129,26 @@ package object jdbc { */ def schemaString(df: DataFrame, url: String): String = { val sb = new StringBuilder() - val quirks = DriverQuirks.get(url) + val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => { val name = field.name - var typ: String = quirks.getJDBCType(field.dataType)._1 - if (typ == null) typ = field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case DecimalType.Unlimited => "DECIMAL(40,20)" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - } + val typ: String = + dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( + field.dataType match { + case IntegerType => "INTEGER" + case LongType => "BIGINT" + case DoubleType => "DOUBLE PRECISION" + case FloatType => "REAL" + case ShortType => "INTEGER" + case ByteType => "BYTE" + case BooleanType => "BIT(1)" + case StringType => "TEXT" + case BinaryType => "BLOB" + case TimestampType => "TIMESTAMP" + case DateType => "DATE" + case DecimalType.Unlimited => "DECIMAL(40,20)" + case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") + }) val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") }} @@ -162,10 +163,9 @@ package object jdbc { url: String, table: String, properties: Properties = new Properties()) { - val quirks = DriverQuirks.get(url) - var nullTypes: Array[Int] = df.schema.fields.map(field => { - var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2 - if (nullType.isEmpty) { + val dialect = JdbcDialects.get(url) + val nullTypes: Array[Int] = df.schema.fields.map { field => + dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( field.dataType match { case IntegerType => java.sql.Types.INTEGER case LongType => java.sql.Types.BIGINT @@ -181,9 +181,8 @@ package object jdbc { case DecimalType.Unlimited => java.sql.Types.DECIMAL case _ => throw new IllegalArgumentException( s"Can't translate null value for field $field") - } - } else nullType.get - }).toArray + }) + } val rddSchema = df.schema df.foreachPartition { iterator => @@ -241,10 +240,10 @@ package object jdbc { } } } - + def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName + case driver => driver.getClass.getCanonicalName } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 9c58b8e4bb16a..565d10247f10e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -124,7 +124,7 @@ private[sql] object InferSchema { case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) case ArrayType(struct: StructType, containsNull) => ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType =>nullTypeToStringType(struct) + case struct: StructType => nullTypeToStringType(struct) case other: DataType => other } @@ -147,7 +147,7 @@ private[sql] object InferSchema { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { case (other: DataType, NullType) => other diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala index 80bf74aa02602..325f54b6808a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonGenerator.scala @@ -33,7 +33,7 @@ private[sql] object JacksonGenerator { */ def apply(rowSchema: StructType, gen: JsonGenerator)(row: Row): Unit = { def valWriter: (DataType, Any) => Unit = { - case (_, null) | (NullType, _) => gen.writeNull() + case (_, null) | (NullType, _) => gen.writeNull() case (StringType, v: String) => gen.writeString(v) case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString) case (IntegerType, v: Int) => gen.writeNumber(v) @@ -48,16 +48,16 @@ private[sql] object JacksonGenerator { case (DateType, v) => gen.writeString(v.toString) case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, udt.serialize(v)) - case (ArrayType(ty, _), v: Seq[_] ) => + case (ArrayType(ty, _), v: Seq[_]) => gen.writeStartArray() - v.foreach(valWriter(ty,_)) + v.foreach(valWriter(ty, _)) gen.writeEndArray() - case (MapType(kv,vv, _), v: Map[_,_]) => + case (MapType(kv, vv, _), v: Map[_, _]) => gen.writeStartObject() v.foreach { p => gen.writeFieldName(p._1.toString) - valWriter(vv,p._2) + valWriter(vv, p._2) } gen.writeEndObject() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index a8e69ae61174f..0e223758051a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -26,6 +26,7 @@ import com.fasterxml.jackson.core._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ @@ -149,10 +150,10 @@ private[sql] object JacksonParser { private def convertMap( factory: JsonFactory, parser: JsonParser, - valueType: DataType): Map[String, Any] = { - val builder = Map.newBuilder[String, Any] + valueType: DataType): Map[UTF8String, Any] = { + val builder = Map.newBuilder[UTF8String, Any] while (nextUntil(parser, JsonToken.END_OBJECT)) { - builder += parser.getCurrentName -> convertField(factory, parser, valueType) + builder += UTF8String(parser.getCurrentName) -> convertField(factory, parser, valueType) } builder.result() @@ -180,7 +181,7 @@ private[sql] object JacksonParser { val row = new GenericMutableRow(schema.length) for (corruptIndex <- schema.getFieldIndex(columnNameOfCorruptRecords)) { require(schema(corruptIndex).dataType == StringType) - row.update(corruptIndex, record) + row.update(corruptIndex, UTF8String(record)) } Seq(row) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index f62973d5fcfab..7e1e21f5fbb99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -20,17 +20,18 @@ package org.apache.spark.sql.json import java.sql.Timestamp import scala.collection.Map -import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} +import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} -import com.fasterxml.jackson.core.{JsonGenerator, JsonProcessingException} +import com.fasterxml.jackson.core.JsonProcessingException import com.fasterxml.jackson.databind.ObjectMapper +import org.apache.spark.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ -import org.apache.spark.Logging private[sql] object JsonRDD extends Logging { @@ -140,7 +141,7 @@ private[sql] object JsonRDD extends Logging { case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) case ArrayType(struct: StructType, containsNull) => ArrayType(nullTypeToStringType(struct), containsNull) - case struct: StructType =>nullTypeToStringType(struct) + case struct: StructType => nullTypeToStringType(struct) case other: DataType => other } StructField(fieldName, newType, nullable) @@ -154,7 +155,7 @@ private[sql] object JsonRDD extends Logging { * Returns the most general data type for two given data types. */ private[json] def compatibleType(t1: DataType, t2: DataType): DataType = { - HiveTypeCoercion.findTightestCommonType(t1, t2) match { + HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2) match { case Some(commonType) => commonType case None => // t1 or t2 is a StructType, ArrayType, or an unexpected type. @@ -215,7 +216,7 @@ private[sql] object JsonRDD extends Logging { case map: Map[_, _] => StructType(Nil) // We have an array of arrays. If those element arrays do not have the same // element types, we will return ArrayType[StringType]. - case seq: Seq[_] => typeOfArray(seq) + case seq: Seq[_] => typeOfArray(seq) case value => typeOfPrimitiveValue(value) } }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) @@ -317,7 +318,8 @@ private[sql] object JsonRDD extends Logging { parsed } catch { - case e: JsonProcessingException => Map(columnNameOfCorruptRecords -> record) :: Nil + case e: JsonProcessingException => + Map(columnNameOfCorruptRecords -> UTF8String(record)) :: Nil } } }) @@ -404,7 +406,7 @@ private[sql] object JsonRDD extends Logging { } } - private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ + private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any = { if (value == null) { null } else { @@ -421,7 +423,10 @@ private[sql] object JsonRDD extends Logging { value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case MapType(StringType, valueType, _) => val map = value.asInstanceOf[Map[String, Any]] - map.mapValues(enforceCorrectType(_, valueType)).map(identity) + map.map { + case (k, v) => + (UTF8String(k), enforceCorrectType(v, valueType)) + }.map(identity) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) case DateType => toDate(value) case TimestampType => toTimestamp(value) @@ -429,7 +434,7 @@ private[sql] object JsonRDD extends Logging { } } - private def asRow(json: Map[String,Any], schema: StructType): Row = { + private def asRow(json: Map[String, Any], schema: StructType): Row = { // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 3f97a11ceb97d..4e94fd07a8771 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -44,6 +44,7 @@ package object sql { /** * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. + * @deprecated As of 1.3.0, replaced by `DataFrame`. */ @deprecated("1.3.0", "use DataFrame") type SchemaRDD = DataFrame diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 36cb5e03bbca7..caa9f045537d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -243,8 +243,10 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { /** * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in * a long (i.e. precision <= 18) + * + * Returned value is needed by CatalystConverter, which doesn't reuse the Decimal object. */ - protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = { + protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Decimal = { val precision = ctype.precisionInfo.get.precision val scale = ctype.precisionInfo.get.scale val bytes = value.getBytes @@ -480,7 +482,7 @@ private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverte override def hasDictionarySupport: Boolean = true - override def setDictionary(dictionary: Dictionary):Unit = + override def setDictionary(dictionary: Dictionary): Unit = dict = Array.tabulate(dictionary.getMaxId + 1) { dictionary.decodeToBinary(_).getBytes } override def addValueFromDictionary(dictionaryId: Int): Unit = @@ -591,8 +593,8 @@ private[parquet] class CatalystArrayConverter( CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, elementType, false), - fieldIndex=0, - parent=this) + fieldIndex = 0, + parent = this) override def getConverter(fieldIndex: Int): Converter = converter @@ -601,7 +603,7 @@ private[parquet] class CatalystArrayConverter( override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { // fieldIndex is ignored (assumed to be zero but not checked) - if(value == null) { + if (value == null) { throw new IllegalArgumentException("Null values inside Parquet arrays are not supported!") } buffer += value @@ -654,8 +656,8 @@ private[parquet] class CatalystNativeArrayConverter( CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, elementType, false), - fieldIndex=0, - parent=this) + fieldIndex = 0, + parent = this) override def getConverter(fieldIndex: Int): Converter = converter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 5eb1c6abc2432..f0f4e7d147e75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -29,128 +29,184 @@ import parquet.io.api.Binary import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.sources import org.apache.spark.sql.types._ private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" def createRecordFilter(filterExpressions: Seq[Expression]): Option[Filter] = { - filterExpressions.flatMap(createFilter).reduceOption(FilterApi.and).map(FilterCompat.get) + filterExpressions.flatMap { filter => + createFilter(filter) + }.reduceOption(FilterApi.and).map(FilterCompat.get) } - def createFilter(predicate: Expression): Option[FilterPredicate] = { - val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => - (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - - // Binary.fromString and Binary.fromByteArray don't accept null values - case StringType => - (n: String, v: Any) => FilterApi.eq( - binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) - case BinaryType => - (n: String, v: Any) => FilterApi.eq( - binaryColumn(n), - Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) - } + private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case BooleanType => + (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) + case IntegerType => + (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) + case LongType => + (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case BooleanType => - (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) - case IntegerType => - (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => FilterApi.notEq( - binaryColumn(n), - Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) - case BinaryType => - (n: String, v: Any) => FilterApi.notEq( - binaryColumn(n), - Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) - } + // Binary.fromString and Binary.fromByteArray don't accept null values + case StringType => + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) + case BinaryType => + (n: String, v: Any) => FilterApi.eq( + binaryColumn(n), + Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) + } - val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) - case LongType => - (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) - case BinaryType => - (n: String, v: Any) => - FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) - } + private val makeNotEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case BooleanType => + (n: String, v: Any) => FilterApi.notEq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) + case IntegerType => + (n: String, v: Any) => FilterApi.notEq(intColumn(n), v.asInstanceOf[Integer]) + case LongType => + (n: String, v: Any) => FilterApi.notEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.notEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.notEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + case StringType => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull) + case BinaryType => + (n: String, v: Any) => FilterApi.notEq( + binaryColumn(n), + Option(v).map(b => Binary.fromByteArray(v.asInstanceOf[Array[Byte]])).orNull) + } - val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) - case BinaryType => - (n: String, v: Any) => - FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) - } + private val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case IntegerType => + (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) + case LongType => + (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + case StringType => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + case BinaryType => + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + } - val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) - case BinaryType => - (n: String, v: Any) => - FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) - } + private val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case IntegerType => + (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) + case LongType => + (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + case StringType => + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + case BinaryType => + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + } - val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { - case IntegerType => - (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) - case LongType => - (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) - case FloatType => - (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) - case DoubleType => - (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) - case StringType => - (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) - case BinaryType => - (n: String, v: Any) => - FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + private val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case IntegerType => + (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer]) + case LongType => + (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + case StringType => + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + case BinaryType => + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + } + + private val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { + case IntegerType => + (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) + case LongType => + (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) + case FloatType => + (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) + case StringType => + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes)) + case BinaryType => + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) + } + + /** + * Converts data sources filters to Parquet filter predicates. + */ + def createFilter(schema: StructType, predicate: sources.Filter): Option[FilterPredicate] = { + val dataTypeOf = schema.map(f => f.name -> f.dataType).toMap + + // NOTE: + // + // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, + // which can be casted to `false` implicitly. Please refer to the `eval` method of these + // operators and the `SimplifyFilters` rule for details. + predicate match { + case sources.IsNull(name) => + makeEq.lift(dataTypeOf(name)).map(_(name, null)) + case sources.IsNotNull(name) => + makeNotEq.lift(dataTypeOf(name)).map(_(name, null)) + + case sources.EqualTo(name, value) => + makeEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.Not(sources.EqualTo(name, value)) => + makeNotEq.lift(dataTypeOf(name)).map(_(name, value)) + + case sources.LessThan(name, value) => + makeLt.lift(dataTypeOf(name)).map(_(name, value)) + case sources.LessThanOrEqual(name, value) => + makeLtEq.lift(dataTypeOf(name)).map(_(name, value)) + + case sources.GreaterThan(name, value) => + makeGt.lift(dataTypeOf(name)).map(_(name, value)) + case sources.GreaterThanOrEqual(name, value) => + makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) + + case sources.And(lhs, rhs) => + (createFilter(schema, lhs) ++ createFilter(schema, rhs)).reduceOption(FilterApi.and) + + case sources.Or(lhs, rhs) => + for { + lhsFilter <- createFilter(schema, lhs) + rhsFilter <- createFilter(schema, rhs) + } yield FilterApi.or(lhsFilter, rhsFilter) + + case sources.Not(pred) => + createFilter(schema, pred).map(FilterApi.not) + + case _ => None } + } + /** + * Converts Catalyst predicate expressions to Parquet filter predicates. + * + * @todo This can be removed once we get rid of the old Parquet support. + */ + def createFilter(predicate: Expression): Option[FilterPredicate] = { // NOTE: // // For any comparison operator `cmp`, both `a cmp NULL` and `NULL cmp a` evaluate to `NULL`, @@ -170,7 +226,7 @@ private[sql] object ParquetFilters { makeEq.lift(dataType).map(_(name, value)) case EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => makeEq.lift(dataType).map(_(name, value)) - + case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) => makeNotEq.lift(dataType).map(_(name, value)) case Not(EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _))) => @@ -192,7 +248,7 @@ private[sql] object ParquetFilters { case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeLtEq.lift(dataType).map(_(name, value)) case LessThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeLtEq.lift(dataType).map(_(name, value)) + makeLtEq.lift(dataType).map(_(name, value)) case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeGtEq.lift(dataType).map(_(name, value)) case LessThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => @@ -201,7 +257,7 @@ private[sql] object ParquetFilters { case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGt.lift(dataType).map(_(name, value)) case GreaterThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeGt.lift(dataType).map(_(name, value)) + makeGt.lift(dataType).map(_(name, value)) case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLt.lift(dataType).map(_(name, value)) case GreaterThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => @@ -210,7 +266,7 @@ private[sql] object ParquetFilters { case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGtEq.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => - makeGtEq.lift(dataType).map(_(name, value)) + makeGtEq.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLtEq.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 75ac52d4a98ff..cb7ae246d0d75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -541,7 +541,7 @@ private[parquet] class FilteringParquetRowInputFormat val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] val filter: Filter = ParquetInputFormat.getFilter(configuration) var rowGroupsDropped: Long = 0 - var totalRowGroups: Long = 0 + var totalRowGroups: Long = 0 // Ugly hack, stuck with it until PR: // https://github.com/apache/incubator-parquet-mr/pull/17 @@ -664,7 +664,7 @@ private[parquet] object FileSystemHelper { s"ParquetTableOperations: path $path does not exist or is not a directory") } fs.globStatus(path) - .flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) } + .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } .map(_.getPath) } @@ -674,7 +674,7 @@ private[parquet] object FileSystemHelper { def findMaxTaskId(pathStr: String, conf: Configuration): Int = { val files = FileSystemHelper.listFiles(pathStr, conf) // filename pattern is part-r-.parquet - val nameP = new scala.util.matching.Regex("""part-r-(\d{1,}).parquet""", "taskid") + val nameP = new scala.util.matching.Regex("""part-.-(\d{1,}).*""", "taskid") val hiddenFileP = new scala.util.matching.Regex("_.*") files.map(_.getName).map { case nameP(taskid) => taskid.toInt diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index c45c431438efc..70a220cc43ab9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -129,7 +129,7 @@ private[parquet] object RowReadSupport { } /** - * A `parquet.hadoop.api.WriteSupport` for Row ojects. + * A `parquet.hadoop.api.WriteSupport` for Row objects. */ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 1dc819b5d7b9b..6698b19c7477d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -489,7 +489,7 @@ private[parquet] object ParquetTypesConverter extends Logging { val children = fs .globStatus(path) - .flatMap { status => if(status.isDir) fs.listStatus(status.getPath) else List(status) } + .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } .filterNot { status => val name = status.getPath.getName (name(0) == '.' || name(0) == '_') && name != ParquetFileWriter.PARQUET_METADATA_FILE diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index ee4b1c72a2148..824ae36968c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -14,205 +14,318 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.parquet -import java.io.IOException -import java.lang.{Double => JDouble, Float => JFloat, Long => JLong} -import java.math.{BigDecimal => JBigDecimal} import java.net.URI -import java.text.SimpleDateFormat -import java.util.{Date, List => JList} +import java.util.{List => JList} import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer import scala.util.Try +import com.google.common.base.Objects import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat -import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} import parquet.filter2.predicate.FilterApi -import parquet.format.converter.ParquetMetadataConverter +import parquet.hadoop._ import parquet.hadoop.metadata.CompressionCodecName import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{ParquetInputFormat, _} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.{Partition => SparkPartition, SerializableWritable, Logging, SparkException} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} -import org.apache.spark.sql.parquet.ParquetTypesConverter._ +import org.apache.spark.rdd.RDD._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _} -import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode} -import org.apache.spark.{Logging, SerializableWritable, SparkException, TaskContext, Partition => SparkPartition} - -/** - * Allows creation of Parquet based tables using the syntax: - * {{{ - * CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet OPTIONS (...) - * }}} - * - * Supported options include: - * - * - `path`: Required. When reading Parquet files, `path` should point to the location of the - * Parquet file(s). It can be either a single raw Parquet file, or a directory of Parquet files. - * In the latter case, this data source tries to discover partitioning information if the the - * directory is structured in the same style of Hive partitioned tables. When writing Parquet - * file, `path` should point to the destination folder. - * - * - `mergeSchema`: Optional. Indicates whether we should merge potentially different (but - * compatible) schemas stored in all Parquet part-files. - * - * - `partition.defaultName`: Optional. Partition name used when a value of a partition column is - * null or empty string. This is similar to the `hive.exec.default.partition.name` configuration - * in Hive. - */ -private[sql] class DefaultSource - extends RelationProvider - with SchemaRelationProvider - with CreatableRelationProvider { +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.{Row, SQLConf, SQLContext} +import org.apache.spark.util.Utils - private def checkPath(parameters: Map[String, String]): String = { - parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables.")) - } - - /** Returns a new base relation with the given parameters. */ - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String]): BaseRelation = { - ParquetRelation2(Seq(checkPath(parameters)), parameters, None)(sqlContext) - } - - /** Returns a new base relation with the given parameters and schema. */ +private[sql] class DefaultSource extends HadoopFsRelationProvider { override def createRelation( sqlContext: SQLContext, - parameters: Map[String, String], - schema: StructType): BaseRelation = { - ParquetRelation2(Seq(checkPath(parameters)), parameters, Some(schema))(sqlContext) + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + new ParquetRelation2(paths, schema, None, partitionColumns, parameters)(sqlContext) } +} - /** Returns a new base relation with the given parameters and save given data into it. */ - override def createRelation( - sqlContext: SQLContext, - mode: SaveMode, - parameters: Map[String, String], - data: DataFrame): BaseRelation = { - val path = checkPath(parameters) - val filesystemPath = new Path(path) - val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val doInsertion = (mode, fs.exists(filesystemPath)) match { - case (SaveMode.ErrorIfExists, true) => - sys.error(s"path $path already exists.") - case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => - true - case (SaveMode.Ignore, exists) => - !exists - } +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) + extends OutputWriter { + + private val recordWriter: RecordWriter[Void, Row] = { + val conf = context.getConfiguration + val outputFormat = { + // When appending new Parquet files to an existing Parquet file directory, to avoid + // overwriting existing data files, we need to find out the max task ID encoded in these data + // file names. + // TODO Make this snippet a utility function for other data source developers + val maxExistingTaskId = { + // Note that `path` may point to a temporary location. Here we retrieve the real + // destination path from the configuration + val outputPath = new Path(conf.get("spark.sql.sources.output.path")) + val fs = outputPath.getFileSystem(conf) + + if (fs.exists(outputPath)) { + // Pattern used to match task ID in part file names, e.g.: + // + // part-r-00001.gz.parquet + // ^~~~~ + val partFilePattern = """part-.-(\d{1,}).*""".r + + fs.listStatus(outputPath).map(_.getPath.getName).map { + case partFilePattern(id) => id.toInt + case name if name.startsWith("_") => 0 + case name if name.startsWith(".") => 0 + case name => sys.error( + s"Trying to write Parquet files to directory $outputPath, " + + s"but found items with illegal name '$name'.") + }.reduceOption(_ max _).getOrElse(0) + } else { + 0 + } + } - val relation = if (doInsertion) { - // This is a hack. We always set nullable/containsNull/valueContainsNull to true - // for the schema of a parquet data. - val df = - sqlContext.createDataFrame( - data.queryExecution.toRdd, - data.schema.asNullable, - needsConversion = false) - val createdRelation = - createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2] - createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite) - createdRelation - } else { - // If the save mode is Ignore, we will just create the relation based on existing data. - createRelation(sqlContext, parameters) + new ParquetOutputFormat[Row]() { + // Here we override `getDefaultWorkFile` for two reasons: + // + // 1. To allow appending. We need to generate output file name based on the max available + // task ID computed above. + // + // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses + // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all + // partitions in the case of dynamic partitioning. + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val split = context.getTaskAttemptID.getTaskID.getId + maxExistingTaskId + 1 + new Path(path, f"part-r-$split%05d$extension") + } + } } - relation + outputFormat.getRecordWriter(context) } + + override def write(row: Row): Unit = recordWriter.write(null, row) + + override def close(): Unit = recordWriter.close(context) } -/** - * An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is - * intended as a full replacement of the Parquet support in Spark SQL. The old implementation will - * be deprecated and eventually removed once this version is proved to be stable enough. - * - * Compared with the old implementation, this class has the following notable differences: - * - * - Partitioning discovery: Hive style multi-level partitions are auto discovered. - * - Metadata discovery: Parquet is a format comes with schema evolving support. This data source - * can detect and merge schemas from all Parquet part-files as long as they are compatible. - * Also, metadata and [[FileStatus]]es are cached for better performance. - * - Statistics: Statistics for the size of the table are automatically populated during schema - * discovery. - */ -@DeveloperApi -private[sql] case class ParquetRelation2( - paths: Seq[String], - parameters: Map[String, String], - maybeSchema: Option[StructType] = None, - maybePartitionSpec: Option[PartitionSpec] = None)( - @transient val sqlContext: SQLContext) - extends BaseRelation - with CatalystScan - with InsertableRelation - with SparkHadoopMapReduceUtil +private[sql] class ParquetRelation2( + override val paths: Array[String], + private val maybeDataSchema: Option[StructType], + // This is for metastore conversion. + private val maybePartitionSpec: Option[PartitionSpec], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + val sqlContext: SQLContext) + extends HadoopFsRelation(maybePartitionSpec) with Logging { + private[sql] def this( + paths: Array[String], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + parameters: Map[String, String])( + sqlContext: SQLContext) = { + this( + paths, + maybeDataSchema, + maybePartitionSpec, + maybePartitionSpec.map(_.partitionColumns), + parameters)(sqlContext) + } + // Should we merge schemas from all Parquet part-files? private val shouldMergeSchemas = parameters.getOrElse(ParquetRelation2.MERGE_SCHEMA, "true").toBoolean - // Optional Metastore schema, used when converting Hive Metastore Parquet table - private val maybeMetastoreSchema = - parameters - .get(ParquetRelation2.METASTORE_SCHEMA) - .map(s => DataType.fromJson(s).asInstanceOf[StructType]) + private val maybeMetastoreSchema = parameters + .get(ParquetRelation2.METASTORE_SCHEMA) + .map(DataType.fromJson(_).asInstanceOf[StructType]) - // Hive uses this as part of the default partition name when the partition column value is null - // or empty string - private val defaultPartitionName = parameters.getOrElse( - ParquetRelation2.DEFAULT_PARTITION_NAME, "__HIVE_DEFAULT_PARTITION__") + private lazy val metadataCache: MetadataCache = { + val meta = new MetadataCache + meta.refresh() + meta + } override def equals(other: Any): Boolean = other match { - case relation: ParquetRelation2 => - // If schema merging is required, we don't compare the actual schemas since they may evolve. + case that: ParquetRelation2 => val schemaEquality = if (shouldMergeSchemas) { - shouldMergeSchemas == relation.shouldMergeSchemas + this.shouldMergeSchemas == that.shouldMergeSchemas } else { - schema == relation.schema + this.dataSchema == that.dataSchema && + this.schema == that.schema } - paths.toSet == relation.paths.toSet && + this.paths.toSet == that.paths.toSet && schemaEquality && - maybeMetastoreSchema == relation.maybeMetastoreSchema && - maybePartitionSpec == relation.maybePartitionSpec + this.maybeDataSchema == that.maybeDataSchema && + this.partitionColumns == that.partitionColumns case _ => false } override def hashCode(): Int = { if (shouldMergeSchemas) { - com.google.common.base.Objects.hashCode( - shouldMergeSchemas: java.lang.Boolean, + Objects.hashCode( + Boolean.box(shouldMergeSchemas), paths.toSet, - maybeMetastoreSchema, - maybePartitionSpec) + maybeDataSchema, + partitionColumns) } else { - com.google.common.base.Objects.hashCode( - shouldMergeSchemas: java.lang.Boolean, - schema, + Objects.hashCode( + Boolean.box(shouldMergeSchemas), paths.toSet, - maybeMetastoreSchema, - maybePartitionSpec) + dataSchema, + schema, + maybeDataSchema, + partitionColumns) + } + } + + override def dataSchema: StructType = maybeDataSchema.getOrElse(metadataCache.dataSchema) + + override private[sql] def refresh(): Unit = { + super.refresh() + metadataCache.refresh() + } + + // Parquet data source always uses Catalyst internal representations. + override val needConversion: Boolean = false + + override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum + + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + val conf = ContextUtil.getConfiguration(job) + + val committerClass = + conf.getClass( + "spark.sql.parquet.output.committer.class", + classOf[ParquetOutputCommitter], + classOf[ParquetOutputCommitter]) + + conf.setClass( + SQLConf.OUTPUT_COMMITTER_CLASS, + committerClass, + classOf[ParquetOutputCommitter]) + + // TODO There's no need to use two kinds of WriteSupport + // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and + // complex types. + val writeSupportClass = + if (dataSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { + classOf[MutableRowWriteSupport] + } else { + classOf[RowWriteSupport] + } + + ParquetOutputFormat.setWriteSupportClass(job, writeSupportClass) + RowWriteSupport.setSchema(dataSchema.toAttributes, conf) + + // Sets compression scheme + conf.set( + ParquetOutputFormat.COMPRESSION, + ParquetRelation + .shortParquetCompressionCodecNames + .getOrElse( + sqlContext.conf.parquetCompressionCodec.toUpperCase, + CompressionCodecName.UNCOMPRESSED).name()) + + new OutputWriterFactory { + override def newInstance( + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, context) + } } } - private[sql] def sparkContext = sqlContext.sparkContext + override def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputFiles: Array[FileStatus], + broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { + val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean + val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown + // Create the function to set variable Parquet confs at both driver and executor side. + val initLocalJobFuncOpt = + ParquetRelation2.initializeLocalJobFunc( + requiredColumns, + filters, + dataSchema, + useMetadataCache, + parquetFilterPushDown) _ + // Create the function to set input paths at the driver side. + val setInputPaths = ParquetRelation2.initializeDriverSideJobFunc(inputFiles) _ + + val footers = inputFiles.map(f => metadataCache.footers(f.getPath)) + + Utils.withDummyCallSite(sqlContext.sparkContext) { + // TODO Stop using `FilteringParquetRowInputFormat` and overriding `getPartition`. + // After upgrading to Parquet 1.6.0, we should be able to stop caching `FileStatus` objects + // and footers. Especially when a global arbitrative schema (either from metastore or data + // source DDL) is available. + new SqlNewHadoopRDD( + sc = sqlContext.sparkContext, + broadcastedConf = broadcastedConf, + initDriverSideJobFuncOpt = Some(setInputPaths), + initLocalJobFuncOpt = Some(initLocalJobFuncOpt), + inputFormatClass = classOf[FilteringParquetRowInputFormat], + keyClass = classOf[Void], + valueClass = classOf[Row]) { + + val cacheMetadata = useMetadataCache + + @transient val cachedStatuses = inputFiles.map { f => + // In order to encode the authority of a Path containing special characters such as '/' + // (which does happen in some S3N credentials), we need to use the string returned by the + // URI of the path to create a new Path. + val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) + new FileStatus( + f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, + f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) + }.toSeq + + @transient val cachedFooters = footers.map { f => + // In order to encode the authority of a Path containing special characters such as /, + // we need to use the string returned by the URI of the path to create a new Path. + new Footer(escapePathUserInfo(f.getFile), f.getParquetMetadata) + }.toSeq + + private def escapePathUserInfo(path: Path): Path = { + val uri = path.toUri + new Path(new URI( + uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, + uri.getQuery, uri.getFragment)) + } + + // Overridden so we can inject our own cached files statuses. + override def getPartitions: Array[SparkPartition] = { + val inputFormat = if (cacheMetadata) { + new FilteringParquetRowInputFormat { + override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatuses + override def getFooters(jobContext: JobContext): JList[Footer] = cachedFooters + } + } else { + new FilteringParquetRowInputFormat + } + + val jobContext = newJobContext(getConf(isDriverSide = true), jobId) + val rawSplits = inputFormat.getSplits(jobContext) + + Array.tabulate[SparkPartition](rawSplits.size) { i => + new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + } + }.values + } + } private class MetadataCache { // `FileStatus` objects of all "_metadata" files. @@ -222,127 +335,75 @@ private[sql] case class ParquetRelation2( private var commonMetadataStatuses: Array[FileStatus] = _ // Parquet footer cache. - var footers: Map[FileStatus, Footer] = _ + var footers: Map[Path, Footer] = _ // `FileStatus` objects of all data files (Parquet part-files). var dataStatuses: Array[FileStatus] = _ - // Partition spec of this table, including names, data types, and values of each partition - // column, and paths of each partition. - var partitionSpec: PartitionSpec = _ - // Schema of the actual Parquet files, without partition columns discovered from partition // directory paths. - var parquetSchema: StructType = _ + var dataSchema: StructType = null // Schema of the whole table, including partition columns. var schema: StructType = _ - // Indicates whether partition columns are also included in Parquet data file schema. If not, - // we need to fill in partition column values into read rows when scanning the table. - var partitionKeysIncludedInParquetSchema: Boolean = _ - - def prepareMetadata(path: Path, schema: StructType, conf: Configuration): Unit = { - conf.set( - ParquetOutputFormat.COMPRESSION, - ParquetRelation - .shortParquetCompressionCodecNames - .getOrElse( - sqlContext.conf.parquetCompressionCodec.toUpperCase, - CompressionCodecName.UNCOMPRESSED).name()) - - ParquetRelation.enableLogForwarding() - ParquetTypesConverter.writeMetaData(schema.toAttributes, path, conf) - } - /** * Refreshes `FileStatus`es, footers, partition spec, and table schema. */ def refresh(): Unit = { - // Support either reading a collection of raw Parquet part-files, or a collection of folders - // containing Parquet files (e.g. partitioned Parquet table). - val baseStatuses = paths.distinct.map { p => - val fs = FileSystem.get(URI.create(p), sparkContext.hadoopConfiguration) - val path = new Path(p) - val qualified = path.makeQualified(fs.getUri, fs.getWorkingDirectory) - - if (!fs.exists(qualified) && maybeSchema.isDefined) { - fs.mkdirs(qualified) - prepareMetadata(qualified, maybeSchema.get, sparkContext.hadoopConfiguration) - } - - fs.getFileStatus(qualified) - }.toArray - assert(baseStatuses.forall(!_.isDir) || baseStatuses.forall(_.isDir)) - // Lists `FileStatus`es of all leaf nodes (files) under all base directories. - val leaves = baseStatuses.flatMap { f => - val fs = FileSystem.get(f.getPath.toUri, sparkContext.hadoopConfiguration) - SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f => - isSummaryFile(f.getPath) || - !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) - } - } + val leaves = cachedLeafStatuses().filter { f => + isSummaryFile(f.getPath) || + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + }.toArray dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) metadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) commonMetadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) - footers = (dataStatuses ++ metadataStatuses ++ commonMetadataStatuses).par.map { f => - val parquetMetadata = ParquetFileReader.readFooter( - sparkContext.hadoopConfiguration, f, ParquetMetadataConverter.NO_FILTER) - f -> new Footer(f.getPath, parquetMetadata) - }.seq.toMap - - partitionSpec = maybePartitionSpec.getOrElse { - val partitionDirs = leaves - .filterNot(baseStatuses.contains) - .map(_.getPath.getParent) - .distinct - - if (partitionDirs.nonEmpty) { - // Parses names and values of partition columns, and infer their data types. - PartitioningUtils.parsePartitions(partitionDirs, defaultPartitionName) + footers = { + val conf = SparkHadoopUtil.get.conf + val taskSideMetaData = conf.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) + val rawFooters = if (shouldMergeSchemas) { + ParquetFileReader.readAllFootersInParallel( + conf, seqAsJavaList(leaves), taskSideMetaData) } else { - // No partition directories found, makes an empty specification - PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition]) + ParquetFileReader.readAllFootersInParallelUsingSummaryFiles( + conf, seqAsJavaList(leaves), taskSideMetaData) } + + rawFooters.map(footer => footer.getFile -> footer).toMap } - // To get the schema. We first try to get the schema defined in maybeSchema. - // If maybeSchema is not defined, we will try to get the schema from existing parquet data - // (through readSchema). If data does not exist, we will try to get the schema defined in - // maybeMetastoreSchema (defined in the options of the data source). - // Finally, if we still could not get the schema. We throw an error. - parquetSchema = - maybeSchema - .orElse(readSchema()) - .orElse(maybeMetastoreSchema) - .getOrElse(sys.error("Failed to get the schema.")) - - partitionKeysIncludedInParquetSchema = - isPartitioned && - partitionColumns.forall(f => parquetSchema.fieldNames.contains(f.name)) - - schema = { - val fullRelationSchema = if (partitionKeysIncludedInParquetSchema) { - parquetSchema - } else { - StructType(parquetSchema.fields ++ partitionColumns.fields) + // If we already get the schema, don't need to re-compute it since the schema merging is + // time-consuming. + if (dataSchema == null) { + dataSchema = { + val dataSchema0 = + maybeDataSchema + .orElse(readSchema()) + .orElse(maybeMetastoreSchema) + .getOrElse(sys.error("Failed to get the schema.")) + + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case + // case insensitivity issue and possible schema mismatch (probably caused by schema + // evolution). + maybeMetastoreSchema + .map(ParquetRelation2.mergeMetastoreParquetSchema(_, dataSchema0)) + .getOrElse(dataSchema0) } - - // If this Parquet relation is converted from a Hive Metastore table, must reconcile case - // insensitivity issue and possible schema mismatch. - maybeMetastoreSchema - .map(ParquetRelation2.mergeMetastoreParquetSchema(_, fullRelationSchema)) - .getOrElse(fullRelationSchema) } } + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } + private def readSchema(): Option[StructType] = { // Sees which file(s) we need to touch in order to figure out the schema. - val filesToTouch = + // // Always tries the summary files first if users don't require a merged schema. In this case, // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row // groups information, and could be much smaller for large Parquet files with lots of row @@ -361,6 +422,7 @@ private[sql] case class ParquetRelation2( // Here we tend to be pessimistic and take the second case into account. Basically this means // we can't trust the summary files if users require a merged schema, and must touch all part- // files to do the merge. + val filesToTouch = if (shouldMergeSchemas) { // Also includes summary files, 'cause there might be empty partition directories. (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq @@ -378,356 +440,67 @@ private[sql] case class ParquetRelation2( .toSeq } - ParquetRelation2.readSchema(filesToTouch.map(footers.apply), sqlContext) + assert( + filesToTouch.nonEmpty || maybeDataSchema.isDefined || maybeMetastoreSchema.isDefined, + "No schema defined, " + + s"and no Parquet data file or summary file found under ${paths.mkString(", ")}.") + + ParquetRelation2.readSchema(filesToTouch.map(f => footers.apply(f.getPath)), sqlContext) } } +} - @transient private val metadataCache = new MetadataCache - metadataCache.refresh() - - def partitionSpec: PartitionSpec = metadataCache.partitionSpec - - def partitionColumns: StructType = metadataCache.partitionSpec.partitionColumns - - def partitions: Seq[Partition] = metadataCache.partitionSpec.partitions - - def isPartitioned: Boolean = partitionColumns.nonEmpty - - private def partitionKeysIncludedInDataSchema = metadataCache.partitionKeysIncludedInParquetSchema - - private def parquetSchema = metadataCache.parquetSchema - - override def schema: StructType = metadataCache.schema - - private def isSummaryFile(file: Path): Boolean = { - file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || - file.getName == ParquetFileWriter.PARQUET_METADATA_FILE - } +private[sql] object ParquetRelation2 extends Logging { + // Whether we should merge schemas collected from all Parquet part-files. + private[sql] val MERGE_SCHEMA = "mergeSchema" - // Skip type conversion - override val needConversion: Boolean = false + // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used + // internally. + private[sql] val METASTORE_SCHEMA = "metastoreSchema" - // TODO Should calculate per scan size - // It's common that a query only scans a fraction of a large Parquet file. Returning size of the - // whole Parquet file disables some optimizations in this case (e.g. broadcast join). - override val sizeInBytes = metadataCache.dataStatuses.map(_.getLen).sum - - // This is mostly a hack so that we can use the existing parquet filter code. - override def buildScan(output: Seq[Attribute], predicates: Seq[Expression]): RDD[Row] = { - val job = new Job(sparkContext.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) - val jobConf: Configuration = ContextUtil.getConfiguration(job) - - val selectedPartitions = prunePartitions(predicates, partitions) - val selectedFiles = if (isPartitioned) { - selectedPartitions.flatMap { p => - metadataCache.dataStatuses.filter(_.getPath.getParent.toString == p.path) - } - } else { - metadataCache.dataStatuses.toSeq - } - val selectedFooters = selectedFiles.map(metadataCache.footers) - - // FileInputFormat cannot handle empty lists. - if (selectedFiles.nonEmpty) { - // In order to encode the authority of a Path containning special characters such as /, - // we need to use the string retruned by the URI of the path to create a new Path. - val selectedPaths = selectedFiles.map(status => new Path(status.getPath.toUri.toString)) - FileInputFormat.setInputPaths(job, selectedPaths: _*) - } + /** This closure sets various Parquet configurations at both driver side and executor side. */ + private[parquet] def initializeLocalJobFunc( + requiredColumns: Array[String], + filters: Array[Filter], + dataSchema: StructType, + useMetadataCache: Boolean, + parquetFilterPushDown: Boolean)(job: Job): Unit = { + val conf = job.getConfiguration + conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[RowReadSupport].getName()) // Try to push down filters when filter push-down is enabled. - if (sqlContext.conf.parquetFilterPushDown) { - val partitionColNames = partitionColumns.map(_.name).toSet - predicates - // Don't push down predicates which reference partition columns - .filter { pred => - val referencedColNames = pred.references.map(_.name).toSet - referencedColNames.intersect(partitionColNames).isEmpty - } + if (parquetFilterPushDown) { + filters // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` // is used here. - .flatMap(ParquetFilters.createFilter) + .flatMap(ParquetFilters.createFilter(dataSchema, _)) .reduceOption(FilterApi.and) - .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) + .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) } - if (isPartitioned) { - logInfo { - val percentRead = selectedPartitions.size.toDouble / partitions.size.toDouble * 100 - s"Reading $percentRead% of partitions" - } - } - - val requiredColumns = output.map(_.name) - val requestedSchema = StructType(requiredColumns.map(schema(_))) + conf.set(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { + val requestedSchema = StructType(requiredColumns.map(dataSchema(_))) + ParquetTypesConverter.convertToString(requestedSchema.toAttributes) + }) - // Store both requested and original schema in `Configuration` - jobConf.set( - RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - convertToString(requestedSchema.toAttributes)) - jobConf.set( + conf.set( RowWriteSupport.SPARK_ROW_SCHEMA, - convertToString(schema.toAttributes)) + ParquetTypesConverter.convertToString(dataSchema.toAttributes)) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - val useCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "true").toBoolean - jobConf.set(SQLConf.PARQUET_CACHE_METADATA, useCache.toString) - - val baseRDD = - new NewHadoopRDD( - sparkContext, - classOf[FilteringParquetRowInputFormat], - classOf[Void], - classOf[Row], - jobConf) { - val cacheMetadata = useCache - - @transient - val cachedStatus = selectedFiles.map { st => - // In order to encode the authority of a Path containning special characters such as /, - // we need to use the string retruned by the URI of the path to create a new Path. - val newPath = new Path(st.getPath.toUri.toString) - - new FileStatus( - st.getLen, - st.isDir, - st.getReplication, - st.getBlockSize, - st.getModificationTime, - st.getAccessTime, - st.getPermission, - st.getOwner, - st.getGroup, - newPath) - } - - @transient - val cachedFooters = selectedFooters.map { f => - // In order to encode the authority of a Path containning special characters such as /, - // we need to use the string retruned by the URI of the path to create a new Path. - new Footer(new Path(f.getFile.toUri.toString), f.getParquetMetadata) - } - - - // Overridden so we can inject our own cached files statuses. - override def getPartitions: Array[SparkPartition] = { - val inputFormat = if (cacheMetadata) { - new FilteringParquetRowInputFormat { - override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus - - override def getFooters(jobContext: JobContext): JList[Footer] = cachedFooters - } - } else { - new FilteringParquetRowInputFormat - } - - val jobContext = newJobContext(getConf, jobId) - val rawSplits = inputFormat.getSplits(jobContext) - - Array.tabulate[SparkPartition](rawSplits.size) { i => - new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - } - } - - // The ordinals for partition keys in the result row, if requested. - val partitionKeyLocations = partitionColumns.fieldNames.zipWithIndex.map { - case (name, index) => index -> requiredColumns.indexOf(name) - }.toMap.filter { - case (_, index) => index >= 0 - } - - // When the data does not include the key and the key is requested then we must fill it in - // based on information from the input split. - if (!partitionKeysIncludedInDataSchema && partitionKeyLocations.nonEmpty) { - // This check is based on CatalystConverter.createRootConverter. - val primitiveRow = - requestedSchema.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType)) - - baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) => - val partValues = selectedPartitions.collectFirst { - case p if split.getPath.getParent.toString == p.path => - CatalystTypeConverters.convertToCatalyst(p.values).asInstanceOf[Row] - }.get - - val requiredPartOrdinal = partitionKeyLocations.keys.toSeq - - if (primitiveRow) { - iterator.map { pair => - // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow. - val row = pair._2.asInstanceOf[SpecificMutableRow] - var i = 0 - while (i < requiredPartOrdinal.size) { - // TODO Avoids boxing cost here! - val partOrdinal = requiredPartOrdinal(i) - row.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal)) - i += 1 - } - row - } - } else { - // Create a mutable row since we need to fill in values from partition columns. - val mutableRow = new GenericMutableRow(requestedSchema.size) - iterator.map { pair => - // We are using CatalystGroupConverter and it returns a GenericRow. - // Since GenericRow is not mutable, we just cast it to a Row. - val row = pair._2.asInstanceOf[Row] - var i = 0 - while (i < row.size) { - // TODO Avoids boxing cost here! - mutableRow(i) = row(i) - i += 1 - } - - i = 0 - while (i < requiredPartOrdinal.size) { - // TODO Avoids boxing cost here! - val partOrdinal = requiredPartOrdinal(i) - mutableRow.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal)) - i += 1 - } - mutableRow - } - } - } - } else { - baseRDD.map(_._2) - } + conf.set(SQLConf.PARQUET_CACHE_METADATA, useMetadataCache.toString) } - private def prunePartitions( - predicates: Seq[Expression], - partitions: Seq[Partition]): Seq[Partition] = { - val partitionColumnNames = partitionColumns.map(_.name).toSet - val partitionPruningPredicates = predicates.filter { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - - val rawPredicate = - partitionPruningPredicates.reduceOption(expressions.And).getOrElse(Literal(true)) - val boundPredicate = InterpretedPredicate.create(rawPredicate transform { - case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }) - - if (isPartitioned && partitionPruningPredicates.nonEmpty) { - partitions.filter(p => boundPredicate(p.values)) - } else { - partitions + /** This closure sets input paths at the driver side. */ + private[parquet] def initializeDriverSideJobFunc( + inputFiles: Array[FileStatus])(job: Job): Unit = { + // We side the input paths at the driver side. + if (inputFiles.nonEmpty) { + FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } } - override def insert(data: DataFrame, overwrite: Boolean): Unit = { - assert(paths.size == 1, s"Can't write to multiple destinations: ${paths.mkString(",")}") - - // TODO: currently we do not check whether the "schema"s are compatible - // That means if one first creates a table and then INSERTs data with - // and incompatible schema the execution will fail. It would be nice - // to catch this early one, maybe having the planner validate the schema - // before calling execute(). - - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val writeSupport = - if (parquetSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { - log.debug("Initializing MutableRowWriteSupport") - classOf[MutableRowWriteSupport] - } else { - classOf[RowWriteSupport] - } - - ParquetOutputFormat.setWriteSupportClass(job, writeSupport) - - val conf = ContextUtil.getConfiguration(job) - RowWriteSupport.setSchema(data.schema.toAttributes, conf) - - val destinationPath = new Path(paths.head) - - if (overwrite) { - val fs = destinationPath.getFileSystem(conf) - if (fs.exists(destinationPath)) { - var success: Boolean = false - try { - success = fs.delete(destinationPath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${destinationPath.toString} prior" + - s" to writing to Parquet table:\n${e.toString}") - } - if (!success) { - throw new IOException( - s"Unable to clear output directory ${destinationPath.toString} prior" + - s" to writing to Parquet table.") - } - } - } - - job.setOutputKeyClass(classOf[Void]) - job.setOutputValueClass(classOf[Row]) - FileOutputFormat.setOutputPath(job, destinationPath) - - val wrappedConf = new SerializableWritable(job.getConfiguration) - val jobTrackerId = new SimpleDateFormat("yyyyMMddHHmm").format(new Date()) - val stageId = sqlContext.sparkContext.newRddId() - - val taskIdOffset = if (overwrite) { - 1 - } else { - FileSystemHelper.findMaxTaskId( - FileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 - } - - def writeShard(context: TaskContext, iterator: Iterator[Row]): Unit = { - /* "reduce task" */ - val attemptId = newTaskAttemptID( - jobTrackerId, stageId, isMap = false, context.partitionId(), context.attemptNumber()) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) - val format = new AppendingParquetOutputFormat(taskIdOffset) - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - val writer = format.getRecordWriter(hadoopContext) - try { - while (iterator.hasNext) { - val row = iterator.next() - writer.write(null, row) - } - } finally { - writer.close(hadoopContext) - } - - SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) - } - val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) - /* apparently we need a TaskAttemptID to construct an OutputCommitter; - * however we're only going to use this local OutputCommitter for - * setupJob/commitJob, so we just use a dummy "map" task. - */ - val jobAttemptId = newTaskAttemptID(jobTrackerId, stageId, isMap = true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - - jobCommitter.setupJob(jobTaskContext) - sqlContext.sparkContext.runJob(data.queryExecution.executedPlan.execute(), writeShard _) - jobCommitter.commitJob(jobTaskContext) - - metadataCache.refresh() - } -} - -private[sql] object ParquetRelation2 extends Logging { - // Whether we should merge schemas collected from all Parquet part-files. - val MERGE_SCHEMA = "mergeSchema" - - // Default partition name to use when the partition column value is null or empty string. - val DEFAULT_PARTITION_NAME = "partition.defaultName" - - // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used - // internally. - private[sql] val METASTORE_SCHEMA = "metastoreSchema" - private[parquet] def readSchema( footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { footers.map { footer => @@ -762,7 +535,7 @@ private[sql] object ParquetRelation2 extends Logging { // Falls back to Parquet schema if Spark SQL schema is absent. StructType.fromAttributes( // TODO Really no need to use `Attribute` here, we only need to know the data type. - convertToAttributes( + ParquetTypesConverter.convertToAttributes( parquetSchema, sqlContext.conf.isParquetBinaryAsString, sqlContext.conf.isParquetINT96AsTimestamp)) @@ -801,6 +574,7 @@ private[sql] object ParquetRelation2 extends Logging { val ordinalMap = metastoreSchema.zipWithIndex.map { case (field, index) => field.name.toLowerCase -> index }.toMap + val reorderedParquetSchema = mergedParquetSchema.sortBy(f => ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index a5410cda0fe6b..c6a4dabbab05e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql.sources -import org.apache.hadoop.fs.Path - -import org.apache.spark.Logging +import org.apache.spark.{Logging, SerializableWritable, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row +import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.{StructType, UTF8String, StringType} -import org.apache.spark.sql._ +import org.apache.spark.sql.types.{StringType, StructType, UTF8String} +import org.apache.spark.sql.{SaveMode, Strategy, execution, sources} +import org.apache.spark.util.Utils /** * A Strategy for planning scans over data sources defined using the sources API. @@ -58,8 +56,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { filters, (a, _) => t.buildScan(a)) :: Nil - // Scanning partitioned FSBasedRelation - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FSBasedRelation)) + // Scanning partitioned HadoopFsRelation + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: HadoopFsRelation)) if t.partitionSpec.partitionColumns.nonEmpty => val selectedPartitions = prunePartitions(filters, t.partitionSpec).toArray @@ -86,22 +84,18 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { t.partitionSpec.partitionColumns, selectedPartitions) :: Nil - // Scanning non-partitioned FSBasedRelation - case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: FSBasedRelation)) => - val inputPaths = t.paths.map(new Path(_)).flatMap { path => - val fs = path.getFileSystem(t.sqlContext.sparkContext.hadoopConfiguration) - val qualifiedPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.listLeafStatuses(fs, qualifiedPath).map(_.getPath).filterNot { path => - val name = path.getName - name.startsWith("_") || name.startsWith(".") - }.map(fs.makeQualified(_).toString) - } - + // Scanning non-partitioned HadoopFsRelation + case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: HadoopFsRelation)) => + // See buildPartitionedTableScan for the reason that we need to create a shard + // broadcast HadoopConf. + val sharedHadoopConf = SparkHadoopUtil.get.conf + val confBroadcast = + t.sqlContext.sparkContext.broadcast(new SerializableWritable(sharedHadoopConf)) pruneFilterProject( l, projectList, filters, - (a, f) => t.buildScan(a, f, inputPaths)) :: Nil + (a, f) => t.buildScan(a, f, t.paths, confBroadcast)) :: Nil case l @ LogicalRelation(t: TableScan) => createPhysicalRDD(l.relation, l.output, t.buildScan()) :: Nil @@ -111,10 +105,9 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: FSBasedRelation), part, query, overwrite, false) if part.isEmpty => + l @ LogicalRelation(t: HadoopFsRelation), part, query, overwrite, false) => val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append - execution.ExecutedCommand( - InsertIntoFSBasedRelation(t, query, Array.empty[String], mode)) :: Nil + execution.ExecutedCommand(InsertIntoHadoopFsRelation(t, query, mode)) :: Nil case _ => Nil } @@ -126,20 +119,16 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { partitionColumns: StructType, partitions: Array[Partition]) = { val output = projections.map(_.toAttribute) - val relation = logicalRelation.relation.asInstanceOf[FSBasedRelation] + val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] + + // Because we are creating one RDD per partition, we need to have a shared HadoopConf. + // Otherwise, the cost of broadcasting HadoopConf in every RDD will be high. + val sharedHadoopConf = SparkHadoopUtil.get.conf + val confBroadcast = + relation.sqlContext.sparkContext.broadcast(new SerializableWritable(sharedHadoopConf)) // Builds RDD[Row]s for each selected partition. val perPartitionRows = partitions.map { case Partition(partitionValues, dir) => - // Paths to all data files within this partition - val dataFilePaths = { - val dirPath = new Path(dir) - val fs = dirPath.getFileSystem(SparkHadoopUtil.get.conf) - fs.listStatus(dirPath).map(_.getPath).filterNot { path => - val name = path.getName - name.startsWith("_") || name.startsWith(".") - }.map(fs.makeQualified(_).toString) - } - // The table scan operator (PhysicalRDD) which retrieves required columns from data files. // Notice that the schema of data files, represented by `relation.dataSchema`, may contain // some partition column(s). @@ -155,7 +144,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // assuming partition columns data stored in data files are always consistent with those // partition values encoded in partition directory paths. val nonPartitionColumns = requiredColumns.filterNot(partitionColNames.contains) - val dataRows = relation.buildScan(nonPartitionColumns, filters, dataFilePaths) + val dataRows = + relation.buildScan(nonPartitionColumns, filters, Array(dir), confBroadcast) // Merges data values with partition values. mergeWithPartitionValues( @@ -169,9 +159,12 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { scan.execute() } - val unionedRows = perPartitionRows.reduceOption(_ ++ _).getOrElse { - relation.sqlContext.emptyResult - } + val unionedRows = + if (perPartitionRows.length == 0) { + relation.sqlContext.emptyResult + } else { + new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) + } createPhysicalRDD(logicalRelation.relation, output, unionedRows) } @@ -204,7 +197,10 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { } } - dataRows.mapPartitions { iterator => + // Since we know for sure that this closure is serializable, we can avoid the overhead + // of cleaning a closure for each RDD by creating our own MapPartitionsRDD. Functionally + // this is equivalent to calling `dataRows.mapPartitions(mapPartitionsFunc)` (SPARK-7718). + val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[Row]) => { val dataTypes = requiredColumns.map(schema(_).dataType) val mutableRow = new SpecificMutableRow(dataTypes) iterator.map { dataRow => @@ -216,6 +212,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { mutableRow.asInstanceOf[expressions.Row] } } + + // This is an internal RDD whose call site the user should not be concerned with + // Since we create many of these (one per partition), the time spent on computing + // the call site may add up. + Utils.withDummyCallSite(dataRows.sparkContext) { + new MapPartitionsRDD(dataRows, mapPartitionsFunc, preservesPartitioning = false) + } + } else { dataRows } @@ -305,7 +309,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { output: Seq[Attribute], rdd: RDD[Row]): SparkPlan = { val converted = if (relation.needConversion) { - execution.RDDConversions.rowToRowRdd(rdd, relation.schema) + execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { rdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala index d30f7f65e21c0..c4c99de5a38dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/PartitioningUtils.scala @@ -23,8 +23,8 @@ import java.math.{BigDecimal => JBigDecimal} import scala.collection.mutable.ArrayBuffer import scala.util.Try -import com.google.common.cache.{CacheBuilder, Cache} -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.util.Shell import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} @@ -34,7 +34,15 @@ private[sql] case class Partition(values: Row, path: String) private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) +private[sql] object PartitionSpec { + val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition]) +} + private[sql] object PartitioningUtils { + // This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since sql/core doesn't + // depend on Hive. + private[sql] val DEFAULT_PARTITION_NAME = "__HIVE_DEFAULT_PARTITION__" + private[sql] case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) { require(columnNames.size == literals.size) } @@ -65,20 +73,37 @@ private[sql] object PartitioningUtils { private[sql] def parsePartitions( paths: Seq[Path], defaultPartitionName: String): PartitionSpec = { - val partitionValues = resolvePartitions(paths.map(parsePartition(_, defaultPartitionName))) - val fields = { - val (PartitionValues(columnNames, literals)) = partitionValues.head - columnNames.zip(literals).map { case (name, Literal(_, dataType)) => - StructField(name, dataType, nullable = true) - } + // First, we need to parse every partition's path and see if we can find partition values. + val pathsWithPartitionValues = paths.flatMap { path => + parsePartition(path, defaultPartitionName).map(path -> _) } - val partitions = partitionValues.zip(paths).map { - case (PartitionValues(_, literals), path) => - Partition(Row(literals.map(_.value): _*), path.toString) - } + if (pathsWithPartitionValues.isEmpty) { + // This dataset is not partitioned. + PartitionSpec.emptySpec + } else { + // This dataset is partitioned. We need to check whether all partitions have the same + // partition columns and resolve potential type conflicts. + val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues.map(_._2)) - PartitionSpec(StructType(fields), partitions) + // Creates the StructType which represents the partition columns. + val fields = { + val PartitionValues(columnNames, literals) = resolvedPartitionValues.head + columnNames.zip(literals).map { case (name, Literal(_, dataType)) => + // We always assume partition columns are nullable since we've no idea whether null values + // will be appended in the future. + StructField(name, dataType, nullable = true) + } + } + + // Finally, we create `Partition`s based on paths and resolved partition values. + val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map { + case (PartitionValues(_, literals), (path, _)) => + Partition(Row.fromSeq(literals.map(_.value)), path.toString) + } + + PartitionSpec(StructType(fields), partitions) + } } /** @@ -99,21 +124,31 @@ private[sql] object PartitioningUtils { */ private[sql] def parsePartition( path: Path, - defaultPartitionName: String): PartitionValues = { + defaultPartitionName: String): Option[PartitionValues] = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null var chopped = path while (!finished) { + // Sometimes (e.g., when speculative task is enabled), temporary directories may be left + // uncleaned. Here we simply ignore them. + if (chopped.getName.toLowerCase == "_temporary") { + return None + } + val maybeColumn = parsePartitionColumn(chopped.getName, defaultPartitionName) maybeColumn.foreach(columns += _) chopped = chopped.getParent finished = maybeColumn.isEmpty || chopped.getParent == null } - val (columnNames, values) = columns.reverse.unzip - PartitionValues(columnNames, values) + if (columns.isEmpty) { + None + } else { + val (columnNames, values) = columns.reverse.unzip + Some(PartitionValues(columnNames, values)) + } } private def parsePartitionColumn( @@ -147,20 +182,25 @@ private[sql] object PartitioningUtils { private[sql] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = { // Column names of all partitions must match val distinctPartitionsColNames = values.map(_.columnNames).distinct - assert(distinctPartitionsColNames.size == 1, { - val list = distinctPartitionsColNames.mkString("\t", "\n", "") - s"Conflicting partition column names detected:\n$list" - }) - - // Resolves possible type conflicts for each column - val columnCount = values.head.columnNames.size - val resolvedValues = (0 until columnCount).map { i => - resolveTypeConflicts(values.map(_.literals(i))) - } - // Fills resolved literals back to each partition - values.zipWithIndex.map { case (d, index) => - d.copy(literals = resolvedValues.map(_(index))) + if (distinctPartitionsColNames.isEmpty) { + Seq.empty + } else { + assert(distinctPartitionsColNames.size == 1, { + val list = distinctPartitionsColNames.mkString("\t", "\n\t", "") + s"Conflicting partition column names detected:\n$list" + }) + + // Resolves possible type conflicts for each column + val columnCount = values.head.columnNames.size + val resolvedValues = (0 until columnCount).map { i => + resolveTypeConflicts(values.map(_.literals(i))) + } + + // Fills resolved literals back to each partition + values.zipWithIndex.map { case (d, index) => + d.copy(literals = resolvedValues.map(_(index))) + } } } @@ -182,7 +222,7 @@ private[sql] object PartitioningUtils { // Then falls back to string .getOrElse { if (raw == defaultPartitionName) Literal.create(null, NullType) - else Literal.create(raw, StringType) + else Literal.create(unescapePathName(raw), StringType) } } @@ -204,4 +244,77 @@ private[sql] object PartitioningUtils { Literal.create(Cast(l, desiredType).eval(), desiredType) } } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). + ////////////////////////////////////////////////////////////////////////////////////////////////// + + val charToEscape = { + val bitSet = new java.util.BitSet(128) + + /** + * ASCII 01-1F are HTTP control characters that need to be escaped. + * \u000A and \u000D are \n and \r, respectively. + */ + val clist = Array( + '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', + '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', + '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', + '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', + '{', '[', ']', '^') + + clist.foreach(bitSet.set(_)) + + if (Shell.WINDOWS) { + Array(' ', '<', '>', '|').foreach(bitSet.set(_)) + } + + bitSet + } + + def needsEscaping(c: Char): Boolean = { + c >= 0 && c < charToEscape.size() && charToEscape.get(c) + } + + def escapePathName(path: String): String = { + val builder = new StringBuilder() + path.foreach { c => + if (needsEscaping(c)) { + builder.append('%') + builder.append(f"${c.asInstanceOf[Int]}%02x") + } else { + builder.append(c) + } + } + + builder.toString() + } + + def unescapePathName(path: String): String = { + val sb = new StringBuilder + var i = 0 + + while (i < path.length) { + val c = path.charAt(i) + if (c == '%' && i + 2 < path.length) { + val code: Int = try { + Integer.valueOf(path.substring(i + 1, i + 3), 16) + } catch { case e: Exception => + -1: Integer + } + if (code >= 0) { + sb.append(code.asInstanceOf[Char]) + i += 3 + } else { + sb.append(c) + i += 1 + } + } else { + sb.append(c) + i += 1 + } + } + + sb.toString() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala new file mode 100644 index 0000000000000..ebad0c1564ec0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/SqlNewHadoopRDD.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.text.SimpleDateFormat +import java.util.Date + +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.spark.broadcast.Broadcast + +import org.apache.spark.{Partition => SparkPartition, _} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.DataReadMethod +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.rdd.{RDD, HadoopRDD} +import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + +import scala.reflect.ClassTag + +private[spark] class SqlNewHadoopPartition( + rddId: Int, + val index: Int, + @transient rawSplit: InputSplit with Writable) + extends SparkPartition { + + val serializableHadoopSplit = new SerializableWritable(rawSplit) + + override def hashCode(): Int = 41 * (41 + rddId) + index +} + +/** + * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, + * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). + * It is based on [[org.apache.spark.rdd.NewHadoopRDD]]. It has three additions. + * 1. A shared broadcast Hadoop Configuration. + * 2. An optional closure `initDriverSideJobFuncOpt` that set configurations at the driver side + * to the shared Hadoop Configuration. + * 3. An optional closure `initLocalJobFuncOpt` that set configurations at both the driver side + * and the executor side to the shared Hadoop Configuration. + * + * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with + * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be + * folded into core. + */ +private[sql] class SqlNewHadoopRDD[K, V]( + @transient sc : SparkContext, + broadcastedConf: Broadcast[SerializableWritable[Configuration]], + @transient initDriverSideJobFuncOpt: Option[Job => Unit], + initLocalJobFuncOpt: Option[Job => Unit], + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V]) + extends RDD[(K, V)](sc, Nil) + with SparkHadoopMapReduceUtil + with Logging { + + protected def getJob(): Job = { + val conf: Configuration = broadcastedConf.value.value + // "new Job" will make a copy of the conf. Then, it is + // safe to mutate conf properties with initLocalJobFuncOpt + // and initDriverSideJobFuncOpt. + val newJob = new Job(conf) + initLocalJobFuncOpt.map(f => f(newJob)) + newJob + } + + def getConf(isDriverSide: Boolean): Configuration = { + val job = getJob() + if (isDriverSide) { + initDriverSideJobFuncOpt.map(f => f(job)) + } + job.getConfiguration + } + + private val jobTrackerId: String = { + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + formatter.format(new Date()) + } + + @transient protected val jobId = new JobID(jobTrackerId, id) + + override def getPartitions: Array[SparkPartition] = { + val conf = getConf(isDriverSide = true) + val inputFormat = inputFormatClass.newInstance + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val jobContext = newJobContext(conf, jobId) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[SparkPartition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = + new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } + + override def compute( + theSplit: SparkPartition, + context: TaskContext): InterruptibleIterator[(K, V)] = { + val iter = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[SqlNewHadoopPartition] + logInfo("Input split: " + split.serializableHadoopSplit) + val conf = getConf(isDriverSide = false) + + val inputMetrics = context.taskMetrics + .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { + split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + } + inputMetrics.setBytesReadCallback(bytesReadCallback) + + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + format match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener(context => close()) + var havePair = false + var finished = false + var recordsSinceMetricsUpdate = 0 + + override def hasNext: Boolean = { + if (!finished && !havePair) { + finished = !reader.nextKeyValue + havePair = !finished + } + !finished + } + + override def next(): (K, V) = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + if (!finished) { + inputMetrics.incRecordsRead(1) + } + (reader.getCurrentKey, reader.getCurrentValue) + } + + private def close() { + try { + reader.close() + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } + } + } catch { + case e: Exception => { + if (!Utils.inShutdown()) { + logWarning("Exception in RecordReader.close()", e) + } + } + } + } + } + new InterruptibleIterator(context, iter) + } + + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[U: ClassTag]( + f: (InputSplit, Iterator[(K, V)]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + } + + override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { + val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value + val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { + case Some(c) => + try { + val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] + Some(HadoopRDD.convertSplitLocationInfo(infos)) + } catch { + case e : Exception => + logDebug("Failed to use InputSplit#getLocationInfo.", e) + None + } + case None => None + } + locs.getOrElse(split.getLocations.filter(_ != "localhost")) + } + + override def persist(storageLevel: StorageLevel): this.type = { + if (storageLevel.deserialized) { + logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + + " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + + " Use a map transformation to make copies of the records.") + } + super.persist(storageLevel) + } +} + +private[spark] object SqlNewHadoopRDD { + /** + * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to + * the given function rather than the index of the partition. + */ + private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( + prev: RDD[T], + f: (InputSplit, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None + + override def getPartitions: Array[SparkPartition] = firstParent[T].partitions + + override def compute(split: SparkPartition, context: TaskContext): Iterator[U] = { + val partition = split.asInstanceOf[SqlNewHadoopPartition] + val inputSplit = partition.serializableHadoopSplit.value + f(inputSplit, firstParent[T].iterator(split, context)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 8372d2c34acc7..71f016b1f14de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -23,19 +23,20 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat} -import org.apache.hadoop.util.Shell +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter, FileOutputFormat} import parquet.hadoop.util.ContextUtil import org.apache.spark._ import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode} private[sql] case class InsertIntoDataSource( logicalRelation: LogicalRelation, @@ -58,10 +59,9 @@ private[sql] case class InsertIntoDataSource( } } -private[sql] case class InsertIntoFSBasedRelation( - @transient relation: FSBasedRelation, +private[sql] case class InsertIntoHadoopFsRelation( + @transient relation: HadoopFsRelation, @transient query: LogicalPlan, - partitionColumns: Array[String], mode: SaveMode) extends RunnableCommand { @@ -93,16 +93,28 @@ private[sql] case class InsertIntoFSBasedRelation( job.setOutputValueClass(classOf[Row]) FileOutputFormat.setOutputPath(job, qualifiedOutputPath) - val df = sqlContext.createDataFrame( - DataFrame(sqlContext, query).queryExecution.toRdd, - relation.schema, - needsConversion = false) + // We create a DataFrame by applying the schema of relation to the data to make sure. + // We are writing data based on the expected schema, + val df = { + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). We + // need a Project to adjust the ordering, so that inside InsertIntoHadoopFsRelation, we can + // safely apply the schema of r.schema to the data. + val project = Project( + relation.schema.map(field => new UnresolvedAttribute(Seq(field.name))), query) + + sqlContext.createDataFrame( + DataFrame(sqlContext, project).queryExecution.toRdd, + relation.schema, + needsConversion = false) + } + val partitionColumns = relation.partitionColumns.fieldNames if (partitionColumns.isEmpty) { insert(new DefaultWriterContainer(relation, job), df) } else { val writerContainer = new DynamicPartitionWriterContainer( - relation, job, partitionColumns, "__HIVE_DEFAULT_PARTITION__") + relation, job, partitionColumns, PartitioningUtils.DEFAULT_PARTITION_NAME) insertWithDynamicPartitions(sqlContext, writerContainer, df, partitionColumns) } } @@ -121,6 +133,7 @@ private[sql] case class InsertIntoFSBasedRelation( writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => + logError("Aborting job.", cause) writerContainer.abortJob() throw new SparkException("Job aborted.", cause) } @@ -143,6 +156,7 @@ private[sql] case class InsertIntoFSBasedRelation( } writerContainer.commitTask() } catch { case cause: Throwable => + logError("Aborting task.", cause) writerContainer.abortTask() throw new SparkException("Task failed while writing rows.", cause) } @@ -204,9 +218,11 @@ private[sql] case class InsertIntoFSBasedRelation( writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart) } } else { + val partitionSchema = StructType.fromAttributes(partitionOutput) + val converter = CatalystTypeConverters.createToScalaConverter(partitionSchema) while (iterator.hasNext) { val row = iterator.next() - val partitionPart = partitionProj(row) + val partitionPart = converter(partitionProj(row)).asInstanceOf[Row] val dataPart = dataProj(row) writerContainer.outputWriterForRow(partitionPart).write(dataPart) } @@ -232,7 +248,7 @@ private[sql] case class InsertIntoFSBasedRelation( } private[sql] abstract class BaseWriterContainer( - @transient val relation: FSBasedRelation, + @transient val relation: HadoopFsRelation, @transient job: Job) extends SparkHadoopMapReduceUtil with Logging @@ -244,7 +260,7 @@ private[sql] abstract class BaseWriterContainer( @transient private val jobContext: JobContext = job // The following fields are initialized and used on both driver and executor side. - @transient protected var outputCommitter: FileOutputCommitter = _ + @transient protected var outputCommitter: OutputCommitter = _ @transient private var jobId: JobID = _ @transient private var taskId: TaskID = _ @transient private var taskAttemptId: TaskAttemptID = _ @@ -259,7 +275,7 @@ private[sql] abstract class BaseWriterContainer( protected val dataSchema = relation.dataSchema - protected val outputWriterClass: Class[_ <: OutputWriter] = relation.outputWriterClass + protected var outputWriterFactory: OutputWriterFactory = _ private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _ @@ -267,7 +283,7 @@ private[sql] abstract class BaseWriterContainer( setupIDs(0, 0, 0) setupConf() taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) - relation.prepareForWrite(job) + outputWriterFactory = relation.prepareJobForWrite(job) outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupJob(jobContext) @@ -282,11 +298,39 @@ private[sql] abstract class BaseWriterContainer( initWriters() } - private def newOutputCommitter(context: TaskAttemptContext): FileOutputCommitter = { - outputFormatClass.newInstance().getOutputCommitter(context) match { - case f: FileOutputCommitter => f - case f => sys.error( - s"FileOutputCommitter or its subclass is expected, but got a ${f.getClass.getName}.") + protected def getWorkPath: String = { + outputCommitter match { + // FileOutputCommitter writes to a temporary location returned by `getWorkPath`. + case f: MapReduceFileOutputCommitter => f.getWorkPath.toString + case _ => outputPath + } + } + + private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + val committerClass = context.getConfiguration.getClass( + SQLConf.OUTPUT_COMMITTER_CLASS, null, classOf[OutputCommitter]) + + Option(committerClass).map { clazz => + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat + // has an associated output committer. To override this output committer, + // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. + // If a data source needs to override the output committer, it needs to set the + // output committer in prepareForWrite method. + if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) { + // The specified output committer is a FileOutputCommitter. + // So, we will use the FileOutputCommitter-specified constructor. + val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + ctor.newInstance(new Path(outputPath), context) + } else { + // The specified output committer is just a OutputCommitter. + // So, we will use the no-argument constructor. + val ctor = clazz.getDeclaredConstructor() + ctor.newInstance() + } + }.getOrElse { + // If output committer class is not set, we will use the one associated with the + // file output format. + outputFormatClass.newInstance().getOutputCommitter(context) } } @@ -331,32 +375,41 @@ private[sql] abstract class BaseWriterContainer( } private[sql] class DefaultWriterContainer( - @transient relation: FSBasedRelation, + @transient relation: HadoopFsRelation, @transient job: Job) extends BaseWriterContainer(relation, job) { @transient private var writer: OutputWriter = _ override protected def initWriters(): Unit = { - writer = outputWriterClass.newInstance() - writer.init(outputCommitter.getWorkPath.toString, dataSchema, taskAttemptContext) + taskAttemptContext.getConfiguration.set("spark.sql.sources.output.path", outputPath) + writer = outputWriterFactory.newInstance(getWorkPath, dataSchema, taskAttemptContext) } override def outputWriterForRow(row: Row): OutputWriter = writer override def commitTask(): Unit = { - writer.close() - super.commitTask() + try { + writer.close() + super.commitTask() + } catch { + case cause: Throwable => + super.abortTask() + throw new RuntimeException("Failed to commit task", cause) + } } override def abortTask(): Unit = { - writer.close() - super.abortTask() + try { + writer.close() + } finally { + super.abortTask() + } } } private[sql] class DynamicPartitionWriterContainer( - @transient relation: FSBasedRelation, + @transient relation: HadoopFsRelation, @transient job: Job, partitionColumns: Array[String], defaultPartitionName: String) @@ -375,73 +428,35 @@ private[sql] class DynamicPartitionWriterContainer( val valueString = if (string == null || string.isEmpty) { defaultPartitionName } else { - DynamicPartitionWriterContainer.escapePathName(string) + PartitioningUtils.escapePathName(string) } s"/$col=$valueString" - }.mkString + }.mkString.stripPrefix(Path.SEPARATOR) outputWriters.getOrElseUpdate(partitionPath, { - val path = new Path(outputCommitter.getWorkPath, partitionPath.stripPrefix(Path.SEPARATOR)) - val writer = outputWriterClass.newInstance() - writer.init(path.toString, dataSchema, taskAttemptContext) - writer + val path = new Path(getWorkPath, partitionPath) + taskAttemptContext.getConfiguration.set( + "spark.sql.sources.output.path", + new Path(outputPath, partitionPath).toString) + outputWriterFactory.newInstance(path.toString, dataSchema, taskAttemptContext) }) } override def commitTask(): Unit = { - outputWriters.values.foreach(_.close()) - super.commitTask() - } - - override def abortTask(): Unit = { - outputWriters.values.foreach(_.close()) - super.abortTask() - } -} - -private[sql] object DynamicPartitionWriterContainer { - ////////////////////////////////////////////////////////////////////////////////////////////////// - // The following string escaping code is mainly copied from Hive (o.a.h.h.common.FileUtils). - ////////////////////////////////////////////////////////////////////////////////////////////////// - - val charToEscape = { - val bitSet = new java.util.BitSet(128) - - /** - * ASCII 01-1F are HTTP control characters that need to be escaped. - * \u000A and \u000D are \n and \r, respectively. - */ - val clist = Array( - '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u0009', - '\n', '\u000B', '\u000C', '\r', '\u000E', '\u000F', '\u0010', '\u0011', '\u0012', '\u0013', - '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001A', '\u001B', '\u001C', - '\u001D', '\u001E', '\u001F', '"', '#', '%', '\'', '*', '/', ':', '=', '?', '\\', '\u007F', - '{', '[', ']', '^') - - clist.foreach(bitSet.set(_)) - - if (Shell.WINDOWS) { - Array(' ', '<', '>', '|').foreach(bitSet.set(_)) + try { + outputWriters.values.foreach(_.close()) + super.commitTask() + } catch { case cause: Throwable => + super.abortTask() + throw new RuntimeException("Failed to commit task", cause) } - - bitSet } - def needsEscaping(c: Char): Boolean = { - c >= 0 && c < charToEscape.size() && charToEscape.get(c) - } - - def escapePathName(path: String): String = { - val builder = new StringBuilder() - path.foreach { c => - if (DynamicPartitionWriterContainer.needsEscaping(c)) { - builder.append('%') - builder.append(f"${c.asInstanceOf[Int]}%02x") - } else { - builder.append(c) - } + override def abortTask(): Unit = { + try { + outputWriters.values.foreach(_.close()) + } finally { + super.abortTask() } - - builder.toString() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 595c5eb40e295..20afd60cb7767 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -130,7 +130,7 @@ private[sql] class DDLParser( } } - protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" + protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")" /* * describe [extended] table avroTable @@ -138,7 +138,7 @@ private[sql] class DDLParser( */ protected lazy val describeTable: Parser[LogicalPlan] = (DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ { - case e ~ db ~ tbl => + case e ~ db ~ tbl => val tblIdentifier = db match { case Some(dbName) => Seq(dbName, tbl) @@ -171,7 +171,7 @@ private[sql] class DDLParser( } protected lazy val pair: Parser[(String, String)] = - optionName ~ stringLit ^^ { case k ~ v => (k,v) } + optionName ~ stringLit ^^ { case k ~ v => (k, v) } protected lazy val column: Parser[StructField] = ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => @@ -188,18 +188,20 @@ private[sql] class DDLParser( private[sql] object ResolvedDataSource { private val builtinSources = Map( - "jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource], - "json" -> classOf[org.apache.spark.sql.json.DefaultSource], - "parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource] + "jdbc" -> "org.apache.spark.sql.jdbc.DefaultSource", + "json" -> "org.apache.spark.sql.json.DefaultSource", + "parquet" -> "org.apache.spark.sql.parquet.DefaultSource", + "orc" -> "org.apache.spark.sql.hive.orc.DefaultSource" ) /** Given a provider name, look up the data source class definition. */ def lookupDataSource(provider: String): Class[_] = { + val loader = Utils.getContextOrSparkClassLoader + if (builtinSources.contains(provider)) { - return builtinSources(provider) + return loader.loadClass(builtinSources(provider)) } - val loader = Utils.getContextOrSparkClassLoader try { loader.loadClass(provider) } catch { @@ -208,7 +210,11 @@ private[sql] object ResolvedDataSource { loader.loadClass(provider + ".DefaultSource") } catch { case cnf: java.lang.ClassNotFoundException => - sys.error(s"Failed to load class for data source: $provider") + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + sys.error("The ORC data source must be used with Hive support enabled.") + } else { + sys.error(s"Failed to load class for data source: $provider") + } } } } @@ -226,25 +232,26 @@ private[sql] object ResolvedDataSource { case Some(schema: StructType) => clazz.newInstance() match { case dataSource: SchemaRelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema) - case dataSource: FSBasedRelationProvider => + case dataSource: HadoopFsRelationProvider => val maybePartitionsSchema = if (partitionColumns.isEmpty) { None } else { Some(partitionColumnsSchema(schema, partitionColumns)) } - val caseInsensitiveOptions= new CaseInsensitiveMap(options) + val caseInsensitiveOptions = new CaseInsensitiveMap(options) val paths = { val patternPath = new Path(caseInsensitiveOptions("path")) SparkHadoopUtil.get.globPath(patternPath).map(_.toString).toArray } - val dataSchema = StructType(schema.filterNot(f => partitionColumns.contains(f.name))) + val dataSchema = + StructType(schema.filterNot(f => partitionColumns.contains(f.name))).asNullable dataSource.createRelation( sqlContext, paths, - Some(schema), + Some(dataSchema), maybePartitionsSchema, caseInsensitiveOptions) case dataSource: org.apache.spark.sql.sources.RelationProvider => @@ -256,7 +263,7 @@ private[sql] object ResolvedDataSource { case None => clazz.newInstance() match { case dataSource: RelationProvider => dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options)) - case dataSource: FSBasedRelationProvider => + case dataSource: HadoopFsRelationProvider => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val paths = { val patternPath = new Path(caseInsensitiveOptions("path")) @@ -296,7 +303,7 @@ private[sql] object ResolvedDataSource { val relation = clazz.newInstance() match { case dataSource: CreatableRelationProvider => dataSource.createRelation(sqlContext, mode, options, data) - case dataSource: FSBasedRelationProvider => + case dataSource: HadoopFsRelationProvider => // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; @@ -314,11 +321,14 @@ private[sql] object ResolvedDataSource { Some(dataSchema.asNullable), Some(partitionColumnsSchema(data.schema, partitionColumns)), caseInsensitiveOptions) + + // For partitioned relation r, r.schema's column ordering can be different from the column + // ordering of data.logicalPlan (partition columns are all moved after data column). This + // will be adjusted within InsertIntoHadoopFsRelation. sqlContext.executePlan( - InsertIntoFSBasedRelation( + InsertIntoHadoopFsRelation( r, data.logicalPlan, - partitionColumns.toArray, mode)).toRdd r case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6f315305c11d6..f5bd2d2941ca0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -17,19 +17,21 @@ package org.apache.spark.sql.sources +import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.SerializableWritable import org.apache.spark.sql.{Row, _} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.StructType /** * ::DeveloperApi:: @@ -91,10 +93,10 @@ trait SchemaRelationProvider { } /** - * ::DeveloperApi:: + * ::Experimental:: * Implemented by objects that produce relations for a specific kind of data source * with a given schema and partitioned columns. When Spark SQL is given a DDL operation with a - * USING clause specified (to specify the implemented [[FSBasedRelationProvider]]), a user defined + * USING clause specified (to specify the implemented [[HadoopFsRelationProvider]]), a user defined * schema, and an optional list of partition columns, this interface is used to pass in the * parameters specified by a user. * @@ -105,26 +107,29 @@ trait SchemaRelationProvider { * * A new instance of this class with be instantiated each time a DDL call is made. * - * The difference between a [[RelationProvider]] and a [[FSBasedRelationProvider]] is + * The difference between a [[RelationProvider]] and a [[HadoopFsRelationProvider]] is * that users need to provide a schema and a (possibly empty) list of partition columns when * using a SchemaRelationProvider. A relation provider can inherits both [[RelationProvider]], - * and [[FSBasedRelationProvider]] if it can support schema inference, user-specified + * and [[HadoopFsRelationProvider]] if it can support schema inference, user-specified * schemas, and accessing partitioned relations. * * @since 1.4.0 */ -trait FSBasedRelationProvider { +@Experimental +trait HadoopFsRelationProvider { /** * Returns a new base relation with the given parameters, a user defined schema, and a list of * partition columns. Note: the parameters' keywords are case insensitive and this insensitivity * is enforced by the Map that is passed to the function. + * + * @param dataSchema Schema of data columns (i.e., columns that are not partition columns). */ def createRelation( sqlContext: SQLContext, paths: Array[String], - schema: Option[StructType], + dataSchema: Option[StructType], partitionColumns: Option[StructType], - parameters: Map[String, String]): FSBasedRelation + parameters: Map[String, String]): HadoopFsRelation } /** @@ -280,33 +285,42 @@ trait CatalystScan { /** * ::Experimental:: - * [[OutputWriter]] is used together with [[FSBasedRelation]] for persisting rows to the - * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. - * An [[OutputWriter]] instance is created and initialized when a new output file is opened on - * executor side. This instance is used to persist rows to this single output file. + * A factory that produces [[OutputWriter]]s. A new [[OutputWriterFactory]] is created on driver + * side for each write job issued when writing to a [[HadoopFsRelation]], and then gets serialized + * to executor side to create actual [[OutputWriter]]s on the fly. * * @since 1.4.0 */ @Experimental -abstract class OutputWriter { +abstract class OutputWriterFactory extends Serializable { /** - * Initializes this [[OutputWriter]] before any rows are persisted. + * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side + * to instantiate new [[OutputWriter]]s. * * @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that * this may not point to the final output file. For example, `FileOutputFormat` writes to * temporary directories and then merge written files back to the final destination. In * this case, `path` points to a temporary output file under the temporary directory. * @param dataSchema Schema of the rows to be written. Partition columns are not included in the - * schema if the corresponding relation is partitioned. + * schema if the relation being written is partitioned. * @param context The Hadoop MapReduce task context. * * @since 1.4.0 */ - def init( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): Unit = () + def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter +} +/** + * ::Experimental:: + * [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the + * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. + * An [[OutputWriter]] instance is created and initialized when a new output file is opened on + * executor side. This instance is used to persist rows to this single output file. + * + * @since 1.4.0 + */ +@Experimental +abstract class OutputWriter { /** * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned * tables, dynamic partition columns are not included in rows to be written. @@ -333,96 +347,152 @@ abstract class OutputWriter { * filter using selected predicates before producing an RDD containing all matching tuples as * [[Row]] objects. In addition, when reading from Hive style partitioned tables stored in file * systems, it's able to discover partitioning information from the paths of input directories, and - * perform partition pruning before start reading the data. Subclasses of [[FSBasedRelation()]] must - * override one of the three `buildScan` methods to implement the read path. + * perform partition pruning before start reading the data. Subclasses of [[HadoopFsRelation()]] + * must override one of the three `buildScan` methods to implement the read path. * * For the write path, it provides the ability to write to both non-partitioned and partitioned * tables. Directory layout of the partitioned tables is compatible with Hive. * * @constructor This constructor is for internal uses only. The [[PartitionSpec]] argument is for * implementing metastore table conversion. - * @param paths Base paths of this relation. For partitioned relations, it should be the root - * directories of all partition directories. - * @param maybePartitionSpec An [[FSBasedRelation]] can be created with an optional + * + * @param maybePartitionSpec An [[HadoopFsRelation]] can be created with an optional * [[PartitionSpec]], so that partition discovery can be skipped. * * @since 1.4.0 */ @Experimental -abstract class FSBasedRelation private[sql]( - val paths: Array[String], - maybePartitionSpec: Option[PartitionSpec]) +abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[PartitionSpec]) extends BaseRelation { + def this() = this(None) + + private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + + private val codegenEnabled = sqlContext.conf.codegenEnabled + + private var _partitionSpec: PartitionSpec = _ + + private class FileStatusCache { + var leafFiles = mutable.Map.empty[Path, FileStatus] + + var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] + + def refresh(): Unit = { + // We don't filter files/directories whose name start with "_" except "_temporary" here, as + // specific data sources may take advantages over them (e.g. Parquet _metadata and + // _common_metadata files). "_temporary" directories are explicitly ignored since failed + // tasks/jobs may leave partial/corrupted data files there. + def listLeafFilesAndDirs(fs: FileSystem, status: FileStatus): Set[FileStatus] = { + if (status.getPath.getName.toLowerCase == "_temporary") { + Set.empty + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + val leafDirs = if (dirs.isEmpty) Set(status) else Set.empty[FileStatus] + files.toSet ++ leafDirs ++ dirs.flatMap(dir => listLeafFilesAndDirs(fs, dir)) + } + } + + leafFiles.clear() + + val statuses = paths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + Try(fs.getFileStatus(qualified)).toOption.toArray.flatMap(listLeafFilesAndDirs(fs, _)) + }.filterNot { status => + // SPARK-8037: Ignores files like ".DS_Store" and other hidden files/directories + status.getPath.getName.startsWith(".") + } + + val files = statuses.filterNot(_.isDir) + leafFiles ++= files.map(f => f.getPath -> f).toMap + leafDirToChildrenFiles ++= files.groupBy(_.getPath.getParent) + } + } + + private lazy val fileStatusCache = { + val cache = new FileStatusCache + cache.refresh() + cache + } + + protected def cachedLeafStatuses(): Set[FileStatus] = { + fileStatusCache.leafFiles.values.toSet + } + + final private[sql] def partitionSpec: PartitionSpec = { + if (_partitionSpec == null) { + _partitionSpec = maybePartitionSpec + .flatMap { + case spec if spec.partitions.nonEmpty => + Some(spec.copy(partitionColumns = spec.partitionColumns.asNullable)) + case _ => + None + } + .orElse { + // We only know the partition columns and their data types. We need to discover + // partition values. + userDefinedPartitionColumns.map { partitionSchema => + val spec = discoverPartitions() + val castedPartitions = spec.partitions.map { case p @ Partition(values, path) => + val literals = values.toSeq.zip(spec.partitionColumns.map(_.dataType)).map { + case (value, dataType) => Literal.create(value, dataType) + } + val castedValues = partitionSchema.zip(literals).map { case (field, literal) => + Cast(literal, field.dataType).eval() + } + p.copy(values = Row.fromSeq(castedValues)) + } + PartitionSpec(partitionSchema, castedPartitions) + } + } + .getOrElse { + if (sqlContext.conf.partitionDiscoveryEnabled()) { + discoverPartitions() + } else { + PartitionSpec(StructType(Nil), Array.empty[Partition]) + } + } + } + _partitionSpec + } + /** - * Constructs an [[FSBasedRelation]]. - * - * @param paths Base paths of this relation. For partitioned relations, it should be either root - * directories of all partition directories. - * @param partitionColumns Partition columns of this relation. + * Base paths of this relation. For partitioned relations, it should be either root directories + * of all partition directories. * * @since 1.4.0 */ - def this(paths: Array[String], partitionColumns: StructType) = - this(paths, { - if (partitionColumns.isEmpty) None - else Some(PartitionSpec(partitionColumns, Array.empty[Partition])) - }) + def paths: Array[String] /** - * Constructs an [[FSBasedRelation]]. - * - * @param paths Base paths of this relation. For partitioned relations, it should be root - * directories of all partition directories. + * Partition columns. Can be either defined by [[userDefinedPartitionColumns]] or automatically + * discovered. Note that they should always be nullable. * * @since 1.4.0 */ - def this(paths: Array[String]) = this(paths, None) - - private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) - - private val codegenEnabled = sqlContext.conf.codegenEnabled - - private var _partitionSpec: PartitionSpec = maybePartitionSpec.map { spec => - spec.copy(partitionColumns = spec.partitionColumns.asNullable) - }.getOrElse { - if (sqlContext.conf.partitionDiscoveryEnabled()) { - discoverPartitions() - } else { - PartitionSpec(StructType(Nil), Array.empty[Partition]) - } - } - - private[sql] def partitionSpec: PartitionSpec = _partitionSpec + final def partitionColumns: StructType = + userDefinedPartitionColumns.getOrElse(partitionSpec.partitionColumns) /** - * Partition columns. Note that they are always nullable. + * Optional user defined partition columns. * * @since 1.4.0 */ - def partitionColumns: StructType = partitionSpec.partitionColumns + def userDefinedPartitionColumns: Option[StructType] = None private[sql] def refresh(): Unit = { + fileStatusCache.refresh() if (sqlContext.conf.partitionDiscoveryEnabled()) { _partitionSpec = discoverPartitions() } } private def discoverPartitions(): PartitionSpec = { - val basePaths = paths.map(new Path(_)) - val leafDirs = basePaths.flatMap { path => - val fs = path.getFileSystem(hadoopConf) - Try(fs.getFileStatus(path.makeQualified(fs.getUri, fs.getWorkingDirectory))) - .filter(_.isDir) - .map(SparkHadoopUtil.get.listLeafDirStatuses(fs, _)) - .getOrElse(Seq.empty[FileStatus]) - }.map(_.getPath) - - if (leafDirs.nonEmpty) { - PartitioningUtils.parsePartitions(leafDirs, "__HIVE_DEFAULT_PARTITION__") - } else { - PartitionSpec(StructType(Array.empty[StructField]), Array.empty[Partition]) - } + // We use leaf dirs containing data files to discover the schema. + val leafDirs = fileStatusCache.leafDirToChildrenFiles.keys.toSeq + PartitioningUtils.parsePartitions(leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME) } /** @@ -433,11 +503,33 @@ abstract class FSBasedRelation private[sql]( */ override lazy val schema: StructType = { val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet - StructType(dataSchema ++ partitionSpec.partitionColumns.filterNot { column => + StructType(dataSchema ++ partitionColumns.filterNot { column => dataSchemaColumnNames.contains(column.name.toLowerCase) }) } + private[sources] final def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[String], + broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { + val inputStatuses = inputPaths.flatMap { input => + val path = new Path(input) + + // First assumes `input` is a directory path, and tries to get all files contained in it. + fileStatusCache.leafDirToChildrenFiles.getOrElse( + path, + // Otherwise, `input` might be a file path + fileStatusCache.leafFiles.get(path).toArray + ).filter { status => + val name = status.getPath.getName + !name.startsWith("_") && !name.startsWith(".") + } + } + + buildScan(requiredColumns, filters, inputStatuses, broadcastedConf) + } + /** * Specifies schema of actual data files. For partitioned relations, if one or more partitioned * columns are contained in the data files, they should also appear in `dataSchema`. @@ -451,14 +543,14 @@ abstract class FSBasedRelation private[sql]( * this relation. For partitioned relations, this method is called for each selected partition, * and builds an `RDD[Row]` containing all rows within that single partition. * - * @param inputPaths For a non-partitioned relation, it contains paths of all data files in the + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the * relation. For a partitioned relation, it contains paths of all data files in a single * selected partition. * * @since 1.4.0 */ - def buildScan(inputPaths: Array[String]): RDD[Row] = { - throw new RuntimeException( + def buildScan(inputFiles: Array[FileStatus]): RDD[Row] = { + throw new UnsupportedOperationException( "At least one buildScan() method should be overridden to read the relation.") } @@ -468,13 +560,13 @@ abstract class FSBasedRelation private[sql]( * and builds an `RDD[Row]` containing all rows within that single partition. * * @param requiredColumns Required columns. - * @param inputPaths For a non-partitioned relation, it contains paths of all data files in the + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the * relation. For a partitioned relation, it contains paths of all data files in a single * selected partition. * * @since 1.4.0 */ - def buildScan(requiredColumns: Array[String], inputPaths: Array[String]): RDD[Row] = { + def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]): RDD[Row] = { // Yeah, to workaround serialization... val dataSchema = this.dataSchema val codegenEnabled = this.codegenEnabled @@ -484,7 +576,7 @@ abstract class FSBasedRelation private[sql]( BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable) }.toSeq - buildScan(inputPaths).mapPartitions { rows => + buildScan(inputFiles).mapPartitions { rows => val buildProjection = if (codegenEnabled) { GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) } else { @@ -506,7 +598,7 @@ abstract class FSBasedRelation private[sql]( * of all `filters`. The pushed down filters are currently purely an optimization as they * will all be evaluated again. This means it is safe to use them with methods that produce * false positives such as filtering partitions based on a bloom filter. - * @param inputPaths For a non-partitioned relation, it contains paths of all data files in the + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the * relation. For a partitioned relation, it contains paths of all data files in a single * selected partition. * @@ -515,27 +607,48 @@ abstract class FSBasedRelation private[sql]( def buildScan( requiredColumns: Array[String], filters: Array[Filter], - inputPaths: Array[String]): RDD[Row] = { - buildScan(requiredColumns, inputPaths) + inputFiles: Array[FileStatus]): RDD[Row] = { + buildScan(requiredColumns, inputFiles) } /** - * Client side preparation for data writing can be put here. For example, user defined output - * committer can be configured here. + * For a non-partitioned relation, this method builds an `RDD[Row]` containing all rows within + * this relation. For partitioned relations, this method is called for each selected partition, + * and builds an `RDD[Row]` containing all rows within that single partition. * - * Note that the only side effect expected here is mutating `job` via its setters. Especially, - * Spark SQL caches [[BaseRelation]] instances for performance, mutating relation internal states - * may cause unexpected behaviors. + * Note: This interface is subject to change in future. + * + * @param requiredColumns Required columns. + * @param filters Candidate filters to be pushed down. The actual filter should be the conjunction + * of all `filters`. The pushed down filters are currently purely an optimization as they + * will all be evaluated again. This means it is safe to use them with methods that produce + * false positives such as filtering partitions based on a bloom filter. + * @param inputFiles For a non-partitioned relation, it contains paths of all data files in the + * relation. For a partitioned relation, it contains paths of all data files in a single + * selected partition. + * @param broadcastedConf A shared broadcast Hadoop Configuration, which can be used to reduce the + * overhead of broadcasting the Configuration for every Hadoop RDD. * * @since 1.4.0 */ - def prepareForWrite(job: Job): Unit = () + private[sql] def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputFiles: Array[FileStatus], + broadcastedConf: Broadcast[SerializableWritable[Configuration]]): RDD[Row] = { + buildScan(requiredColumns, filters, inputFiles) + } /** - * This method is responsible for producing a new [[OutputWriter]] for each newly opened output - * file on the executor side. + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here + * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. + * + * Note that the only side effect expected here is mutating `job` via its setters. Especially, + * Spark SQL caches [[BaseRelation]] instances for performance, mutating relation internal states + * may cause unexpected behaviors. * * @since 1.4.0 */ - def outputWriterClass: Class[_ <: OutputWriter] + def prepareJobForWrite(job: Job): OutputWriterFactory } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala index aad1d248d0a28..a3fd7f13b3db7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -35,9 +35,9 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { // Wait until children are resolved. case p: LogicalPlan if !p.childrenResolved => p - // We are inserting into an InsertableRelation. + // We are inserting into an InsertableRelation or HadoopFsRelation. case i @ InsertIntoTable( - l @ LogicalRelation(r: InsertableRelation), partition, child, overwrite, ifNotExists) => { + l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation), _, child, _, _) => { // First, make sure the data to be inserted have the same number of fields with the // schema of the relation. if (l.output.size != child.output.size) { @@ -101,8 +101,20 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } } - case logical.InsertIntoTable(LogicalRelation(_: InsertableRelation), _, _, _, _) => // OK - case logical.InsertIntoTable(LogicalRelation(_: FSBasedRelation), _, _, _, _) => // OK + case logical.InsertIntoTable(LogicalRelation(r: HadoopFsRelation), part, _, _, _) => + // We need to make sure the partition columns specified by users do match partition + // columns of the relation. + val existingPartitionColumns = r.partitionColumns.fieldNames.toSet + val specifiedPartitionColumns = part.keySet + if (existingPartitionColumns != specifiedPartitionColumns) { + failAnalysis(s"Specified partition columns " + + s"(${specifiedPartitionColumns.mkString(", ")}) " + + s"do not match the partition columns of the table. Please use " + + s"(${existingPartitionColumns.mkString(", ")}) as the partition columns.") + } else { + // OK + } + case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/README.md b/sql/core/src/main/scala/org/apache/spark/sql/test/README.md new file mode 100644 index 0000000000000..d867f181b9728 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/README.md @@ -0,0 +1,7 @@ +README +====== + +Please do not add any class in this place unless it is used by `sql/console` or Python tests. +If you need to create any classes or traits that will be used by tests from both `sql/core` and +`sql/hive`, you can add them in the `src/test` of `sql/core` (tests of `sql/hive` +depend on the test jar of `sql/core`). diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index c344a9b095c52..fcb8f5499cf84 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -187,14 +187,14 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - DataFrame df1 = sqlContext.jsonRDD(jsonRDD); + DataFrame df1 = sqlContext.read().json(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); List actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - DataFrame df2 = sqlContext.jsonRDD(jsonRDD, expectedSchema); + DataFrame df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index b76f7d421f643..2706e01bd28af 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -67,7 +67,7 @@ public void setUp() throws IOException { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } JavaRDD rdd = sc.parallelize(jsonObjects); - df = sqlContext.jsonRDD(rdd); + df = sqlContext.read().json(rdd); df.registerTempTable("jsonTable"); } @@ -75,10 +75,8 @@ public void setUp() throws IOException { public void saveAndLoad() { Map options = new HashMap(); options.put("path", path.toString()); - df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); - - DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", options); - + df.write().mode(SaveMode.ErrorIfExists).format("json").options(options).save(); + DataFrame loadedDF = sqlContext.read().format("json").options(options).load(); checkAnswer(loadedDF, df.collectAsList()); } @@ -86,12 +84,12 @@ public void saveAndLoad() { public void saveAndLoadWithSchema() { Map options = new HashMap(); options.put("path", path.toString()); - df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options); + df.write().format("json").mode(SaveMode.ErrorIfExists).options(options).save(); List fields = new ArrayList(); fields.add(DataTypes.createStructField("b", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); - DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", schema, options); + DataFrame loadedDF = sqlContext.read().format("json").schema(schema).options(options).load(); checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList()); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 269e185543059..bfba379d9a518 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.scalatest.Matchers._ +import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -27,6 +28,72 @@ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ + test("alias") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + assert(df.select(df("a").as("b")).columns.head === "b") + assert(df.select(df("a").alias("b")).columns.head === "b") + } + + test("single explode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + checkAnswer( + df.select(explode('intList)), + Row(1) :: Row(2) :: Row(3) :: Nil) + } + + test("explode and other columns") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + + checkAnswer( + df.select($"a", explode('intList)), + Row(1, 1) :: + Row(1, 2) :: + Row(1, 3) :: Nil) + + checkAnswer( + df.select($"*", explode('intList)), + Row(1, Seq(1, 2, 3), 1) :: + Row(1, Seq(1, 2, 3), 2) :: + Row(1, Seq(1, 2, 3), 3) :: Nil) + } + + test("aliased explode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + + checkAnswer( + df.select(explode('intList).as('int)).select('int), + Row(1) :: Row(2) :: Row(3) :: Nil) + + checkAnswer( + df.select(explode('intList).as('int)).select(sum('int)), + Row(6) :: Nil) + } + + test("explode on map") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map)), + Row("a", "b")) + } + + test("explode on map with aliases") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), + Row("a", "b")) + } + + test("self join explode") { + val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") + val exploded = df.select(explode('intList).as('i)) + + checkAnswer( + exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")), + Row(3) :: Nil) + } + test("collect on column produced by a binary operator") { val df = Seq((1, 2, 3)).toDF("a", "b", "c") checkAnswer(df.select(df("a") + df("b")), Seq(Row(3))) @@ -386,13 +453,51 @@ class ColumnExpressionSuite extends QueryTest { } test("rand") { - val randCol = testData.select('key, rand(5L).as("rand")) + val randCol = testData.select($"key", rand(5L).as("rand")) randCol.columns.length should be (2) val rows = randCol.collect() rows.foreach { row => assert(row.getDouble(1) <= 1.0) assert(row.getDouble(1) >= 0.0) } + + def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { + val projects = df.queryExecution.executedPlan.collect { + case project: Project => project + } + assert(projects.size === expectedNumProjects) + } + + // We first create a plan with two Projects. + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // Because Rand function is not deterministic, the column rand is not deterministic. + // So, in the optimizer, we will not collapse Project [rand + 1 AS rand1, rand - 1 AS rand2] + // and Project [key, Rand 5 AS rand]. The final plan still has two Projects. + val dfWithTwoProjects = + testData + .select($"key", (rand(5L) + 1).as("rand")) + .select(($"rand" + 1).as("rand1"), ($"rand" - 1).as("rand2")) + checkNumProjects(dfWithTwoProjects, 2) + + // Now, we add one more project rand1 - rand2 on top of the query plan. + // Since rand1 and rand2 are deterministic (they basically apply +/- to the generated + // rand value), we can collapse rand1 - rand2 to the Project generating rand1 and rand2. + // So, the plan will be optimized from ... + // Project [(rand1 - rand2) AS (rand1 - rand2)] + // Project [rand + 1 AS rand1, rand - 1 AS rand2] + // Project [key, (Rand 5 + 1) AS rand] + // LogicalRDD [key, value] + // to ... + // Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)] + // Project [key, Rand 5 AS rand] + // LogicalRDD [key, value] + val dfWithThreeProjects = dfWithTwoProjects.select($"rand1" - $"rand2") + checkNumProjects(dfWithThreeProjects, 2) + dfWithThreeProjects.collect().foreach { row => + assert(row.getDouble(0) === 2.0 +- 0.0001) + } } test("randn") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 35a574f354741..232f05c00918f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -148,12 +148,12 @@ class DataFrameAggregateSuite extends QueryTest { test("null count") { checkAnswer( testData3.groupBy('a).agg(count('b)), - Seq(Row(1,0), Row(2, 1)) + Seq(Row(1, 0), Row(2, 1)) ) checkAnswer( testData3.groupBy('a).agg(count('a + 'b)), - Seq(Row(1,0), Row(2, 1)) + Seq(Row(1, 0), Row(2, 1)) ) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b1845a9180c..438f479459dfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql -import org.scalatest.FunSuite import org.scalatest.Matchers._ +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ -class DataFrameStatSuite extends FunSuite { - +class DataFrameStatSuite extends SparkFunSuite { + val sqlCtx = TestSQLContext def toLetter(i: Int): String = (i + 97).toChar.toString @@ -74,10 +74,10 @@ class DataFrameStatSuite extends FunSuite { val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) assert(rows(0).get(0).toString === "0") assert(rows(0).getLong(1) === 2L) - assert(rows(0).get(2) === null) + assert(rows(0).get(2) === 0L) assert(rows(1).get(0).toString === "1") assert(rows(1).getLong(1) === 1L) - assert(rows(1).get(2) === null) + assert(rows(1).get(2) === 0L) assert(rows(2).get(0).toString === "2") assert(rows(2).getLong(1) === 2L) assert(rows(2).getLong(2) === 1L) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 52aa1f6558f80..a4fd1058afce5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -59,7 +59,7 @@ class DataFrameSuite extends QueryTest { } test("rename nested groupby") { - val df = Seq((1,(1,1))).toDF() + val df = Seq((1, (1, 1))).toDF() checkAnswer( df.groupBy("_1").agg(sum("_2._1")).toDF("key", "total"), @@ -211,23 +211,23 @@ class DataFrameSuite extends QueryTest { test("global sorting") { checkAnswer( testData2.orderBy('a.asc, 'b.asc), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( testData2.orderBy(asc("a"), desc("b")), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( testData2.orderBy('a.asc, 'b.desc), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( testData2.orderBy('a.desc, 'b.desc), - Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) + Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( testData2.orderBy('a.desc, 'b.asc), - Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( arrayData.toDF().orderBy('data.getItem(0).asc), @@ -331,7 +331,7 @@ class DataFrameSuite extends QueryTest { checkAnswer( df, testData.collect().toSeq) - assert(df.schema.map(_.name) === Seq("key","value")) + assert(df.schema.map(_.name) === Seq("key", "value")) } test("withColumnRenamed") { @@ -364,30 +364,35 @@ class DataFrameSuite extends QueryTest { test("describe") { val describeTestData = Seq( - ("Bob", 16, 176), + ("Bob", 16, 176), ("Alice", 32, 164), ("David", 60, 192), - ("Amy", 24, 180)).toDF("name", "age", "height") + ("Amy", 24, 180)).toDF("name", "age", "height") val describeResult = Seq( - Row("count", 4, 4), - Row("mean", 33.0, 178.0), - Row("stddev", 16.583123951777, 10.0), - Row("min", 16, 164), - Row("max", 60, 192)) + Row("count", "4", "4"), + Row("mean", "33.0", "178.0"), + Row("stddev", "16.583123951777", "10.0"), + Row("min", "16", "164"), + Row("max", "60", "192")) val emptyDescribeResult = Seq( - Row("count", 0, 0), - Row("mean", null, null), - Row("stddev", null, null), - Row("min", null, null), - Row("max", null, null)) + Row("count", "0", "0"), + Row("mean", null, null), + Row("stddev", null, null), + Row("min", null, null), + Row("max", null, null)) def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) val describeTwoCols = describeTestData.describe("age", "height") assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) checkAnswer(describeTwoCols, describeResult) + // All aggregate value should have been cast to string + describeTwoCols.collect().foreach { row => + assert(row.get(1).isInstanceOf[String], "expected string but found " + row.get(1).getClass) + assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass) + } val describeAllCols = describeTestData.describe() assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) @@ -459,6 +464,33 @@ class DataFrameSuite extends QueryTest { assert(complexData.filter(complexData("m")(complexData("s")("value")) === 1).count() == 1) } + test("SPARK-7551: support backticks for DataFrame attribute resolution") { + val df = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( + """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) + checkAnswer( + df.select(df("`a.b`.c.`d..e`.`f`")), + Row(1) + ) + + val df2 = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( + """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) + checkAnswer( + df2.select(df2("`a b`.c.d e.f")), + Row(1) + ) + + def checkError(testFun: => Unit): Unit = { + val e = intercept[org.apache.spark.sql.AnalysisException] { + testFun + } + assert(e.getMessage.contains("syntax error in attribute name:")) + } + checkError(df("`abc.`c`")) + checkError(df("`abc`..d")) + checkError(df("`a`.b.")) + checkError(df("`a.b`.c.`d")) + } + test("SPARK-7324 dropDuplicates") { val testData = TestSQLContext.sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: @@ -505,4 +537,44 @@ class DataFrameSuite extends QueryTest { val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project] assert(!p.child.isInstanceOf[Project]) } + + test("SPARK-7150 range api") { + // numSlice is greater than length + val res1 = TestSQLContext.range(0, 10, 1, 15).select("id") + assert(res1.count == 10) + assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res2 = TestSQLContext.range(3, 15, 3, 2).select("id") + assert(res2.count == 4) + assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) + + val res3 = TestSQLContext.range(1, -2).select("id") + assert(res3.count == 0) + + // start is positive, end is negative, step is negative + val res4 = TestSQLContext.range(1, -2, -2, 6).select("id") + assert(res4.count == 2) + assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) + + // start, end, step are negative + val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id") + assert(res5.count == 3) + assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) + + // start, end are negative, step is positive + val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id") + assert(res6.count == 2) + assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) + + val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id") + assert(res7.count == 0) + + val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + assert(res8.count == 3) + assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) + + val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + assert(res9.count == 2) + assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 037d392c1f929..407c789657834 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -167,10 +167,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val y = testData2.where($"a" === 1).as("y") checkAnswer( x.join(y).where($"x.a" === $"y.a"), - Row(1,1,1,1) :: - Row(1,1,1,2) :: - Row(1,2,1,1) :: - Row(1,2,1,2) :: Nil + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index f9f41eb358bd5..3ce97c3fffdb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -28,7 +28,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { import org.apache.spark.sql.test.TestSQLContext.implicits._ val df = - sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") before { df.registerTempTable("ListTablesSuiteTable") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index c4281c4b55c02..dd68965444f5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -206,7 +206,7 @@ class MathExpressionsSuite extends QueryTest { } test("log") { - testOneToOneNonNegativeMathFunction(log, math.log) + testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) } test("log10") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index bbf9ab113ca43..98ba3c99283a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -67,6 +67,10 @@ class QueryTest extends PlanTest { checkAnswer(df, Seq(expectedAnswer)) } + protected def checkAnswer(df: DataFrame, expectedAnswer: DataFrame): Unit = { + checkAnswer(df, expectedAnswer.collect()) + } + def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) { test(sqlString) { checkAnswer(sqlContext.sql(sqlString), expectedAnswer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index fb3ba4bc1b908..513ac915dcb2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer -import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ -class RowSuite extends FunSuite { +class RowSuite extends SparkFunSuite { test("create row") { val expected = new GenericMutableRow(4) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index bf73d0c7074a5..3a5f071e2f7cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql -import org.scalatest.FunSuiteLike - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ -class SQLConfSuite extends QueryTest with FunSuiteLike { +class SQLConfSuite extends QueryTest { val testKey = "test.key.0" val testVal = "test.val.0" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala new file mode 100644 index 0000000000000..797d123b48668 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -0,0 +1,50 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.TestSQLContext + +class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { + + private val testSqlContext = TestSQLContext + private val testSparkContext = TestSQLContext.sparkContext + + override def afterAll(): Unit = { + SQLContext.setLastInstantiatedContext(testSqlContext) + } + + test("getOrCreate instantiates SQLContext") { + SQLContext.clearLastInstantiatedContext() + val sqlContext = SQLContext.getOrCreate(testSparkContext) + assert(sqlContext != null, "SQLContext.getOrCreate returned null") + assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext), + "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") + } + + test("getOrCreate gets last explicitly instantiated SQLContext") { + SQLContext.clearLastInstantiatedContext() + val sqlContext = new SQLContext(testSparkContext) + assert(SQLContext.getOrCreate(testSparkContext) != null, + "SQLContext.getOrCreate after explicitly created SQLContext returned null") + assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext), + "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ec0e76cde6f7c..63f7d314fb699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} import org.apache.spark.sql.types._ @@ -32,15 +32,28 @@ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ class MyDialect extends DefaultParserDialect -class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { +class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Make sure the tables are loaded. TestData - import org.apache.spark.sql.test.TestSQLContext.implicits._ - val sqlCtx = TestSQLContext + val sqlContext = TestSQLContext + import sqlContext.implicits._ + + test("SPARK-6743: no columns from cache") { + Seq( + (83, 0, 38), + (26, 0, 79), + (43, 81, 24) + ).toDF("a", "b", "c").registerTempTable("cachedData") + + cacheTable("cachedData") + checkAnswer( + sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), + Row(0) :: Row(81) :: Nil) + } test("self join with aliases") { - Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") + Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str").registerTempTable("df") checkAnswer( sql( @@ -63,7 +76,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("self join with alias in agg") { - Seq(1,2,3) + Seq(1, 2, 3) .map(i => (i, i.toString)) .toDF("int", "str") .groupBy("str") @@ -100,12 +113,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") { checkAnswer( sql("SELECT a FROM testData2 SORT BY a"), - Seq(1, 1, 2 ,2 ,3 ,3).map(Row(_)) + Seq(1, 1, 2, 2, 3, 3).map(Row(_)) ) } test("grouping on nested fields") { - jsonRDD(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) + read.json(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") checkAnswer( @@ -122,7 +135,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6201 IN type conversion") { - jsonRDD(sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) + read.json( + sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") checkAnswer( @@ -141,7 +155,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT ABS(2.5)"), Row(2.5)) } - + test("aggregation with codegen") { val originalValue = conf.codegenEnabled setConf(SQLConf.CODEGEN_ENABLED, "true") @@ -193,7 +207,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { "SELECT value, sum(key) FROM testData3x GROUP BY value", (1 to 100).map(i => Row(i.toString, 3 * i))) testCodeGen( - "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", Row(5050 * 3, 5050 * 3.0) :: Nil) // AVERAGE testCodeGen( @@ -297,6 +311,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-3173 Timestamp support in the parser") { + checkAnswer(sql( + "SELECT time FROM timestamps WHERE time='1969-12-31 16:00:00.0'"), + Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00"))) + checkAnswer(sql( "SELECT time FROM timestamps WHERE time=CAST('1969-12-31 16:00:00.001' AS TIMESTAMP)"), Row(java.sql.Timestamp.valueOf("1969-12-31 16:00:00.001"))) @@ -336,7 +354,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("left semi greater than predicate") { checkAnswer( sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), - Seq(Row(3,1), Row(3,2)) + Seq(Row(3, 1), Row(3, 2)) ) } @@ -353,16 +371,16 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("agg") { checkAnswer( sql("SELECT a, SUM(b) FROM testData2 GROUP BY a"), - Seq(Row(1,3), Row(2,3), Row(3,3))) + Seq(Row(1, 3), Row(2, 3), Row(3, 3))) } test("literal in agg grouping expressions") { checkAnswer( sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1,2), Row(2,2), Row(3,2))) + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) checkAnswer( sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1,2), Row(2,2), Row(3,2))) + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) } test("aggregates with nulls") { @@ -387,19 +405,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { def sortTest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b DESC"), - Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + Seq(Row(1, 2), Row(1, 1), Row(2, 2), Row(2, 1), Row(3, 2), Row(3, 1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b DESC"), - Seq(Row(3,2), Row(3,1), Row(2,2), Row(2,1), Row(1,2), Row(1,1))) + Seq(Row(3, 2), Row(3, 1), Row(2, 2), Row(2, 1), Row(1, 2), Row(1, 1))) checkAnswer( sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"), - Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2))) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2))) checkAnswer( sql("SELECT b FROM binaryData ORDER BY a ASC"), @@ -534,7 +552,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("average overflow") { checkAnswer( sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), - Seq(Row(2147483645.0,1), Row(2.0,2))) + Seq(Row(2147483645.0, 1), Row(2.0, 2))) } test("count") { @@ -601,10 +619,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | (SELECT * FROM testData2 WHERE a = 1) x JOIN | (SELECT * FROM testData2 WHERE a = 1) y |WHERE x.a = y.a""".stripMargin), - Row(1,1,1,1) :: - Row(1,1,1,2) :: - Row(1,2,1,1) :: - Row(1,2,1,2) :: Nil) + Row(1, 1, 1, 1) :: + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil) } test("inner join, no matches") { @@ -897,7 +915,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = sqlCtx.createDataFrame(rowRDD1, schema1) + val df1 = createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -927,7 +945,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = sqlCtx.createDataFrame(rowRDD2, schema2) + val df2 = createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -952,7 +970,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = sqlCtx.createDataFrame(rowRDD3, schema2) + val df3 = createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -997,7 +1015,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1195,7 +1213,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-3483 Special chars in column names") { val data = sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - jsonRDD(data).registerTempTable("records") + read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1236,11 +1254,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-4322 Grouping field with struct field as sub expression") { - jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") + read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) dropTempTable("data") - jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + read.json(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) dropTempTable("data") } @@ -1248,22 +1266,22 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { checkAnswer( sql("SELECT a + b FROM testData2 ORDER BY a"), - Seq(2, 3, 3 ,4 ,4 ,5).map(Row(_)) + Seq(2, 3, 3, 4, 4, 5).map(Row(_)) ) } test("oder by asc by default when not specify ascending and descending") { checkAnswer( sql("SELECT a, b FROM testData2 ORDER BY a desc, b"), - Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2,2), Row(1, 1), Row(1, 2)) + Seq(Row(3, 1), Row(3, 2), Row(2, 1), Row(2, 2), Row(1, 1), Row(1, 2)) ) } test("Supporting relational operator '<=>' in Spark SQL") { - val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil + val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") - val nullCheckData2 = TestData(1,"1") :: TestData(2,null) :: Nil + val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + @@ -1272,7 +1290,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("Multi-column COUNT(DISTINCT ...)") { - val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil + val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) @@ -1288,7 +1306,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6145: ORDER BY test for nested fields") { - jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + read.json(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) @@ -1300,17 +1318,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("SPARK-6145: special cases") { - jsonRDD(sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - jsonRDD(sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + + test("SPARK-7952: fix the equality check between boolean and numeric types") { + withTempTable("t") { + // numeric field i, boolean field j, result of i = j, result of i <=> j + Seq[(Integer, java.lang.Boolean, java.lang.Boolean, java.lang.Boolean)]( + (1, true, true, true), + (0, false, true, true), + (2, true, false, false), + (2, false, false, false), + (null, true, null, false), + (null, false, null, false), + (0, null, null, false), + (1, null, null, false), + (null, null, null, true) + ).toDF("i", "b", "r1", "r2").registerTempTable("t") + + checkAnswer(sql("select i = b from t"), sql("select r1 from t")) + checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 3fa00fd9d0ccb..d2ede39f0a5f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.TestSQLContext._ @@ -74,20 +73,20 @@ case class ComplexReflectData( mapFieldContainsNull: Map[Int, Option[Long]], dataField: Data) -class ScalaReflectionRelationSuite extends FunSuite { +class ScalaReflectionRelationSuite extends SparkFunSuite { import org.apache.spark.sql.test.TestSQLContext.implicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)) + new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1, 2, 3)) val rdd = sparkContext.parallelize(data :: Nil) rdd.toDF().registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), - new Timestamp(12345), Seq(1,2,3))) + new Timestamp(12345), Seq(1, 2, 3))) } test("query case class RDD with nulls") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index 6f6d3c9c243d4..1e8cde606b67b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.test.TestSQLContext -class SerializationSuite extends FunSuite { +class SerializationSuite extends SparkFunSuite { test("[SPARK-5235] SQLContext should be serializable") { val sqlContext = new SQLContext(TestSQLContext.sparkContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 446771ab2a5a5..725a18bfae3a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -109,8 +109,8 @@ object TestData { case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) val arrayData = TestSQLContext.sparkContext.parallelize( - ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: - ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) + ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: + ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) arrayData.toDF().registerTempTable("arrayData") case class MapData(data: scala.collection.Map[Int, String]) @@ -175,7 +175,7 @@ object TestData { "4, D4, true, 2147483644" :: Nil) case class TimestampField(time: Timestamp) - val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => + val timestamps = TestSQLContext.sparkContext.parallelize((0 to 3).map { i => TimestampField(new Timestamp(i)) }) timestamps.toDF().registerTempTable("timestamps") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d615542ab50a7..1a9ba66416b21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -38,7 +38,7 @@ class UDFSuite extends QueryTest { } test("TwoArgument UDF") { - udf.register("strLenScala", (_: String).length + (_:Int)) + udf.register("strLenScala", (_: String).length + (_: Int)) assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 2672e20deadc5..dc2d43a197f40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -105,13 +105,13 @@ class UserDefinedTypeSuite extends QueryTest { test("UDTs with Parquet") { val tempDir = Utils.createTempDir() tempDir.delete() - pointsRDD.saveAsParquetFile(tempDir.getCanonicalPath) + pointsRDD.write.parquet(tempDir.getCanonicalPath) } test("Repartition UDTs with Parquet") { val tempDir = Utils.createTempDir() tempDir.delete() - pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath) + pointsRDD.repartition(1).write.parquet(tempDir.getCanonicalPath) } // Tests to make sure that all operators correctly convert types on the way out. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 7cefcf44061ce..339e719f39f16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.columnar -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.types._ -class ColumnStatsSuite extends FunSuite { +class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0)) testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0)) testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 1e105e259dce7..a1e76eaa982cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -23,15 +23,14 @@ import java.sql.Timestamp import com.esotericsoftware.kryo.{Serializer, Kryo} import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.serializer.KryoRegistrator -import org.scalatest.FunSuite -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ -class ColumnTypeSuite extends FunSuite with Logging { +class ColumnTypeSuite extends SparkFunSuite with Logging { val DEFAULT_BUFFER_SIZE = 512 test("defaultSize") { @@ -73,7 +72,7 @@ class ColumnTypeSuite extends FunSuite with Logging { checkActualSize(TIMESTAMP, new Timestamp(0L), 12) val binary = Array.fill[Byte](4)(0: Byte) - checkActualSize(BINARY, binary, 4 + 4) + checkActualSize(BINARY, binary, 4 + 4) val generic = Map(1 -> "a") checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) @@ -167,7 +166,7 @@ class ColumnTypeSuite extends FunSuite with Logging { val serializer = new SparkSqlSerializer(conf).newInstance() val buffer = ByteBuffer.allocate(512) - val obj = CustomClass(Int.MaxValue,Long.MaxValue) + val obj = CustomClass(Int.MaxValue, Long.MaxValue) val serializedObj = serializer.serialize(obj).array() GENERIC.append(serializer.serialize(obj).array(), buffer) @@ -278,7 +277,7 @@ private[columnar] object CustomerSerializer extends Serializer[CustomClass] { override def read(kryo: Kryo, input: Input, aClass: Class[CustomClass]): CustomClass = { val a = input.readInt() val b = input.readLong() - CustomClass(a,b) + CustomClass(a, b) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 56591d9dba29e..055453e688e73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -173,7 +173,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { new Timestamp(i), (1 to i).toSeq, (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, - Row((i - 0.25).toFloat, (1 to i).toSeq)) + Row((i - 0.25).toFloat, Seq(true, false, null))) } createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index a0702144f942c..2a6e0c376551a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types.DataType @@ -39,7 +38,7 @@ object TestNullableColumnAccessor { } } -class NullableColumnAccessorSuite extends FunSuite { +class NullableColumnAccessorSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 3a5605d2335d7..cb4e9f1eb7f46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.columnar -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ @@ -35,7 +34,7 @@ object TestNullableColumnBuilder { } } -class NullableColumnBuilderSuite extends FunSuite { +class NullableColumnBuilderSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index 2a0b701cad7fa..cda1b0992e36f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.columnar -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ -class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { +class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { val originalColumnBatchSize = conf.columnBatchSize val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index 8b518f094174c..20d65a74e3b7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ -class BooleanBitSetSuite extends FunSuite { +class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ def skeleton(count: Int) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index 64b70552eb047..acfab6586c0d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -19,16 +19,15 @@ package org.apache.spark.sql.columnar.compression import java.nio.ByteBuffer -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType -class DictionaryEncodingSuite extends FunSuite { - testDictionaryEncoding(new IntColumnStats, INT) - testDictionaryEncoding(new LongColumnStats, LONG) +class DictionaryEncodingSuite extends SparkFunSuite { + testDictionaryEncoding(new IntColumnStats, INT) + testDictionaryEncoding(new LongColumnStats, LONG) testDictionaryEncoding(new StringColumnStats, STRING) def testDictionaryEncoding[T <: AtomicType]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index bfd99f143bedc..2111e9fbe62cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.IntegralType -class IntegralDeltaSuite extends FunSuite { - testIntegralDelta(new IntColumnStats, INT, IntDelta) +class IntegralDeltaSuite extends SparkFunSuite { + testIntegralDelta(new IntColumnStats, INT, IntDelta) testIntegralDelta(new LongColumnStats, LONG, LongDelta) def testIntegralDelta[I <: IntegralType]( @@ -116,7 +115,7 @@ class IntegralDeltaSuite extends FunSuite { test(s"$scheme: simple case") { val input = columnType match { - case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) + case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index fde7a4595be0e..67ec08f594a43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql.columnar.compression -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types.AtomicType -class RunLengthEncodingSuite extends FunSuite { +class RunLengthEncodingSuite extends SparkFunSuite { testRunLengthEncoding(new NoopColumnStats, BOOLEAN) - testRunLengthEncoding(new ByteColumnStats, BYTE) - testRunLengthEncoding(new ShortColumnStats, SHORT) - testRunLengthEncoding(new IntColumnStats, INT) - testRunLengthEncoding(new LongColumnStats, LONG) - testRunLengthEncoding(new StringColumnStats, STRING) + testRunLengthEncoding(new ByteColumnStats, BYTE) + testRunLengthEncoding(new ShortColumnStats, SHORT) + testRunLengthEncoding(new IntColumnStats, INT) + testRunLengthEncoding(new LongColumnStats, LONG) + testRunLengthEncoding(new StringColumnStats, STRING) def testRunLengthEncoding[T <: AtomicType]( columnStats: ColumnStats, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 523be56df65ba..45a7e8fe68f72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SQLConf, execution} import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ @@ -31,7 +30,7 @@ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ -class PlannerSuite extends FunSuite { +class PlannerSuite extends SparkFunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 15337c4045436..6ca5390cde23e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -19,17 +19,17 @@ package org.apache.spark.sql.execution import java.sql.{Timestamp, Date} -import org.scalatest.{FunSuite, BeforeAndAfterAll} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.serializer.Serializer -import org.apache.spark.ShuffleDependency +import org.apache.spark.{ShuffleDependency, SparkFunSuite} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} -class SparkSqlSerializer2DataTypeSuite extends FunSuite { +class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { // Make sure that we will not use serializer2 for unsupported data types. def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { val testName = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 358d8cf06e463..8ec3985e00360 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.debug -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ -class DebuggingSuite extends FunSuite { +class DebuggingSuite extends SparkFunSuite { test("DataFrame.debug()") { testData.debug() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 2aad01ded1acf..5290c28cfca02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.joins -import org.scalatest.FunSuite - +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Projection, Row} import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends FunSuite { +class HashedRelationSuite extends SparkFunSuite { // Key is simply the record itself private val keyProjection = new Projection { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 2abfe7f167f77..e20c66cb2f1d7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -21,20 +21,28 @@ import java.math.BigDecimal import java.sql.DriverManager import java.util.{Calendar, GregorianCalendar, Properties} +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ import org.h2.jdbc.JdbcSQLException -import org.scalatest.{FunSuite, BeforeAndAfter} +import org.scalatest.BeforeAndAfter import TestSQLContext._ import TestSQLContext.implicits._ -class JDBCSuite extends FunSuite with BeforeAndAfter { +class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) + val testH2Dialect = new JdbcDialect { + def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + Some(StringType) + } + before { Class.forName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test @@ -61,6 +69,14 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + sql( + s""" + |CREATE TEMPORARY TABLE fetchtwo + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass', + | fetchSize '2') + """.stripMargin.replaceAll("\n", " ")) + sql( s""" |CREATE TEMPORARY TABLE parts @@ -178,6 +194,14 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(names(2).equals("mary")) } + test("SELECT first field when fetchSize is two") { + val names = sql("SELECT NAME FROM fetchtwo").collect().map(x => x.getString(0)).sortWith(_ < _) + assert(names.size === 3) + assert(names(0).equals("fred")) + assert(names(1).equals("joe 'foo' \"bar\"")) + assert(names(2).equals("mary")) + } + test("SELECT second field") { val ids = sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _) assert(ids.size === 3) @@ -186,6 +210,14 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { assert(ids(2) === 3) } + test("SELECT second field when fetchSize is two") { + val ids = sql("SELECT THEID FROM fetchtwo").collect().map(x => x.getInt(0)).sortWith(_ < _) + assert(ids.size === 3) + assert(ids(0) === 1) + assert(ids(1) === 2) + assert(ids(2) === 3) + } + test("SELECT * partitioned") { assert(sql("SELECT * FROM parts").collect().size == 3) } @@ -221,22 +253,32 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("Basic API") { - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect().size === 3) + assert(TestSQLContext.read.jdbc( + urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) + } + + test("Basic API with FetchSize") { + val properties = new Properties + properties.setProperty("fetchSize", "2") + assert(TestSQLContext.read.jdbc( + urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3) - .collect.size === 3) + assert( + TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect().size === 3) + assert(TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + .collect().length === 3) } test("H2 integral types") { val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows(0).getInt(0) === 1) assert(rows(0).getBoolean(1) === false) assert(rows(0).getInt(2) === 3) @@ -246,7 +288,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { test("H2 null entries") { val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect() - assert(rows.size === 1) + assert(rows.length === 1) assert(rows(0).isNullAt(0)) assert(rows(0).isNullAt(1)) assert(rows(0).isNullAt(2)) @@ -286,24 +328,28 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { } test("test DATE types") { - val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").collect() - val cachedRows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").cache().collect() + val rows = TestSQLContext.read.jdbc( + urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + val cachedRows = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) } test("test DATE types in cache") { - val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.TIMETYPES").collect() - TestSQLContext - .jdbc(urlWithUserAndPass, "TEST.TIMETYPES").cache().registerTempTable("mycached_date") + val rows = + TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(cachedRows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) } test("test types for null value") { - val rows = TestSQLContext.jdbc(urlWithUserAndPass, "TEST.NULLTYPES").collect() + val rows = TestSQLContext.read.jdbc( + urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -346,4 +392,46 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { """.stripMargin.replaceAll("\n", " ")) } } + + test("Remap types via JdbcDialects") { + JdbcDialects.registerDialect(testH2Dialect) + val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + assert(df.schema.filter( + _.dataType != org.apache.spark.sql.types.StringType + ).isEmpty) + val rows = df.collect() + assert(rows(0).get(0).isInstanceOf[String]) + assert(rows(0).get(1).isInstanceOf[String]) + JdbcDialects.unregisterDialect(testH2Dialect) + } + + test("Default jdbc dialect registration") { + assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect) + assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect) + assert(JdbcDialects.get("test.invalid") == NoopDialect) + } + + test("Dialect unregister") { + JdbcDialects.registerDialect(testH2Dialect) + JdbcDialects.unregisterDialect(testH2Dialect) + assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect) + } + + test("Aggregated dialects") { + val agg = new AggregatedDialect(List(new JdbcDialect { + def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = + if (sqlType % 2 == 0) { + Some(LongType) + } else { + None + } + }, testH2Dialect)) + assert(agg.canHandle("jdbc:h2:xxx")) + assert(!agg.canHandle("jdbc:h2")) + assert(agg.getCatalystType(0, "", 1, null) == Some(LongType)) + assert(agg.getCatalystType(1, "", 1, null) == Some(StringType)) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index f3ce8e66460e5..2de8c1a6098e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.Row +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{SaveMode, Row} import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ -class JDBCWriteSuite extends FunSuite with BeforeAndAfter { +class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null val url1 = "jdbc:h2:mem:testdb3" @@ -35,14 +36,37 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { properties.setProperty("user", "testUser") properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") - + before { Class.forName("org.h2.Driver") conn = DriverManager.getConnection(url) conn.prepareStatement("create schema test").executeUpdate() - + conn1 = DriverManager.getConnection(url1, properties) conn1.prepareStatement("create schema test").executeUpdate() + conn1.prepareStatement("drop table if exists test.people").executeUpdate() + conn1.prepareStatement( + "create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + conn1.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() + conn1.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate() + conn1.prepareStatement("drop table if exists test.people1").executeUpdate() + conn1.prepareStatement( + "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() + conn1.commit() + + TestSQLContext.sql( + s""" + |CREATE TEMPORARY TABLE PEOPLE + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + + TestSQLContext.sql( + s""" + |CREATE TEMPORARY TABLE PEOPLE1 + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) } after { @@ -67,52 +91,66 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { test("Basic CREATE") { val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - df.createJDBCTable(url, "TEST.BASICCREATETEST", false) - assert(2 == TestSQLContext.jdbc(url, "TEST.BASICCREATETEST").count) - assert(2 == TestSQLContext.jdbc(url, "TEST.BASICCREATETEST").collect()(0).length) + df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) + assert(2 == TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert(2 == + TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) - df.createJDBCTable(url1, "TEST.DROPTEST", false, properties) - assert(2 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + df.write.jdbc(url1, "TEST.DROPTEST", properties) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) - df2.createJDBCTable(url1, "TEST.DROPTEST", true, properties) - assert(1 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) + assert(1 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) - df.createJDBCTable(url, "TEST.APPENDTEST", false) - df2.insertIntoJDBC(url, "TEST.APPENDTEST", false) - assert(3 == TestSQLContext.jdbc(url, "TEST.APPENDTEST").count) - assert(2 == TestSQLContext.jdbc(url, "TEST.APPENDTEST").collect()(0).length) + df.write.jdbc(url, "TEST.APPENDTEST", new Properties) + df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) + assert(3 == TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 == + TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) - df.createJDBCTable(url1, "TEST.TRUNCATETEST", false, properties) - df2.insertIntoJDBC(url1, "TEST.TRUNCATETEST", true, properties) - assert(1 == TestSQLContext.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 == TestSQLContext.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) + df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) + assert(1 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) - df.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false) + df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { - df2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true) + df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) } } + test("INSERT to JDBC Datasource") { + TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + } + + test("INSERT to JDBC Datasource with overwrite") { + TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 263fafba930ce..f8d62f9e7e02b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -24,7 +24,7 @@ import com.fasterxml.jackson.core.JsonFactory import org.scalactic.Tolerance._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.functions._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.json.InferSchema.compatibleType import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.test.TestSQLContext @@ -214,7 +214,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring with null in sampling") { - val jsonDF = jsonRDD(jsonNullStruct) + val jsonDF = read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -233,7 +233,7 @@ class JsonSuite extends QueryTest { } test("Primitive field and type inferring") { - val jsonDF = jsonRDD(primitiveFieldAndType) + val jsonDF = read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -261,7 +261,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring") { - val jsonDF = jsonRDD(complexFieldAndType1) + val jsonDF = read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -360,7 +360,7 @@ class JsonSuite extends QueryTest { } test("GetField operation on complex data type") { - val jsonDF = jsonRDD(complexFieldAndType1) + val jsonDF = read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -376,7 +376,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in primitive field values") { - val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) + val jsonDF = read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -450,7 +450,7 @@ class JsonSuite extends QueryTest { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) + val jsonDF = read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -503,7 +503,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in complex field values") { - val jsonDF = jsonRDD(complexFieldValueTypeConflict) + val jsonDF = read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -522,12 +522,12 @@ class JsonSuite extends QueryTest { Row(Seq(), "11", "[1,2,3]", Row(null), "[]") :: Row(null, """{"field":false}""", null, null, "{}") :: Row(Seq(4, 5, 6), null, "str", Row(null), "[7,8,9]") :: - Row(Seq(7), "{}","""["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil + Row(Seq(7), "{}", """["str1","str2",33]""", Row("str"), """{"field":true}""") :: Nil ) } test("Type conflict in array elements") { - val jsonDF = jsonRDD(arrayElementTypeConflict) + val jsonDF = read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -555,7 +555,7 @@ class JsonSuite extends QueryTest { } test("Handling missing fields") { - val jsonDF = jsonRDD(missingFields) + val jsonDF = read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -575,7 +575,7 @@ class JsonSuite extends QueryTest { dir.delete() val path = dir.getCanonicalPath sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = jsonFile(path, 0.49) + val jsonDF = read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -590,7 +590,7 @@ class JsonSuite extends QueryTest { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - jsonFile(path, schema).queryExecution.analyzed.asInstanceOf[LogicalRelation] + read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.path === Some(path)) assert(relationWithSchema.schema === schema) @@ -602,7 +602,7 @@ class JsonSuite extends QueryTest { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = jsonFile(path) + val jsonDF = read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -671,7 +671,7 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = jsonFile(path, schema) + val jsonDF1 = read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -688,7 +688,7 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val jsonDF2 = jsonRDD(primitiveFieldAndType, schema) + val jsonDF2 = read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -709,7 +709,7 @@ class JsonSuite extends QueryTest { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = jsonRDD(mapType1, schemaWithSimpleMap) + val jsonWithSimpleMap = read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -737,7 +737,7 @@ class JsonSuite extends QueryTest { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = jsonRDD(mapType2, schemaWithComplexMap) + val jsonWithComplexMap = read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -763,7 +763,7 @@ class JsonSuite extends QueryTest { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = jsonRDD(complexFieldAndType2) + val jsonDF = read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -781,7 +781,7 @@ class JsonSuite extends QueryTest { } test("SPARK-3390 Complex arrays") { - val jsonDF = jsonRDD(complexFieldAndType2) + val jsonDF = read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -804,7 +804,7 @@ class JsonSuite extends QueryTest { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = jsonRDD(jsonArray) + val jsonDF = read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -825,7 +825,7 @@ class JsonSuite extends QueryTest { val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - val jsonDF = jsonRDD(corruptRecords) + val jsonDF = read.json(corruptRecords) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -879,7 +879,7 @@ class JsonSuite extends QueryTest { } test("SPARK-4068: nulls in arrays") { - val jsonDF = jsonRDD(nullsInArrays) + val jsonDF = read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -956,8 +956,8 @@ class JsonSuite extends QueryTest { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = jsonRDD(primitiveFieldAndType) - val primTable = jsonRDD(jsonDF.toJSON) + val jsonDF = read.json(primitiveFieldAndType) + val primTable = read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -969,8 +969,8 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val complexJsonDF = jsonRDD(complexFieldAndType1) - val compTable = jsonRDD(complexJsonDF.toJSON) + val complexJsonDF = read.json(complexFieldAndType1) + val compTable = read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1073,4 +1073,31 @@ class JsonSuite extends QueryTest { assert(StructType(Seq()) === emptySchema) } + test("SPARK-7565 MapType in JsonRDD") { + val useStreaming = getConf(SQLConf.USE_JACKSON_STREAMING_API, "true") + val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord + TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + try{ + for (useStreaming <- List("true", "false")) { + setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + val temp = Utils.createTempDir().getPath + + val df = read.schema(schemaWithSimpleMap).json(mapType1) + df.write.mode("overwrite").parquet(temp) + // order of MapType is not defined + assert(read.parquet(temp).count() == 5) + + val df2 = read.json(corruptRecords) + df2.write.mode("overwrite").parquet(temp) + checkAnswer(read.parquet(temp), df2.collect()) + } + } finally { + setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 10d0ede4dc0dc..bdc2ebabc5e9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -328,12 +328,12 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file checkAnswer( - sqlContext.parquetFile(path).filter("part = 1"), + sqlContext.read.parquet(path).filter("part = 1"), (1 to 3).map(i => Row(i, i.toString, 1))) } } @@ -350,14 +350,14 @@ class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with Before override protected def afterAll(): Unit = { sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString) } - + test("SPARK-6742: don't push down predicates which reference partition columns") { import sqlContext.implicits._ withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { withTempPath { dir => val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").saveAsParquetFile(path) + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) // If the "part = 1" filter gets pushed down, this query will throw an exception since // "part" is not a valid column in the actual Parquet file @@ -365,7 +365,7 @@ class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with Before path, Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext, Seq(AttributeReference("part", IntegerType, false)()) )) - + checkAnswer( df.filter("a = 1 or part = 1"), (1 to 3).map(i => Row(1, i, i.toString))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index b504842053690..dd48bb350f26d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -35,6 +35,7 @@ import parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -113,24 +114,24 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) - data.saveAsParquetFile(dir.getCanonicalPath) - checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + data.write.parquet(dir.getCanonicalPath) + checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } // Decimals with precision above 18 are not yet supported - intercept[RuntimeException] { + intercept[Throwable] { withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() + makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) + read.parquet(dir.getCanonicalPath).collect() } } // Unlimited-length decimals are not yet supported - intercept[RuntimeException] { + intercept[Throwable] { withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).collect() + makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) + read.parquet(dir.getCanonicalPath).collect() } } } @@ -145,8 +146,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempPath { dir => val data = makeDateRDD() - data.saveAsParquetFile(dir.getCanonicalPath) - checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq) + data.write.parquet(dir.getCanonicalPath) + checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } @@ -282,7 +283,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempDir { dir => val path = new Path(dir.toURI.toString, "part-r-0.parquet") makeRawParquetFile(path) - checkAnswer(parquetFile(path.toString), (0 until 10).map { i => + checkAnswer(read.parquet(path.toString), (0 until 10).map { i => Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) }) } @@ -310,8 +311,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { test("save - overwrite") { withParquetFile((1 to 10).map(i => (i, i.toString))) { file => val newData = (11 to 20).map(i => (i, i.toString)) - newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Overwrite, Map("path" -> file)) - checkAnswer(parquetFile(file), newData.map(Row.fromTuple)) + newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file) + checkAnswer(read.parquet(file), newData.map(Row.fromTuple)) } } @@ -319,8 +320,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val data = (1 to 10).map(i => (i, i.toString)) withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) - newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Ignore, Map("path" -> file)) - checkAnswer(parquetFile(file), data.map(Row.fromTuple)) + newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file) + checkAnswer(read.parquet(file), data.map(Row.fromTuple)) } } @@ -329,8 +330,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) val errorMessage = intercept[Throwable] { - newData.toDF().save( - "org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> file)) + newData.toDF().write.format("parquet").mode(SaveMode.ErrorIfExists).save(file) }.getMessage assert(errorMessage.contains("already exists")) } @@ -340,8 +340,8 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val data = (1 to 10).map(i => (i, i.toString)) withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) - newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Append, Map("path" -> file)) - checkAnswer(parquetFile(file), (data ++ newData).map(Row.fromTuple)) + newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file) + checkAnswer(read.parquet(file), (data ++ newData).map(Row.fromTuple)) } } @@ -373,7 +373,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { path, new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) - assertResult(parquetFile(path.toString).schema) { + assertResult(read.parquet(path.toString).schema) { StructType( StructField("a", BooleanType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: @@ -391,7 +391,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(configuration) @@ -419,11 +419,11 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA test("SPARK-6330 regression test") { // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: // IllegalArgumentException: Wrong FS: hdfs://..., expected: file:/// - intercept[java.io.FileNotFoundException] { - sqlContext.parquetFile("file:///nonexistent") + intercept[Throwable] { + sqlContext.read.parquet("file:///nonexistent") } val errorMessage = intercept[Throwable] { - sqlContext.parquetFile("hdfs://nonexistent") + sqlContext.read.parquet("hdfs://nonexistent") }.toString assert(errorMessage.contains("UnknownHostException")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index bea568ed40049..3b29979452ad9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -16,16 +16,21 @@ */ package org.apache.spark.sql.parquet +import java.io.File +import java.math.BigInteger +import java.sql.Timestamp + import scala.collection.mutable.ArrayBuffer +import com.google.common.io.Files import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.sources.PartitioningUtils._ -import org.apache.spark.sql.sources.{Partition, PartitionSpec} +import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLContext} +import org.apache.spark.sql.{Column, QueryTest, Row, SQLContext} // The data where the partitioning key exists only in the directory structure. case class ParquetData(intField: Int, stringField: String) @@ -39,7 +44,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { import sqlContext._ import sqlContext.implicits._ - val defaultPartitionName = "__NULL__" + val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" test("column type inference") { def check(raw: String, literal: Literal): Unit = { @@ -54,44 +59,45 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } test("parse partition") { - def check(path: String, expected: PartitionValues): Unit = { + def check(path: String, expected: Option[PartitionValues]): Unit = { assert(expected === parsePartition(new Path(path), defaultPartitionName)) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), defaultPartitionName) + parsePartition(new Path(path), defaultPartitionName).get }.getMessage assert(message.contains(expected)) } - check( - "file:///", - PartitionValues( - ArrayBuffer.empty[String], - ArrayBuffer.empty[Literal])) - - check( - "file://path/a=10", + check("file://path/a=10", Some { PartitionValues( ArrayBuffer("a"), - ArrayBuffer(Literal.create(10, IntegerType)))) + ArrayBuffer(Literal.create(10, IntegerType))) + }) - check( - "file://path/a=10/b=hello/c=1.5", + check("file://path/a=10/b=hello/c=1.5", Some { PartitionValues( ArrayBuffer("a", "b", "c"), ArrayBuffer( Literal.create(10, IntegerType), Literal.create("hello", StringType), - Literal.create(1.5, FloatType)))) + Literal.create(1.5, FloatType))) + }) - check( - "file://path/a=10/b_hello/c=1.5", + check("file://path/a=10/b_hello/c=1.5", Some { PartitionValues( ArrayBuffer("c"), - ArrayBuffer(Literal.create(1.5, FloatType)))) + ArrayBuffer(Literal.create(1.5, FloatType))) + }) + + check("file:///", None) + check("file:///path/_temporary", None) + check("file:///path/_temporary/c=1.5", None) + check("file:///path/_temporary/path", None) + check("file://path/a=10/_temporary/c=1.5", None) + check("file://path/a=10/c=1.5/_temporary", None) checkThrows[AssertionError]("file://path/=10", "Empty partition column name") checkThrows[AssertionError]("file://path/a=", "Empty partition column value") @@ -121,6 +127,25 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + check(Seq( + "hdfs://host:9000/path/_temporary", + "hdfs://host:9000/path/a=10/b=20", + "hdfs://host:9000/path/a=10.5/b=hello", + "hdfs://host:9000/path/a=10.5/_temporary", + "hdfs://host:9000/path/a=10.5/_TeMpOrArY", + "hdfs://host:9000/path/a=10.5/b=hello/_temporary", + "hdfs://host:9000/path/a=10.5/b=hello/_TEMPORARY", + "hdfs://host:9000/path/_temporary/path", + "hdfs://host:9000/path/a=11/_temporary/path", + "hdfs://host:9000/path/a=10.5/b=world/_temporary/path"), + PartitionSpec( + StructType(Seq( + StructField("a", FloatType), + StructField("b", StringType))), + Seq( + Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"), + Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello")))) + check(Seq( s"hdfs://host:9000/path/a=10/b=20", s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"), @@ -142,6 +167,11 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { Seq( Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"), Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName")))) + + check(Seq( + s"hdfs://host:9000/path1", + s"hdfs://host:9000/path2"), + PartitionSpec.emptySpec) } test("read partitioned table - normal case") { @@ -150,12 +180,18 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { pi <- Seq(1, 2) ps <- Seq("foo", "bar") } { + val dir = makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps) makeParquetFile( (1 to 10).map(i => ParquetData(i, i.toString)), - makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + dir) + // Introduce _temporary dir to test the robustness of the schema discovery process. + new File(dir.toString, "_temporary").mkdir() } + // Introduce _temporary dir to the base dir the robustness of the schema discovery process. + new File(base.getCanonicalPath, "_temporary").mkdir() - parquetFile(base.getCanonicalPath).registerTempTable("t") + println("load the partitioned table") + read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -202,7 +238,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - parquetFile(base.getCanonicalPath).registerTempTable("t") + read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -250,12 +286,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = load( - "org.apache.spark.sql.parquet", - Map( - "path" -> base.getCanonicalPath, - ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName)) - + val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -295,12 +326,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = load( - "org.apache.spark.sql.parquet", - Map( - "path" -> base.getCanonicalPath, - ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName)) - + val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -332,7 +358,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - load(base.getCanonicalPath, "org.apache.spark.sql.parquet").registerTempTable("t") + read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -341,4 +367,86 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } } } + + test("SPARK-7749 Non-partitioned table should have empty partition spec") { + withTempPath { dir => + (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) + val queryExecution = read.parquet(dir.getCanonicalPath).queryExecution + queryExecution.analyzed.collectFirst { + case LogicalRelation(relation: ParquetRelation2) => + assert(relation.partitionSpec === PartitionSpec.emptySpec) + }.getOrElse { + fail(s"Expecting a ParquetRelation2, but got:\n$queryExecution") + } + } + } + + test("SPARK-7847: Dynamic partition directory path escaping and unescaping") { + withTempPath { dir => + val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s") + df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath) + checkAnswer(read.parquet(dir.getCanonicalPath), df.collect()) + } + } + + test("Various partition value types") { + val row = + Row( + 100.toByte, + 40000.toShort, + Int.MaxValue, + Long.MaxValue, + 1.5.toFloat, + 4.5, + new java.math.BigDecimal(new BigInteger("212500"), 5), + new java.math.BigDecimal(2.125), + java.sql.Date.valueOf("2015-05-23"), + new Timestamp(0), + "This is a string, /[]?=:", + "This is not a partition column") + + // BooleanType is not supported yet + val partitionColumnTypes = + Seq( + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType(10, 5), + DecimalType.Unlimited, + DateType, + TimestampType, + StringType) + + val partitionColumns = partitionColumnTypes.zipWithIndex.map { + case (t, index) => StructField(s"p_$index", t) + } + + val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) + val df = createDataFrame(sparkContext.parallelize(row :: Nil), schema) + + withTempPath { dir => + df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name).cast(f.dataType)) + checkAnswer(read.load(dir.toString).select(fields: _*), row) + } + } + + test("SPARK-8037: Ignores files whose name starts with dot") { + withTempPath { dir => + val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d") + + df.write + .format("parquet") + .partitionBy("b", "c", "d") + .save(dir.getCanonicalPath) + + Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) + Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) + + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index b98ba09ccfc2d..304936fb2be8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.parquet import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.types._ import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext @@ -111,6 +112,18 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { List(Row("same", "run_5", 100))) } } + + test("SPARK-6917 DecimalType should work with non-native types") { + val data = (1 to 10).map(i => Row(Decimal(i, 18, 0), new java.sql.Timestamp(i))) + val schema = StructType(List(StructField("d", DecimalType(18, 0), false), + StructField("time", TimestampType, false)).toArray) + withTempPath { file => + val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) + df.write.parquet(file.getCanonicalPath) + val df2 = sqlContext.read.parquet(file.getCanonicalPath) + checkAnswer(df2, df.collect().toSeq) + } + } } class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index c964b6d984557..caec2a6f25489 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.parquet import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.scalatest.FunSuite import parquet.schema.MessageTypeParser +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class ParquetSchemaSuite extends FunSuite with ParquetTest { +class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { val sqlContext = TestSQLContext /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala similarity index 59% rename from sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala rename to sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 9d17516e0ef7d..516ba373f41d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -21,10 +21,9 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import scala.util.Try -import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode} -import org.apache.spark.util.Utils +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{DataFrame, SaveMode} /** * A helper trait that provides convenient facilities for Parquet testing. @@ -33,54 +32,9 @@ import org.apache.spark.util.Utils * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest { - val sqlContext: SQLContext - +private[sql] trait ParquetTest extends SQLTestUtils { import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder} - import sqlContext.{conf, sparkContext} - - protected def configuration = sparkContext.hadoopConfiguration - - /** - * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL - * configurations. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { - val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) - (keys, values).zipped.foreach(conf.setConf) - try f finally { - keys.zip(currentValues).foreach { - case (key, Some(value)) => conf.setConf(key, value) - case (key, None) => conf.unsetConf(key) - } - } - } - - /** - * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If - * a file/directory is created there by `f`, it will be delete after `f` returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempPath(f: File => Unit): Unit = { - val path = Utils.createTempDir() - path.delete() - try f(path) finally Utils.deleteRecursively(path) - } - - /** - * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` - * returns. - * - * @todo Probably this method should be moved to a more general place - */ - protected def withTempDir(f: File => Unit): Unit = { - val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) - } + import sqlContext.sparkContext /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` @@ -90,7 +44,7 @@ private[sql] trait ParquetTest { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).toDF().saveAsParquetFile(file.getCanonicalPath) + sparkContext.parallelize(data).toDF().write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -102,14 +56,7 @@ private[sql] trait ParquetTest { protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => f(sqlContext.parquetFile(path))) - } - - /** - * Drops temporary table `tableName` after calling `f`. - */ - protected def withTempTable(tableName: String)(f: => Unit): Unit = { - try f finally sqlContext.dropTempTable(tableName) + withParquetFile(data)(path => f(sqlContext.read.parquet(path))) } /** @@ -128,12 +75,12 @@ private[sql] trait ParquetTest { protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite) + data.toDF().write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite) + df.write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makePartitionDir( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 4e54b2eb8df7a..d2d1011b8e917 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -33,7 +33,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { override def beforeAll(): Unit = { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - jsonRDD(rdd).registerTempTable("jt") + read.json(rdd).registerTempTable("jt") } override def afterAll(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 6664e8d64c13a..5c3467158a01b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -43,7 +43,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo StructField("bigintType", LongType, nullable = false), StructField("tinyintType", ByteType, nullable = false), StructField("decimalType", DecimalType.Unlimited, nullable = false), - StructField("fixedDecimalType", DecimalType(5,1), nullable = false), + StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), StructField("binaryType", BinaryType, nullable = false), StructField("booleanType", BooleanType, nullable = false), StructField("smallIntType", ShortType, nullable = false), @@ -51,8 +51,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo StructField("mapType", MapType(StringType, StringType)), StructField("arrayType", ArrayType(StringType)), StructField("structType", - StructType(StructField("f1",StringType) :: - (StructField("f2",IntegerType)) :: Nil + StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil ) ) )) @@ -99,4 +98,10 @@ class DDLTestSuite extends DataSourceTest { Row("arrayType", "array", ""), Row("structType", "struct", "") )) + + test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { + val attributes = sql("describe ddlPeople").queryExecution.executedPlan.output + assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) + assert(attributes.map(_.dataType).toSet === Set(StringType)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index cce747e7dbf64..db94b1f3e8926 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -154,7 +154,7 @@ class FilteredScanSuite extends DataSourceTest { sqlTest( "SELECT a, b FROM oneToTenFiltered WHERE a IN (1,3,5)", - Seq(1,3,5).map(i => Row(i, i * 2))) + Seq(1, 3, 5).map(i => Row(i, i * 2))) sqlTest( "SELECT a, b FROM oneToTenFiltered WHERE A = 1", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index d1d427e1790bd..6f375ef36237d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -33,7 +33,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { override def beforeAll: Unit = { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - jsonRDD(rdd).registerTempTable("jt") + read.json(rdd).registerTempTable("jt") sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) @@ -109,7 +109,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { // Writing the table to less part files. val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5) - jsonRDD(rdd1).registerTempTable("jt1") + read.json(rdd1).registerTempTable("jt1") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 @@ -121,7 +121,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { // Writing the table to more part files. val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10) - jsonRDD(rdd2).registerTempTable("jt2") + read.json(rdd2).registerTempTable("jt2") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 @@ -154,13 +154,13 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } test("save directly to the path of a JSON table") { - table("jt").selectExpr("a * 5 as a", "b").save(path.toString, "json", SaveMode.Overwrite) + table("jt").selectExpr("a * 5 as a", "b").write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 5, s"str$i")) ) - table("jt").save(path.toString, "json", SaveMode.Overwrite) + table("jt").write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 8331a14c9295c..296b0d6f74a0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.sources -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class ResolvedDataSourceSuite extends FunSuite { +class ResolvedDataSourceSuite extends SparkFunSuite { test("builtin sources") { assert(ResolvedDataSource.lookupDataSource("jdbc") === diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 6567d1acd7644..274c652dd14d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -42,7 +42,7 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { path.delete() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - df = jsonRDD(rdd) + df = read.json(rdd) df.registerTempTable("jsonTable") } @@ -57,41 +57,48 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { def checkLoad(): Unit = { conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - checkAnswer(load(path.toString), df.collect()) + checkAnswer(read.load(path.toString), df.collect()) // Test if we can pick up the data source name passed in load. conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - checkAnswer(load(path.toString, "org.apache.spark.sql.json"), df.collect()) - checkAnswer(load("org.apache.spark.sql.json", Map("path" -> path.toString)), df.collect()) + checkAnswer(read.format("json").load(path.toString), df.collect()) + checkAnswer(read.format("json").load(path.toString), df.collect()) val schema = StructType(StructField("b", StringType, true) :: Nil) checkAnswer( - load("org.apache.spark.sql.json", schema, Map("path" -> path.toString)), + read.format("json").schema(schema).load(path.toString), sql("SELECT b FROM jsonTable").collect()) } test("save with path and load") { conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - df.save(path.toString) + df.write.save(path.toString) + checkLoad() + } + + test("save with string mode and path, and load") { + conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + path.createNewFile() + df.write.mode("overwrite").save(path.toString) checkLoad() } test("save with path and datasource, and load") { conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.save(path.toString, "org.apache.spark.sql.json") + df.write.json(path.toString) checkLoad() } test("save with data source and options, and load") { conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, Map("path" -> path.toString)) + df.write.mode(SaveMode.ErrorIfExists).json(path.toString) checkLoad() } test("save and save again") { - df.save(path.toString, "org.apache.spark.sql.json") + df.write.json(path.toString) var message = intercept[RuntimeException] { - df.save(path.toString, "org.apache.spark.sql.json") + df.write.json(path.toString) }.getMessage assert( @@ -100,14 +107,14 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { if (path.exists()) Utils.deleteRecursively(path) - df.save(path.toString, "org.apache.spark.sql.json") + df.write.json(path.toString) checkLoad() - df.save("org.apache.spark.sql.json", SaveMode.Overwrite, Map("path" -> path.toString)) + df.write.mode(SaveMode.Overwrite).json(path.toString) checkLoad() message = intercept[RuntimeException] { - df.save("org.apache.spark.sql.json", SaveMode.Append, Map("path" -> path.toString)) + df.write.mode(SaveMode.Append).json(path.toString) }.getMessage assert( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala new file mode 100644 index 0000000000000..17a8b0cca09df --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import java.io.File + +import scala.util.Try + +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils + +trait SQLTestUtils { + val sqlContext: SQLContext + + import sqlContext.{conf, sparkContext} + + protected def configuration = sparkContext.hadoopConfiguration + + /** + * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL + * configurations. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) + (keys, values).zipped.foreach(conf.setConf) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => conf.setConf(key, value) + case (key, None) => conf.unsetConf(key) + } + } + } + + /** + * Generates a temporary path without creating the actual file/directory, then pass it to `f`. If + * a file/directory is created there by `f`, it will be delete after `f` returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + /** + * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` + * returns. + * + * @todo Probably this method should be moved to a more general place + */ + protected def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + /** + * Drops temporary table `tableName` after calling `f`. + */ + protected def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(sqlContext.dropTempTable) + } + + /** + * Drops table `tableName` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + try f finally { + tableNames.foreach { name => + sqlContext.sql(s"DROP TABLE IF EXISTS $name") + } + } + } +} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 437f697d25bf3..20d3c7d4c5959 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -41,6 +41,13 @@ spark-hive_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + com.google.guava guava diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 3458b04bfba0f..94687eeda4179 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -17,23 +17,23 @@ package org.apache.spark.sql.hive.thriftserver +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} -import org.apache.spark.sql.SQLConf -import org.apache.spark.{SparkContext, SparkConf, Logging} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerApplicationEnd, SparkListener} import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab +import org.apache.spark.sql.hive.{HiveContext, HiveShim} import org.apache.spark.util.Utils - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{Logging, SparkContext} /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a @@ -51,6 +51,7 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) + sqlContext.setConf("spark.sql.hive.version", HiveShim.version) server.init(sqlContext.hiveconf) server.start() listener = new HiveThriftServer2Listener(server, sqlContext.conf) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index deb1008c468bf..14f6f658d9b75 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.Utils private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" private var continuedPrompt = "".padTo(prompt.length, ' ') - private var transport:TSocket = _ + private var transport: TSocket = _ installSignalHandler() @@ -276,13 +276,13 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { driver.init() val out = sessionState.out - val start:Long = System.currentTimeMillis() + val start: Long = System.currentTimeMillis() if (sessionState.getIsVerbose) { out.println(cmd) } val rc = driver.run(cmd) val end = System.currentTimeMillis() - val timeTaken:Double = (end - start) / 1000.0 + val timeTaken: Double = (end - start) / 1000.0 ret = rc.getResponseCode if (ret != 0) { @@ -310,7 +310,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { res.clear() } } catch { - case e:IOException => + case e: IOException => console.printError( s"""Failed with exception ${e.getClass.getName}: ${e.getMessage} |${org.apache.hadoop.util.StringUtils.stringifyException(e)} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 6a2be4a58e5cb..10c83d8b27a2a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -47,7 +47,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" ++ generateSessionStatsTable() ++ generateSQLStatsTable() - UIUtils.headerSparkPage("ThriftServer", content, parent, Some(5000)) + UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -77,7 +77,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" [{id}] } - val detail = if(info.state == ExecutionState.FAILED) info.detail else info.executePlan + val detail = if (info.state == ExecutionState.FAILED) info.detail else info.executePlan {info.userName} @@ -85,7 +85,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {info.groupId} {formatDate(info.startTimestamp)} - {if(info.finishTimestamp > 0) formatDate(info.finishTimestamp)} + {if (info.finishTimestamp > 0) formatDate(info.finishTimestamp)} {formatDurationOption(Some(info.totalTime))} {info.statement} {info.state} @@ -143,14 +143,14 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/ThriftServer/session?id=%s" + val sessionLink = "%s/sql/session?id=%s" .format(UIUtils.prependBaseUri(parent.basePath), session.sessionId) {session.userName} {session.ip} {session.sessionId} {formatDate(session.startTimestamp)} - {if(session.finishTimestamp > 0) formatDate(session.finishTimestamp)} + {if (session.finishTimestamp > 0) formatDate(session.finishTimestamp)} {formatDurationOption(Some(session.totalTime))} {session.totalExecution.toString} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 33ba038ecce73..3b01afa603cea 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -55,7 +55,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) Total run {sessionStat._2.totalExecution} SQL ++ generateSQLStatsTable(sessionStat._2.sessionId) - UIUtils.headerSparkPage("ThriftServer", content, parent, Some(5000)) + UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) } /** Generate basic stats of the streaming program */ @@ -87,7 +87,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) [{id}] } - val detail = if(info.state == ExecutionState.FAILED) info.detail else info.executePlan + val detail = if (info.state == ExecutionState.FAILED) info.detail else info.executePlan {info.userName} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 343031f10c75c..94fd8a6bb60b9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -27,7 +27,9 @@ import org.apache.spark.{SparkContext, Logging, SparkException} * This assumes the given SparkContext has enabled its SparkUI. */ private[thriftserver] class ThriftServerTab(sparkContext: SparkContext) - extends SparkUITab(getSparkUI(sparkContext), "ThriftServer") with Logging { + extends SparkUITab(getSparkUI(sparkContext), "sql") with Logging { + + override val name = "SQL" val parent = getSparkUI(sparkContext) val listener = HiveThriftServer2.listener diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index b070fa8eaa469..13b0c5951dddc 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -25,12 +25,16 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.Utils -class CliSuite extends FunSuite with BeforeAndAfter with Logging { +/** + * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary + * Hive metastore and warehouse. + */ +class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { val warehousePath = Utils.createTempDir() val metastorePath = Utils.createTempDir() @@ -58,13 +62,13 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging { | --master local | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --driver-class-path ${sys.props("java.class.path")} """.stripMargin.split("\\s+").toSeq ++ extraArgs } var next = 0 val foundAllExpectedAnswers = Promise.apply[Unit]() - val queryStream = new ByteArrayInputStream(queries.mkString("\n").getBytes) + // Explicitly adds ENTER for each statement to make sure they are actually entered into the CLI. + val queryStream = new ByteArrayInputStream(queries.map(_ + "\n").mkString.getBytes) val buffer = new ArrayBuffer[String]() val lock = new Object @@ -124,12 +128,12 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging { "SELECT COUNT(*) FROM hive_test;" -> "5", "DROP TABLE hive_test;" - -> "Time taken: " + -> "OK" ) } test("Single command with -e") { - runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") + runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") } test("Single command with --database") { @@ -151,4 +155,33 @@ class CliSuite extends FunSuite with BeforeAndAfter with Logging { -> "hive_test" ) } + + test("Commands using SerDe provided in --jars") { + val jarFile = + "../hive/src/test/resources/hive-hcatalog-core-0.13.1.jar" + .split("/") + .mkString(File.separator) + + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + + runCliWithin(3.minute, Seq("--jars", s"$jarFile"))( + """CREATE TABLE t1(key string, val string) + |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; + """.stripMargin + -> "OK", + "CREATE TABLE sourceTable (key INT, val STRING);" + -> "OK", + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;" + -> "OK", + "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" + -> "Time taken:", + "SELECT count(key) FROM t1;" + -> "5", + "DROP TABLE t1;" + -> "OK", + "DROP TABLE sourceTable;" + -> "OK" + ) + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 1fadea97fd07f..a93a3dee43511 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -27,6 +27,8 @@ import scala.concurrent.{Await, Promise} import scala.sys.process.{Process, ProcessLogger} import scala.util.{Random, Try} +import com.google.common.base.Charsets.UTF_8 +import com.google.common.io.Files import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.jdbc.HiveDriver import org.apache.hive.service.auth.PlainSaslHelper @@ -35,9 +37,9 @@ import org.apache.hive.service.cli.thrift.TCLIService.Client import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.hive.HiveShim import org.apache.spark.util.Utils @@ -54,7 +56,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { override def mode: ServerMode.Value = ServerMode.binary private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { - // Transport creation logics below mimics HiveConnection.createBinaryTransport + // Transport creation logic below mimics HiveConnection.createBinaryTransport val rawTransport = new TSocket("localhost", serverPort) val user = System.getProperty("user.name") val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) @@ -391,10 +393,10 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { val statements = connections.map(_.createStatement()) try { - statements.zip(fs).map { case (s, f) => f(s) } + statements.zip(fs).foreach { case (s, f) => f(s) } } finally { - statements.map(_.close()) - connections.map(_.close()) + statements.foreach(_.close()) + connections.foreach(_.close()) } } @@ -403,7 +405,7 @@ abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { } } -abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging { +abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAll with Logging { def mode: ServerMode.Value private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$") @@ -433,15 +435,33 @@ abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll wit ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT } + val driverClassPath = { + // Writes a temporary log4j.properties and prepend it to driver classpath, so that it + // overrides all other potential log4j configurations contained in other dependency jar files. + val tempLog4jConf = Utils.createTempDir().getCanonicalPath + + Files.write( + """log4j.rootCategory=INFO, console + |log4j.appender.console=org.apache.log4j.ConsoleAppender + |log4j.appender.console.target=System.err + |log4j.appender.console.layout=org.apache.log4j.PatternLayout + |log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + """.stripMargin, + new File(s"$tempLog4jConf/log4j.properties"), + UTF_8) + + tempLog4jConf + File.pathSeparator + sys.props("java.class.path") + } + s"""$startScript | --master local - | --hiveconf hive.root.logger=INFO,console | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode | --hiveconf $portConf=$port - | --driver-class-path ${sys.props("java.class.path")} + | --driver-class-path $driverClassPath + | --driver-java-options -Dlog4j.debug | --conf spark.ui.enabled=false """.stripMargin.split("\\s+").toSeq } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 47541015a3611..4c9fab7ef6136 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -17,21 +17,18 @@ package org.apache.spark.sql.hive.thriftserver - - import scala.util.Random +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver -import org.scalatest.{Matchers, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.sql.hive.HiveContext - class UISeleniumSuite extends HiveThriftJdbcTest with WebBrowser with Matchers with BeforeAndAfterAll { @@ -75,9 +72,9 @@ class UISeleniumSuite """.stripMargin.split("\\s+").toSeq } - test("thrift server ui test") { - withJdbcStatement(statement =>{ - val baseURL = s"http://localhost:${uiPort}" + ignore("thrift server ui test") { + withJdbcStatement { statement => + val baseURL = s"http://localhost:$uiPort" val queries = Seq( "CREATE TABLE test_map(key INT, value STRING)", @@ -86,20 +83,20 @@ class UISeleniumSuite queries.foreach(statement.execute) eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (baseURL) - find(cssSelector("""ul li a[href*="ThriftServer"]""")) should not be(None) + go to baseURL + find(cssSelector("""ul li a[href*="sql"]""")) should not be None } eventually(timeout(10 seconds), interval(50 milliseconds)) { - go to (baseURL + "/ThriftServer") - find(id("sessionstat")) should not be(None) - find(id("sqlstat")) should not be(None) + go to (baseURL + "/sql") + find(id("sessionstat")) should not be None + find(id("sqlstat")) should not be None // check whether statements exists queries.foreach { line => findAll(cssSelector("""ul table tbody tr td""")).map(_.text).toList should contain (line) } } - }) + } } } diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala deleted file mode 100644 index b3a79ba1c7d6b..0000000000000 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ /dev/null @@ -1,278 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.sql.{Date, Timestamp} -import java.util.concurrent.Executors -import java.util.{ArrayList => JArrayList, Map => JMap, UUID} - -import org.apache.commons.logging.Log -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.service.cli.thrift.TProtocolVersion -import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager - -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, Map => SMap} - -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.hadoop.hive.shims.ShimLoader -import org.apache.hadoop.security.UserGroupInformation -import org.apache.hive.service.cli._ -import org.apache.hive.service.cli.operation.ExecuteStatementOperation -import org.apache.hive.service.cli.session.{SessionManager, HiveSession} - -import org.apache.spark.Logging -import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} -import org.apache.spark.sql.execution.SetCommand -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.types._ - -/** - * A compatibility layer for interacting with Hive version 0.12.0. - */ -private[thriftserver] object HiveThriftServerShim { - val version = "0.12.0" - - def setServerUserName(sparkServiceUGI: UserGroupInformation, sparkCliService:SparkSQLCLIService) = { - val serverUserName = ShimLoader.getHadoopShims.getShortUserName(sparkServiceUGI) - setSuperField(sparkCliService, "serverUserName", serverUserName) - } -} - -private[hive] class SparkSQLDriver(val _context: HiveContext = SparkSQLEnv.hiveContext) - extends AbstractSparkSQLDriver(_context) { - override def getResults(res: JArrayList[String]): Boolean = { - if (hiveResponse == null) { - false - } else { - res.addAll(hiveResponse) - hiveResponse = null - true - } - } -} - -private[hive] class SparkExecuteStatementOperation( - parentSession: HiveSession, - statement: String, - confOverlay: JMap[String, String])( - hiveContext: HiveContext, - sessionToActivePool: SMap[SessionHandle, String]) - extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging { - - private var result: DataFrame = _ - private var iter: Iterator[SparkRow] = _ - private var dataTypes: Array[DataType] = _ - - def close(): Unit = { - // RDDs will be cleaned automatically upon garbage collection. - logDebug("CLOSING") - } - - def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { - if (!iter.hasNext) { - new RowSet() - } else { - // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int - val maxRows = maxRowsL.toInt - var curRow = 0 - var rowSet = new ArrayBuffer[Row](maxRows.min(1024)) - - while (curRow < maxRows && iter.hasNext) { - val sparkRow = iter.next() - val row = new Row() - var curCol = 0 - - while (curCol < sparkRow.length) { - if (sparkRow.isNullAt(curCol)) { - addNullColumnValue(sparkRow, row, curCol) - } else { - addNonNullColumnValue(sparkRow, row, curCol) - } - curCol += 1 - } - rowSet += row - curRow += 1 - } - new RowSet(rowSet, 0) - } - } - - def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(from(ordinal).asInstanceOf[String]) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) - case DecimalType() => - val hiveDecimal = from.getDecimal(ordinal) - to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) - case LongType => - to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) - case DateType => - to.addColumnValue(ColumnValue.dateValue(from(ordinal).asInstanceOf[Date])) - case TimestampType => - to.addColumnValue( - ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = HiveContext.toHiveString((from.get(ordinal), dataTypes(ordinal))) - to.addColumnValue(ColumnValue.stringValue(hiveString)) - } - } - - def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { - dataTypes(ordinal) match { - case StringType => - to.addString(null) - case IntegerType => - to.addColumnValue(ColumnValue.intValue(null)) - case BooleanType => - to.addColumnValue(ColumnValue.booleanValue(null)) - case DoubleType => - to.addColumnValue(ColumnValue.doubleValue(null)) - case FloatType => - to.addColumnValue(ColumnValue.floatValue(null)) - case DecimalType() => - to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) - case LongType => - to.addColumnValue(ColumnValue.longValue(null)) - case ByteType => - to.addColumnValue(ColumnValue.byteValue(null)) - case ShortType => - to.addColumnValue(ColumnValue.shortValue(null)) - case DateType => - to.addColumnValue(ColumnValue.dateValue(null)) - case TimestampType => - to.addColumnValue(ColumnValue.timestampValue(null)) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - to.addColumnValue(ColumnValue.stringValue(null: String)) - } - } - - def getResultSetSchema: TableSchema = { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - if (result.queryExecution.analyzed.output.size == 0) { - new TableSchema(new FieldSchema("Result", "string", "") :: Nil) - } else { - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - new TableSchema(schema) - } - } - - def run(): Unit = { - val statementId = UUID.randomUUID().toString - logInfo(s"Running query '$statement'") - setState(OperationState.RUNNING) - HiveThriftServer2.listener.onStatementStart( - statementId, parentSession.getSessionHandle.getSessionId.toString, statement, statementId) - hiveContext.sparkContext.setJobGroup(statementId, statement) - sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool => - hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) - } - try { - result = hiveContext.sql(statement) - logDebug(result.queryExecution.toString()) - result.queryExecution.logical match { - case SetCommand(Some((SQLConf.THRIFTSERVER_POOL, Some(value))), _) => - sessionToActivePool(parentSession.getSessionHandle) = value - logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") - case _ => - } - HiveThriftServer2.listener.onStatementParsed(statementId, result.queryExecution.toString()) - iter = { - val useIncrementalCollect = - hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean - if (useIncrementalCollect) { - result.rdd.toLocalIterator - } else { - result.collect().iterator - } - } - dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray - setHasResultSet(true) - } catch { - // Actually do need to catch Throwable as some failures don't inherit from Exception and - // HiveServer will silently swallow them. - case e: Throwable => - setState(OperationState.ERROR) - HiveThriftServer2.listener.onStatementError( - statementId, e.getMessage, e.getStackTraceString) - logError("Error executing query:",e) - throw new HiveSQLException(e.toString) - } - setState(OperationState.FINISHED) - HiveThriftServer2.listener.onStatementFinish(statementId) - } -} - -private[hive] class SparkSQLSessionManager(hiveContext: HiveContext) - extends SessionManager - with ReflectedCompositeService { - - private lazy val sparkSqlOperationManager = new SparkSQLOperationManager(hiveContext) - - override def init(hiveConf: HiveConf) { - setSuperField(this, "hiveConf", hiveConf) - - val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS) - setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize)) - getAncestorField[Log](this, 3, "LOG").info( - s"HiveServer2: Async execution pool size $backgroundPoolSize") - - setSuperField(this, "operationManager", sparkSqlOperationManager) - addService(sparkSqlOperationManager) - - initCompositeService(hiveConf) - } - - override def openSession( - username: String, - passwd: String, - sessionConf: java.util.Map[String, String], - withImpersonation: Boolean, - delegationToken: String): SessionHandle = { - hiveContext.openSession() - val sessionHandle = super.openSession( - username, passwd, sessionConf, withImpersonation, delegationToken) - HiveThriftServer2.listener.onSessionCreated("UNKNOWN", sessionHandle.getSessionId.toString) - sessionHandle - } - - override def closeSession(sessionHandle: SessionHandle) { - HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString) - super.closeSession(sessionHandle) - sparkSqlOperationManager.sessionToActivePool -= sessionHandle - - hiveContext.detachSession() - } -} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index b6245a57074c8..0b1917a392901 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -250,7 +250,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // The isolated classloader seemed to make some of our test reset mechanisms less robust. "combine1", // This test changes compression settings in a way that breaks all subsequent tests. - "load_dyn_part14.*" // These work alone but fail when run with other tests... + "load_dyn_part14.*", // These work alone but fail when run with other tests... + + // the answer is sensitive for jdk version + "udf_java_method" ) ++ HiveShim.compatibilityBlackList /** @@ -877,7 +880,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udf_int", "udf_isnotnull", "udf_isnull", - "udf_java_method", "udf_lcase", "udf_length", "udf_lessthan", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index e322340094e6f..923ffabb9b99e 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -41,6 +41,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-sql_${scala.binary.version} @@ -136,16 +143,6 @@ - - hive-0.12.0 - - - com.twitter - parquet-hive-bundle - 1.5.0 - - - diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index 3f20c6142e59a..7f8449cdc282d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -29,10 +29,10 @@ import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand} private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser { // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object - protected val ADD = Keyword("ADD") - protected val DFS = Keyword("DFS") + protected val ADD = Keyword("ADD") + protected val DFS = Keyword("DFS") protected val FILE = Keyword("FILE") - protected val JAR = Keyword("JAR") + protected val JAR = Keyword("JAR") protected lazy val start: Parser[LogicalPlan] = dfs | addJar | addFile | hiveQl diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 766c42d040f80..fbf2c7d8cbc06 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, File, InputStreamReader, PrintStream} +import java.net.{URL, URLClassLoader} import java.sql.Timestamp import java.util.{ArrayList => JArrayList} @@ -25,6 +26,7 @@ import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.spark.sql.catalyst.ParserDialect import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.language.implicitConversions import org.apache.hadoop.fs.{FileSystem, Path} @@ -122,6 +124,29 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[hive] def hiveMetastoreJars: String = getConf(HIVE_METASTORE_JARS, "builtin") + /** + * A comma separated list of class prefixes that should be loaded using the classloader that + * is shared between Spark SQL and a specific version of Hive. An example of classes that should + * be shared is JDBC drivers that are needed to talk to the metastore. Other classes that need + * to be shared are those that interact with classes that are already shared. For example, + * custom appenders that are used by log4j. + */ + protected[hive] def hiveMetastoreSharedPrefixes: Seq[String] = + getConf("spark.sql.hive.metastore.sharedPrefixes", jdbcPrefixes) + .split(",").filterNot(_ == "") + + private def jdbcPrefixes = Seq( + "com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc").mkString(",") + + /** + * A comma separated list of class prefixes that should explicitly be reloaded for each version + * of Hive that Spark SQL is communicating with. For example, Hive UDFs that are declared in a + * prefix that typically would be shared (i.e. org.apache.spark.*) + */ + protected[hive] def hiveMetastoreBarrierPrefixes: Seq[String] = + getConf("spark.sql.hive.metastore.barrierPrefixes", "") + .split(",").filterNot(_ == "") + @transient protected[sql] lazy val substitutor = new VariableSubstitution() @@ -130,11 +155,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * Hive 13 as this is the version of Hive that is packaged with Spark SQL. This copy of the * client is used for execution related tasks like registering temporary functions or ensuring * that the ThreadLocal SessionState is correctly populated. This copy of Hive is *not* used - * for storing peristent metadata, and only point to a dummy metastore in a temporary directory. + * for storing persistent metadata, and only point to a dummy metastore in a temporary directory. */ @transient protected[hive] lazy val executionHive: ClientWrapper = { - logInfo(s"Initilizing execution hive, version $hiveExecutionVersion") + logInfo(s"Initializing execution hive, version $hiveExecutionVersion") new ClientWrapper( version = IsolatedClientLoader.hiveVersion(hiveExecutionVersion), config = newTemporaryConfiguration()) @@ -164,13 +189,22 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { "Specify a vaild path to the correct hive jars using $HIVE_METASTORE_JARS " + s"or change $HIVE_METASTORE_VERSION to $hiveExecutionVersion.") } - val jars = getClass.getClassLoader match { - case urlClassLoader: java.net.URLClassLoader => urlClassLoader.getURLs - case other => - throw new IllegalArgumentException( - "Unable to locate hive jars to connect to metastore " + - s"using classloader ${other.getClass.getName}. " + - "Please set spark.sql.hive.metastore.jars") + + // We recursively find all jars in the class loader chain, + // starting from the given classLoader. + def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { + case null => Array.empty[URL] + case urlClassLoader: URLClassLoader => + urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) + case other => allJars(other.getParent) + } + + val classLoader = Utils.getContextOrSparkClassLoader + val jars = allJars(classLoader) + if (jars.length == 0) { + throw new IllegalArgumentException( + "Unable to locate hive jars to connect to metastore. " + + "Please set spark.sql.hive.metastore.jars.") } logInfo( @@ -179,12 +213,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { version = metaVersion, execJars = jars.toSeq, config = allConfig, - isolationOn = true) + isolationOn = true, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) } else if (hiveMetastoreJars == "maven") { // TODO: Support for loading the jars from an already downloaded location. logInfo( s"Initializing HiveMetastoreConnection version $hiveMetastoreVersion using maven.") - IsolatedClientLoader.forVersion(hiveMetastoreVersion, allConfig ) + IsolatedClientLoader.forVersion(hiveMetastoreVersion, allConfig) } else { // Convert to files and expand any directories. val jars = @@ -210,7 +246,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { version = metaVersion, execJars = jars.toSeq, config = allConfig, - isolationOn = true) + isolationOn = true, + barrierPrefixes = hiveMetastoreBarrierPrefixes, + sharedPrefixes = hiveMetastoreSharedPrefixes) } isolatedLoader.client } @@ -316,9 +354,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def setConf(key: String, value: String): Unit = { super.setConf(key, value) - hiveconf.set(key, value) executionHive.runSqlHive(s"SET $key=$value") metadataHive.runSqlHive(s"SET $key=$value") + // If users put any Spark SQL setting in the spark conf (e.g. spark-defaults.conf), + // this setConf will be called in the constructor of the SQLContext. + // Also, calling hiveconf will create a default session containing a HiveConf, which + // will interfer with the creation of executionHive (which is a lazy val). So, + // we put hiveconf.set at the end of this method. + hiveconf.set(key, value) } /* A catalyst metadata catalog that points to the Hive Metastore. */ @@ -330,12 +373,12 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient override protected[sql] lazy val functionRegistry = new HiveFunctionRegistry with OverrideFunctionRegistry { - def caseSensitive: Boolean = false + override def conf: CatalystConf = currentSession().conf } /* An analyzer that uses the Hive metastore. */ @transient - override protected[sql] lazy val analyzer = + override protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = catalog.ParquetConversions :: @@ -345,6 +388,10 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { ResolveHiveWindowFunction :: sources.PreInsertCastAndRename :: Nil + + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) } override protected[sql] def createSession(): SQLSession = { @@ -480,8 +527,19 @@ private[hive] object HiveContext { def newTemporaryConfiguration(): Map[String, String] = { val tempDir = Utils.createTempDir() val localMetastore = new File(tempDir, "metastore").getAbsolutePath - Map( - "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$localMetastore;create=true") + val propMap: HashMap[String, String] = HashMap() + // We have to mask all properties in hive-site.xml that relates to metastore data source + // as we used a local metastore here. + HiveConf.ConfVars.values().foreach { confvar => + if (confvar.varname.contains("datanucleus") || confvar.varname.contains("jdo")) { + propMap.put(confvar.varname, confvar.defaultVal) + } + } + propMap.put("javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$localMetastore;create=true") + propMap.put("datanucleus.rdbms.datastoreAdapterClassName", + "org.datanucleus.store.rdbms.adapter.DerbyAdapter") + propMap.toMap } protected val primitiveTypes = @@ -495,7 +553,7 @@ private[hive] object HiveContext { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 74ae984f34866..24cd335082639 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types import org.apache.spark.sql.types._ @@ -121,7 +122,7 @@ import scala.collection.JavaConversions._ * even a normal java object (POJO) * UnionObjectInspector: (tag: Int, object data) (TODO: not supported by SparkSQL yet) * - * 3) ConstantObjectInspector: + * 3) ConstantObjectInspector: * Constant object inspector can be either primitive type or Complex type, and it bundles a * constant value as its property, usually the value is created when the constant object inspector * constructed. @@ -132,7 +133,7 @@ import scala.collection.JavaConversions._ } }}} * Hive provides 3 built-in constant object inspectors: - * Primitive Object Inspectors: + * Primitive Object Inspectors: * WritableConstantStringObjectInspector * WritableConstantHiveVarcharObjectInspector * WritableConstantHiveDecimalObjectInspector @@ -146,9 +147,9 @@ import scala.collection.JavaConversions._ * WritableConstantByteObjectInspector * WritableConstantBinaryObjectInspector * WritableConstantDateObjectInspector - * Map Object Inspector: + * Map Object Inspector: * StandardConstantMapObjectInspector - * List Object Inspector: + * List Object Inspector: * StandardConstantListObjectInspector]] * Struct Object Inspector: Hive doesn't provide the built-in constant object inspector for Struct * Union Object Inspector: Hive doesn't provide the built-in constant object inspector for Union @@ -249,9 +250,9 @@ private[hive] trait HiveInspectors { poi.getWritableConstantValue.getHiveDecimal) case poi: WritableConstantTimestampObjectInspector => poi.getWritableConstantValue.getTimestamp.clone() - case poi: WritableConstantIntObjectInspector => + case poi: WritableConstantIntObjectInspector => poi.getWritableConstantValue.get() - case poi: WritableConstantDoubleObjectInspector => + case poi: WritableConstantDoubleObjectInspector => poi.getWritableConstantValue.get() case poi: WritableConstantBooleanObjectInspector => poi.getWritableConstantValue.get() @@ -305,7 +306,7 @@ private[hive] trait HiveInspectors { // In order to keep backward-compatible, we have to copy the // bytes with old apis val bw = x.getPrimitiveWritableObject(data) - val result = new Array[Byte](bw.getLength()) + val result = new Array[Byte](bw.getLength()) System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength()) result case x: DateObjectInspector if x.preferWritable() => @@ -334,7 +335,7 @@ private[hive] trait HiveInspectors { val allRefs = si.getAllStructFieldRefs new GenericRow( allRefs.map(r => - unwrap(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray) + unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector)).toArray) } @@ -393,6 +394,30 @@ private[hive] trait HiveInspectors { identity[Any] } + /** + * Builds specific unwrappers ahead of time according to object inspector + * types to avoid pattern matching and branching costs per row. + */ + def unwrapperFor(field: HiveStructField): (Any, MutableRow, Int) => Unit = + field.getFieldObjectInspector match { + case oi: BooleanObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value)) + case oi: ByteObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value)) + case oi: ShortObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value)) + case oi: IntObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value)) + case oi: LongObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value)) + case oi: FloatObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value)) + case oi: DoubleObjectInspector => + (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) + case oi => + (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrap(value, oi) + } + /** * Converts native catalyst types to the types expected by Hive * @param a the value to be wrapped @@ -536,8 +561,8 @@ private[hive] trait HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name) :_*), - java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) :_*)) + java.util.Arrays.asList(fields.map(f => f.name) : _*), + java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*)) } /** @@ -652,8 +677,8 @@ private[hive] trait HiveInspectors { getListTypeInfo(elemType.toTypeInfo) case StructType(fields) => getStructTypeInfo( - java.util.Arrays.asList(fields.map(_.name) :_*), - java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) :_*)) + java.util.Arrays.asList(fields.map(_.name) : _*), + java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*)) case MapType(keyType, valueType, _) => getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index d754c8e3a8aa1..ca1f49b546bd7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode, sources} import org.apache.spark.util.Utils /* Implicit conversions */ @@ -66,11 +66,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive def schemaStringFromParts: Option[String] = { table.properties.get("spark.sql.sources.schema.numParts").map { numParts => val parts = (0 until numParts.toInt).map { index => - val part = table.properties.get(s"spark.sql.sources.schema.part.${index}").orNull + val part = table.properties.get(s"spark.sql.sources.schema.part.$index").orNull if (part == null) { throw new AnalysisException( - s"Could not read schema from the metastore because it is corrupted " + - s"(missing part ${index} of the schema).") + "Could not read schema from the metastore because it is corrupted " + + s"(missing part $index of the schema, $numParts parts are expected).") } part @@ -89,6 +89,11 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val userSpecifiedSchema = schemaString.map(s => DataType.fromJson(s).asInstanceOf[StructType]) + // We only need names at here since userSpecifiedSchema we loaded from the metastore + // contains partition columns. We can always get datatypes of partitioning columns + // from userSpecifiedSchema. + val partitionColumns = table.partitionColumns.map(_.name) + // It does not appear that the ql client for the metastore has a way to enumerate all the // SerDe properties directly... val options = table.serdeProperties @@ -97,7 +102,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive ResolvedDataSource( hive, userSpecifiedSchema, - Array.empty[String], + partitionColumns.toArray, table.properties("spark.sql.sources.provider"), options) @@ -111,8 +116,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive override def refreshTable(databaseName: String, tableName: String): Unit = { // refreshTable does not eagerly reload the cache. It just invalidate the cache. // Next time when we use the table, it will be populated in the cache. - // Since we also cache ParquetRealtions converted from Hive Parquet tables and - // adding converted ParquetRealtions into the cache is not defined in the load function + // Since we also cache ParquetRelations converted from Hive Parquet tables and + // adding converted ParquetRelations into the cache is not defined in the load function // of the cache (instead, we add the cache entry in convertToParquetRelation), // it is better at here to invalidate the cache to avoid confusing waring logs from the // cache loader (e.g. cannot find data source provider, which is only defined for @@ -133,12 +138,17 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive def createDataSourceTable( tableName: String, userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], provider: String, options: Map[String, String], isExternal: Boolean): Unit = { val (dbName, tblName) = processDatabaseAndTableName("default", tableName) val tableProperties = new scala.collection.mutable.HashMap[String, String] tableProperties.put("spark.sql.sources.provider", provider) + + // Saves optional user specified schema. Serialized JSON schema string may be too long to be + // stored into a single metastore SerDe property. In this case, we split the JSON string and + // store each part as a separate SerDe property. if (userSpecifiedSchema.isDefined) { val threshold = conf.schemaStringLengthThreshold val schemaJsonString = userSpecifiedSchema.get.json @@ -146,8 +156,29 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val parts = schemaJsonString.grouped(threshold).toSeq tableProperties.put("spark.sql.sources.schema.numParts", parts.size.toString) parts.zipWithIndex.foreach { case (part, index) => - tableProperties.put(s"spark.sql.sources.schema.part.${index}", part) + tableProperties.put(s"spark.sql.sources.schema.part.$index", part) + } + } + + val metastorePartitionColumns = userSpecifiedSchema.map { schema => + val fields = partitionColumns.map(col => schema(col)) + fields.map { field => + HiveColumn( + name = field.name, + hiveType = HiveMetastoreTypes.toMetastoreType(field.dataType), + comment = "") + }.toSeq + }.getOrElse { + if (partitionColumns.length > 0) { + // The table does not have a specified schema, which means that the schema will be inferred + // when we load the table. So, we are not expecting partition columns and we will discover + // partitions when we load the table. However, if there are specified partition columns, + // we simplily ignore them and provide a warning message.. + logWarning( + s"The schema and partitions of table $tableName will be inferred when it is loaded. " + + s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") } + Seq.empty[HiveColumn] } val tableType = if (isExternal) { @@ -163,7 +194,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive specifiedDatabase = Option(dbName), name = tblName, schema = Seq.empty, - partitionColumns = Seq.empty, + partitionColumns = metastorePartitionColumns, tableType = tableType, properties = tableProperties.toMap, serdeProperties = options)) @@ -199,7 +230,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val dataSourceTable = cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase) // Then, if alias is specified, wrap the table with a Subquery using the alias. - // Othersie, wrap the table with a Subquery using the table name. + // Otherwise, wrap the table with a Subquery using the table name. val withAlias = alias.map(a => Subquery(a, dataSourceTable)).getOrElse( Subquery(tableIdent.last, dataSourceTable)) @@ -244,7 +275,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val useCached = parquetRelation.paths.toSet == pathsInMetastore.toSet && logical.schema.sameType(metastoreSchema) && - parquetRelation.maybePartitionSpec == partitionSpecInMetastore + parquetRelation.partitionSpec == partitionSpecInMetastore.getOrElse { + PartitionSpec(StructType(Nil), Array.empty[sources.Partition]) + } if (useCached) { Some(logical) @@ -256,7 +289,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive case other => logWarning( s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + - s"as Parquet. However, we are getting a ${other} from the metastore cache. " + + s"as Parquet. However, we are getting a $other from the metastore cache. " + s"This cached entry will be invalidated.") cachedDataSourceTables.invalidate(tableIdentifier) None @@ -278,8 +311,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) val parquetRelation = cached.getOrElse { - val created = - LogicalRelation(ParquetRelation2(paths, parquetOptions, None, Some(partitionSpec))(hive)) + val created = LogicalRelation( + new ParquetRelation2( + paths.toArray, None, Some(partitionSpec), parquetOptions)(hive)) cachedDataSourceTables.put(tableIdentifier, created) created } @@ -290,8 +324,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val cached = getCached(tableIdentifier, paths, metastoreSchema, None) val parquetRelation = cached.getOrElse { - val created = - LogicalRelation(ParquetRelation2(paths, parquetOptions)(hive)) + val created = LogicalRelation( + new ParquetRelation2(paths.toArray, None, None, parquetOptions)(hive)) cachedDataSourceTables.put(tableIdentifier, created) created } @@ -482,17 +516,19 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) : LogicalPlan = { val childOutputDataTypes = child.output.map(_.dataType) + val numDynamicPartitions = p.partition.values.count(_.isEmpty) val tableOutputDataTypes = - (table.attributes ++ table.partitionKeys).take(child.output.length).map(_.dataType) + (table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions)) + .take(child.output.length).map(_.dataType) if (childOutputDataTypes == tableOutputDataTypes) { - p + InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) } else if (childOutputDataTypes.size == tableOutputDataTypes.size && childOutputDataTypes.zip(tableOutputDataTypes) .forall { case (left, right) => left.sameType(right) }) { // If both types ignoring nullability of ArrayType, MapType, StructType are the same, // use InsertIntoHiveTable instead of InsertIntoTable. - InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite, p.ifNotExists) + InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) } else { // Only do the casting when child output data types differ from table output data types. val castedChildOutput = child.output.zip(table.output).map { @@ -510,13 +546,17 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = ??? + override def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = { + throw new UnsupportedOperationException + } /** * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. */ - override def unregisterTable(tableIdentifier: Seq[String]): Unit = ??? + override def unregisterTable(tableIdentifier: Seq[String]): Unit = { + throw new UnsupportedOperationException + } override def unregisterAllTables(): Unit = {} } @@ -527,7 +567,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive * because Hive table doesn't have nullability for ARRAY, MAP, STRUCT types. */ private[hive] case class InsertIntoHiveTable( - table: LogicalPlan, + table: MetastoreRelation, partition: Map[String, Option[String]], child: LogicalPlan, overwrite: Boolean, @@ -537,7 +577,13 @@ private[hive] case class InsertIntoHiveTable( override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = child.output - override lazy val resolved: Boolean = childrenResolved && child.output.zip(table.output).forall { + val numDynamicPartitions = partition.values.count(_.isEmpty) + + // This is the expected schema of the table prepared to be inserted into, + // including dynamic partition columns. + val tableOutput = table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions) + + override lazy val resolved: Boolean = childrenResolved && child.output.zip(tableOutput).forall { case (childAttr, tableAttr) => childAttr.dataType.sameType(tableAttr.dataType) } } @@ -550,7 +596,7 @@ private[hive] case class MetastoreRelation self: Product => - override def equals(other: scala.Any): Boolean = other match { + override def equals(other: Any): Boolean = other match { case relation: MetastoreRelation => databaseName == relation.databaseName && tableName == relation.tableName && @@ -665,25 +711,25 @@ private[hive] case class MetastoreRelation hiveQlTable.getMetadata ) - implicit class SchemaAttribute(f: FieldSchema) { + implicit class SchemaAttribute(f: HiveColumn) { def toAttribute: AttributeReference = AttributeReference( - f.getName, - HiveMetastoreTypes.toDataType(f.getType), + f.name, + HiveMetastoreTypes.toDataType(f.hiveType), // Since data can be dumped in randomly with no validation, everything is nullable. nullable = true )(qualifiers = Seq(alias.getOrElse(tableName))) } - // Must be a stable value since new attributes are born here. - val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) + /** PartitionKey attributes */ + val partitionKeys = table.partitionColumns.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = hiveQlTable.getCols.map(_.toAttribute) + val attributes = table.schema.map(_.toAttribute) val output = attributes ++ partitionKeys /** An attribute map that can be used to lookup original attributes based on expression id. */ - val attributeMap = AttributeMap(output.map(o => (o,o))) + val attributeMap = AttributeMap(output.map(o => (o, o))) /** An attribute map for determining the ordinal for non-partition columns. */ val columnOrdinals = AttributeMap(attributes.zipWithIndex) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 2cbb5ca4d2e0c..a5ca3613c5e00 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -57,7 +57,7 @@ private[hive] case object NativePlaceholder extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty } -case class CreateTableAsSelect( +private[hive] case class CreateTableAsSelect( tableDesc: HiveTable, child: LogicalPlan, allowExisting: Boolean) extends UnaryNode with Command { @@ -665,7 +665,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C HiveColumn(field.getName, field.getType, field.getComment) }) } - case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil)=> + case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => val serdeParams = new java.util.HashMap[String, String]() child match { case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => @@ -775,7 +775,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" case Token("TOK_TRUNCATETABLE", - Token("TOK_TABLE_PARTITION",table)::Nil) => NativePlaceholder + Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder case Token("TOK_QUERY", queryArgs) if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => @@ -1151,7 +1151,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Seq(false, false) => Inner }.toBuffer - val joinedTables = tables.reduceLeft(Join(_,_, Inner, None)) + val joinedTables = tables.reduceLeft(Join(_, _, Inner, None)) // Must be transform down. val joinedResult = joinedTables transform { @@ -1171,7 +1171,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // worth the number of hacks that will be required to implement it. Namely, we need to add // some sort of mapped star expansion that would expand all child output row to be similarly // named output expressions where some aggregate expression has been applied (i.e. First). - ??? // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) + throw new UnsupportedOperationException case Token(allJoinTokens(joinToken), relation1 :: @@ -1560,6 +1561,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C """.stripMargin) } + /* Case insensitive matches for Window Specification */ + val PRECEDING = "(?i)preceding".r + val FOLLOWING = "(?i)following".r + val CURRENT = "(?i)current".r def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { case Token(windowName, Nil) :: Nil => // Refer to a window spec defined in the window clause. @@ -1613,11 +1618,19 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } else { val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) def nodeToBoundary(node: Node): FrameBoundary = node match { - case Token("preceding", Token(count, Nil) :: Nil) => - if (count == "unbounded") UnboundedPreceding else ValuePreceding(count.toInt) - case Token("following", Token(count, Nil) :: Nil) => - if (count == "unbounded") UnboundedFollowing else ValueFollowing(count.toInt) - case Token("current", Nil) => CurrentRow + case Token(PRECEDING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedPreceding + } else { + ValuePreceding(count.toInt) + } + case Token(FOLLOWING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedFollowing + } else { + ValueFollowing(count.toInt) + } + case Token(CURRENT(), Nil) => CurrentRow case _ => throw new NotImplementedError( s"""No parse rules for the Window Frame Boundary based on Node ${node.getName} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d46a127d47d31..c6b65106452bf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -140,7 +140,7 @@ private[hive] trait HiveStrategies { PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil } else { hiveContext - .parquetFile(partitionLocations: _*) + .read.parquet(partitionLocations: _*) .addPartitioningAttributes(relation.partitionKeys) .lowerCase .where(unresolvedOtherPredicates) @@ -152,7 +152,7 @@ private[hive] trait HiveStrategies { } else { hiveContext - .parquetFile(relation.hiveQlTable.getDataLocation.toString) + .read.parquet(relation.hiveQlTable.getDataLocation.toString) .lowerCase .where(unresolvedOtherPredicates) .select(unresolvedProjection: _*) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index b69312f0f8717..294fc3bd7d5e9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -35,7 +35,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.DateUtils +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.util.Utils /** @@ -79,7 +79,7 @@ class HadoopTableReader( makeRDDForTable( hiveTable, Class.forName( - relation.tableDesc.getSerdeClassName, true, Utils.getSparkClassLoader) + relation.tableDesc.getSerdeClassName, true, Utils.getContextOrSparkClassLoader) .asInstanceOf[Class[Deserializer]], filterOpt = None) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 7f94c93ba49c1..16851fdd71a98 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -56,8 +56,7 @@ private[hive] object IsolatedClientLoader { (if (version.hasBuiltinsJar) "hive-builtins" :: Nil else Nil)) .map(a => s"org.apache.hive:$a:${version.fullVersion}") :+ "com.google.guava:guava:14.0.1" :+ - "org.apache.hadoop:hadoop-client:2.4.0" :+ - "mysql:mysql-connector-java:5.1.12" + "org.apache.hadoop:hadoop-client:2.4.0" val classpath = quietly { SparkSubmitUtils.resolveMavenCoordinates( @@ -91,14 +90,14 @@ private[hive] object IsolatedClientLoader { * `ClientInterface`, unless `isolationOn` is set to `false`. * * @param version The version of hive on the classpath. used to pick specific function signatures - * that are not compatibile accross versions. + * that are not compatible across versions. * @param execJars A collection of jar files that must include hive and hadoop. * @param config A set of options that will be added to the HiveConf of the constructed client. * @param isolationOn When true, custom versions of barrier classes will be constructed. Must be * true unless loading the version of hive that is on Sparks classloader. - * @param rootClassLoader The system root classloader. Must not know about hive classes. - * @param baseClassLoader The spark classloader that is used to load shared classes. - * + * @param rootClassLoader The system root classloader. + * @param baseClassLoader The spark classloader that is used to load shared classes. Must not know + * about Hive classes. */ private[hive] class IsolatedClientLoader( val version: HiveVersion, @@ -106,10 +105,12 @@ private[hive] class IsolatedClientLoader( val config: Map[String, String] = Map.empty, val isolationOn: Boolean = true, val rootClassLoader: ClassLoader = ClassLoader.getSystemClassLoader.getParent.getParent, - val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader) + val baseClassLoader: ClassLoader = Thread.currentThread().getContextClassLoader, + val sharedPrefixes: Seq[String] = Seq.empty, + val barrierPrefixes: Seq[String] = Seq.empty) extends Logging { - // Check to make sure that the root classloader does not know about Hive. + // Check to make sure that the base classloader does not know about Hive. assert(Try(baseClassLoader.loadClass("org.apache.hive.HiveConf")).isFailure) /** All jars used by the hive specific classloader. */ @@ -122,13 +123,14 @@ private[hive] class IsolatedClientLoader( name.startsWith("scala.") || name.startsWith("com.google") || name.startsWith("java.lang.") || - name.startsWith("java.net") + name.startsWith("java.net") || + sharedPrefixes.exists(name.startsWith) /** True if `name` refers to a spark class that must see specific version of Hive. */ protected def isBarrierClass(name: String): Boolean = - name.startsWith("org.apache.spark.sql.hive.execution.PairSerDe") || name.startsWith(classOf[ClientWrapper].getName) || - name.startsWith(classOf[ReflectionMagic].getName) + name.startsWith(classOf[ReflectionMagic].getName) || + barrierPrefixes.exists(name.startsWith) protected def classToPath(name: String): String = name.replaceAll("\\.", "/") + ".class" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala index c600b158c5460..4d053ae42c2ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ReflectionMagic.scala @@ -30,7 +30,7 @@ private[client] object ReflectionException { /** * Provides implicit functions on any object for calling methods reflectively. */ -protected trait ReflectionMagic { +private[client] trait ReflectionMagic { /** code for InstanceMagic println( (1 to 22).map { n => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 7db9200d47440..410d9881ac214 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -29,5 +29,5 @@ package object client { case object v13 extends HiveVersion("0.13.1", false) } // scalastyle:on - + } \ No newline at end of file diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 62dc4167b78dd..11ee5503146b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -63,7 +63,7 @@ case class HiveTableScan( BindReferences.bindReference(pred, relation.partitionKeys) } - // Create a local copy of hiveconf,so that scan specific modifications should not impact + // Create a local copy of hiveconf,so that scan specific modifications should not impact // other queries @transient private[this] val hiveExtraConf = new HiveConf(context.hiveconf) @@ -72,7 +72,7 @@ case class HiveTableScan( addColumnMetadataToConf(hiveExtraConf) @transient - private[this] val hadoopReader = + private[this] val hadoopReader = new HadoopTableReader(attributes, relation, context, hiveExtraConf) private[this] def castFromString(value: String, dataType: DataType) = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index c0b0b104e9142..8613332186f28 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -106,7 +106,7 @@ case class InsertIntoHiveTable( } writerContainer - .getLocalFileWriter(row) + .getLocalFileWriter(row, table.schema) .write(serializer.serialize(outputData, standardOI)) } @@ -194,10 +194,9 @@ case class InsertIntoHiveTable( if (partition.nonEmpty) { // loadPartition call orders directories created on the iteration order of the this map - val orderedPartitionSpec = new util.LinkedHashMap[String,String]() - table.hiveQlTable.getPartCols().foreach{ - entry=> - orderedPartitionSpec.put(entry.getName,partitionSpec.get(entry.getName).getOrElse("")) + val orderedPartitionSpec = new util.LinkedHashMap[String, String]() + table.hiveQlTable.getPartCols().foreach { entry => + orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse("")) } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index bfd26e0170c70..fd623370cc407 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -62,7 +62,7 @@ case class ScriptTransformation( val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val reader = new BufferedReader(new InputStreamReader(inputStream)) - + val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) val iterator: Iterator[Row] = new Iterator[Row] with HiveInspectors { @@ -95,7 +95,7 @@ case class ScriptTransformation( val raw = outputSerde.deserialize(writable) val dataList = outputSoi.getStructFieldsDataAsList(raw) val fieldList = outputSoi.getAllStructFieldRefs() - + var i = 0 dataList.foreach( element => { if (element == null) { @@ -117,7 +117,7 @@ case class ScriptTransformation( if (!hasNext) { throw new NoSuchElementException } - + if (outputSerde == null) { val prevLine = curLine curLine = reader.readLine() @@ -192,7 +192,7 @@ case class HiveScriptIOSchema ( val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - + def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = { val (columns, columnTypes) = parseAttrs(input) val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps) @@ -206,22 +206,22 @@ case class HiveScriptIOSchema ( } def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - + val columns = attrs.map { case aref: AttributeReference => aref.name case e: NamedExpression => e.name case _ => null } - + val columnTypes = attrs.map { case aref: AttributeReference => aref.dataType case e: NamedExpression => e.dataType - case _ => null + case _ => null } (columns, columnTypes) } - + def initSerDe(serdeClassName: String, columns: Seq[String], columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { @@ -240,7 +240,7 @@ case class HiveScriptIOSchema ( (kv._1.split("'")(1), kv._2.split("'")(1)) }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - + val properties = new Properties() properties.putAll(propsMap) serde.initialize(null, properties) @@ -261,7 +261,7 @@ case class HiveScriptIOSchema ( null } } - + def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = { if (outputSerde != null) { outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 8e405e080489f..0ba94d7b7c649 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -146,6 +146,7 @@ case class CreateMetastoreDataSource( hiveContext.catalog.createDataSourceTable( tableName, userSpecifiedSchema, + Array.empty[String], provider, optionsWithPath, isExternal) @@ -194,7 +195,7 @@ case class CreateMetastoreDataSourceAsSelect( sqlContext, Some(query.schema.asNullable), partitionColumns, provider, optionsWithPath) val createdRelation = LogicalRelation(resolved.relation) EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match { - case l @ LogicalRelation(_: InsertableRelation | _: FSBasedRelation) => + case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation) => if (l.relation != createdRelation.relation) { val errorDescription = s"Cannot append to table $tableName because the resolved relation does not " + @@ -244,6 +245,7 @@ case class CreateMetastoreDataSourceAsSelect( hiveContext.catalog.createDataSourceTable( tableName, Some(resolved.relation.schema), + partitionColumns, provider, optionsWithPath, isExternal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index fd0b6f058595d..1658bb93b0b79 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -75,9 +75,11 @@ private[hive] abstract class HiveFunctionRegistry private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { - type EvaluatedType = Any + type UDFType = UDF + override def deterministic: Boolean = isUDFDeterministic + override def nullable: Boolean = true @transient @@ -139,7 +141,8 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF - type EvaluatedType = Any + + override def deterministic: Boolean = isUDFDeterministic override def nullable: Boolean = true @@ -316,7 +319,7 @@ private[hive] case class HiveWindowFunction( // The object inspector of values returned from the Hive window function. @transient - protected lazy val returnInspector = { + protected lazy val returnInspector = { evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } @@ -336,8 +339,6 @@ private[hive] case class HiveWindowFunction( def nullable: Boolean = true - override type EvaluatedType = Any - override def eval(input: Row): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") @@ -413,7 +414,7 @@ private[hive] case class HiveGenericUdaf( protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() @transient - protected lazy val objectInspector = { + protected lazy val objectInspector = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) resolver.getEvaluator(parameterInfo) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) @@ -446,7 +447,7 @@ private[hive] case class HiveUdaf( new GenericUDAFBridge(funcWrapper.createFunction()) @transient - protected lazy val objectInspector = { + protected lazy val objectInspector = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) resolver.getEvaluator(parameterInfo) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) @@ -483,7 +484,11 @@ private[hive] case class HiveGenericUdtf( extends Generator with HiveInspectors { @transient - protected lazy val function: GenericUDTF = funcWrapper.createFunction() + protected lazy val function: GenericUDTF = { + val fun: GenericUDTF = funcWrapper.createFunction() + fun.setCollector(collector) + fun + } @transient protected lazy val inputInspectors = children.map(toInspector) @@ -494,6 +499,9 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val udtInput = new Array[AnyRef](children.length) + @transient + protected lazy val collector = new UDTFCollector + lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { field => (inspectorToDataType(field.getFieldObjectInspector), true) } @@ -502,8 +510,7 @@ private[hive] case class HiveGenericUdtf( outputInspector // Make sure initialized. val inputProjection = new InterpretedProjection(children) - val collector = new UDTFCollector - function.setCollector(collector) + function.process(wrap(inputProjection(input), inputInspectors, udtInput)) collector.collectRows() } @@ -525,6 +532,12 @@ private[hive] case class HiveGenericUdtf( } } + override def terminate(): TraversableOnce[Row] = { + outputInspector // Make sure initialized. + function.close() + collector.collectRows() + } + override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } @@ -546,12 +559,12 @@ private[hive] case class HiveUdafFunction( } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } - + private val inspectors = exprs.map(toInspector).toArray - - private val function = { + + private val function = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - resolver.getEvaluator(parameterInfo) + resolver.getEvaluator(parameterInfo) } private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) @@ -566,7 +579,7 @@ private[hive] case class HiveUdafFunction( @transient protected lazy val cached = new Array[AnyRef](exprs.length) - + def update(input: Row): Unit = { val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index cbc381cc81b59..2bb526b14be34 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -34,8 +34,10 @@ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} +import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.hive.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.types._ /** * Internal helper class that saves an RDD using a Hive OutputFormat. @@ -69,7 +71,7 @@ private[hive] class SparkHiveWriterContainer( @transient protected lazy val jobContext = newJobContext(conf.value, jID.value) @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) @transient private lazy val outputFormat = - conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef, Writable]] def driverSideSetup() { setIDs(0, 0, 0) @@ -92,7 +94,7 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = writer + def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = writer def close() { // Seems the boolean value passed into close does not matter. @@ -195,11 +197,20 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { + override def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = { + def convertToHiveRawString(col: String, value: Any): String = { + val raw = String.valueOf(value) + schema(col).dataType match { + case DateType => DateUtils.toString(raw.toInt) + case _: DecimalType => BigDecimal(raw).toString() + case _ => raw + } + } + val dynamicPartPath = dynamicPartColNames .zip(row.toSeq.takeRight(dynamicPartColNames.length)) .map { case (col, rawVal) => - val string = if (rawVal == null) null else String.valueOf(rawVal) + val string = if (rawVal == null) null else convertToHiveRawString(col, rawVal) val colString = if (string == null || string.isEmpty) { defaultPartName diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala new file mode 100644 index 0000000000000..1e51173a19882 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.ql.io.orc.{OrcFile, Reader} +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.hive.HiveMetastoreTypes +import org.apache.spark.sql.types.StructType + +private[orc] object OrcFileOperator extends Logging{ + def getFileReader(pathStr: String, config: Option[Configuration] = None ): Reader = { + val conf = config.getOrElse(new Configuration) + val fspath = new Path(pathStr) + val fs = fspath.getFileSystem(conf) + val orcFiles = listOrcFiles(pathStr, conf) + + // TODO Need to consider all files when schema evolution is taken into account. + OrcFile.createReader(fs, orcFiles.head) + } + + def readSchema(path: String, conf: Option[Configuration]): StructType = { + val reader = getFileReader(path, conf) + val readerInspector = reader.getObjectInspector.asInstanceOf[StructObjectInspector] + val schema = readerInspector.getTypeName + HiveMetastoreTypes.toDataType(schema).asInstanceOf[StructType] + } + + def getObjectInspector(path: String, conf: Option[Configuration]): StructObjectInspector = { + getFileReader(path, conf).getObjectInspector.asInstanceOf[StructObjectInspector] + } + + def listOrcFiles(pathStr: String, conf: Configuration): Seq[Path] = { + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + val path = origPath.makeQualified(fs) + val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) + .filterNot(_.isDir) + .map(_.getPath) + .filterNot(_.getName.startsWith("_")) + .filterNot(_.getName.startsWith(".")) + + if (paths == null || paths.size == 0) { + throw new IllegalArgumentException( + s"orcFileOperator: path $path does not have valid orc files matching the pattern") + } + + paths + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala new file mode 100644 index 0000000000000..250e73a4dba92 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument +import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder +import org.apache.hadoop.hive.serde2.io.DateWritable + +import org.apache.spark.Logging +import org.apache.spark.sql.sources._ + +/** + * It may be optimized by push down partial filters. But we are conservative here. + * Because if some filters fail to be parsed, the tree may be corrupted, + * and cannot be used anymore. + */ +private[orc] object OrcFilters extends Logging { + def createFilter(expr: Array[Filter]): Option[SearchArgument] = { + expr.reduceOption(And).flatMap { conjunction => + val builder = SearchArgument.FACTORY.newBuilder() + buildSearchArgument(conjunction, builder).map(_.build()) + } + } + + private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { + def newBuilder = SearchArgument.FACTORY.newBuilder() + + def isSearchableLiteral(value: Any) = value match { + // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. + case _: String | _: Long | _: Double | _: DateWritable | _: HiveDecimal | _: HiveChar | + _: HiveVarchar | _: Byte | _: Short | _: Integer | _: Float => true + case _ => false + } + + // lian: I probably missed something here, and had to end up with a pretty weird double-checking + // pattern when converting `And`/`Or`/`Not` filters. + // + // The annoying part is that, `SearchArgument` builder methods like `startAnd()` `startOr()`, + // and `startNot()` mutate internal state of the builder instance. This forces us to translate + // all convertible filters with a single builder instance. However, before actually converting a + // filter, we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible + // filter is found, we may already end up with a builder whose internal state is inconsistent. + // + // For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and + // then try to convert its children. Say we convert `left` child successfully, but find that + // `right` child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is + // inconsistent now. + // + // The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + // children with brand new builders, and only do the actual conversion with the right builder + // instance when the children are proven to be convertible. + // + // P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. + // Usage of builder methods mentioned above can only be found in test code, where all tested + // filters are known to be convertible. + + expression match { + case And(left, right) => + val tryLeft = buildSearchArgument(left, newBuilder) + val tryRight = buildSearchArgument(right, newBuilder) + + val conjunction = for { + _ <- tryLeft + _ <- tryRight + lhs <- buildSearchArgument(left, builder.startAnd()) + rhs <- buildSearchArgument(right, lhs) + } yield rhs.end() + + // For filter `left AND right`, we can still push down `left` even if `right` is not + // convertible, and vice versa. + conjunction + .orElse(tryLeft.flatMap(_ => buildSearchArgument(left, builder))) + .orElse(tryRight.flatMap(_ => buildSearchArgument(right, builder))) + + case Or(left, right) => + for { + _ <- buildSearchArgument(left, newBuilder) + _ <- buildSearchArgument(right, newBuilder) + lhs <- buildSearchArgument(left, builder.startOr()) + rhs <- buildSearchArgument(right, lhs) + } yield rhs.end() + + case Not(child) => + for { + _ <- buildSearchArgument(child, newBuilder) + negate <- buildSearchArgument(child, builder.startNot()) + } yield negate.end() + + case EqualTo(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.equals(attribute, _)) + + case LessThan(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.lessThan(attribute, _)) + + case LessThanOrEqual(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.lessThanEquals(attribute, _)) + + case GreaterThan(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.startNot().lessThanEquals(attribute, _).end()) + + case GreaterThanOrEqual(attribute, value) => + Option(value) + .filter(isSearchableLiteral) + .map(builder.startNot().lessThan(attribute, _).end()) + + case IsNull(attribute) => + Some(builder.isNull(attribute)) + + case IsNotNull(attribute) => + Some(builder.startNot().isNull(attribute).end()) + + case In(attribute, values) => + Option(values) + .filter(_.forall(isSearchableLiteral)) + .map(builder.in(attribute, _)) + + case _ => None + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala new file mode 100644 index 0000000000000..f03c4cd54e7e6 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.util.Properties + +import com.google.common.base.Objects +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSerde, OrcSplit} +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils +import org.apache.hadoop.io.{NullWritable, Writable} +import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, RecordWriter, Reporter} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mapred.SparkHadoopMapRedUtil +import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} +import org.apache.spark.sql.sources.{Filter, _} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.{Logging, SerializableWritable} + +/* Implicit conversions */ +import scala.collection.JavaConversions._ + +private[sql] class DefaultSource extends HadoopFsRelationProvider { + def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + assert( + sqlContext.isInstanceOf[HiveContext], + "The ORC data source can only be used with HiveContext.") + + new OrcRelation(paths, dataSchema, None, partitionColumns, parameters)(sqlContext) + } +} + +private[orc] class OrcOutputWriter( + path: String, + dataSchema: StructType, + context: TaskAttemptContext) + extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { + + private val serializer = { + val table = new Properties() + table.setProperty("columns", dataSchema.fieldNames.mkString(",")) + table.setProperty("columns.types", dataSchema.map { f => + HiveMetastoreTypes.toMetastoreType(f.dataType) + }.mkString(":")) + + val serde = new OrcSerde + serde.initialize(context.getConfiguration, table) + serde + } + + // Object inspector converted from the schema of the relation to be written. + private val structOI = { + val typeInfo = + TypeInfoUtils.getTypeInfoFromTypeString( + HiveMetastoreTypes.toMetastoreType(dataSchema)) + + TypeInfoUtils + .getStandardJavaObjectInspectorFromTypeInfo(typeInfo) + .asInstanceOf[StructObjectInspector] + } + + // Used to hold temporary `Writable` fields of the next row to be written. + private val reusableOutputBuffer = new Array[Any](dataSchema.length) + + // Used to convert Catalyst values into Hadoop `Writable`s. + private val wrappers = structOI.getAllStructFieldRefs.map { ref => + wrapperFor(ref.getFieldObjectInspector) + }.toArray + + // `OrcRecordWriter.close()` creates an empty file if no rows are written at all. We use this + // flag to decide whether `OrcRecordWriter.close()` needs to be called. + private var recordWriterInstantiated = false + + private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { + recordWriterInstantiated = true + + val conf = context.getConfiguration + val partition = context.getTaskAttemptID.getTaskID.getId + val filename = f"part-r-$partition%05d-${System.currentTimeMillis}%015d.orc" + + new OrcOutputFormat().getRecordWriter( + new Path(path, filename).getFileSystem(conf), + conf.asInstanceOf[JobConf], + new Path(path, filename).toUri.getPath, + Reporter.NULL + ).asInstanceOf[RecordWriter[NullWritable, Writable]] + } + + override def write(row: Row): Unit = { + var i = 0 + while (i < row.length) { + reusableOutputBuffer(i) = wrappers(i)(row(i)) + i += 1 + } + + recordWriter.write( + NullWritable.get(), + serializer.serialize(reusableOutputBuffer, structOI)) + } + + override def close(): Unit = { + if (recordWriterInstantiated) { + recordWriter.close(Reporter.NULL) + } + } +} + +@DeveloperApi +private[sql] class OrcRelation( + override val paths: Array[String], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + @transient val sqlContext: SQLContext) + extends HadoopFsRelation(maybePartitionSpec) + with Logging { + + private[sql] def this( + paths: Array[String], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + parameters: Map[String, String])( + sqlContext: SQLContext) = { + this( + paths, + maybeDataSchema, + maybePartitionSpec, + maybePartitionSpec.map(_.partitionColumns), + parameters)(sqlContext) + } + + override val dataSchema: StructType = maybeDataSchema.getOrElse { + OrcFileOperator.readSchema( + paths.head, Some(sqlContext.sparkContext.hadoopConfiguration)) + } + + override def needConversion: Boolean = false + + override def equals(other: Any): Boolean = other match { + case that: OrcRelation => + paths.toSet == that.paths.toSet && + dataSchema == that.dataSchema && + schema == that.schema && + partitionColumns == that.partitionColumns + case _ => false + } + + override def hashCode(): Int = { + Objects.hashCode( + paths.toSet, + dataSchema, + schema, + partitionColumns) + } + + override def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputPaths: Array[FileStatus]): RDD[Row] = { + val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes + OrcTableScan(output, this, filters, inputPaths).execute() + } + + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OrcOutputWriter(path, dataSchema, context) + } + } + } +} + +private[orc] case class OrcTableScan( + attributes: Seq[Attribute], + @transient relation: OrcRelation, + filters: Array[Filter], + @transient inputPaths: Array[FileStatus]) + extends Logging + with HiveInspectors { + + @transient private val sqlContext = relation.sqlContext + + private def addColumnIds( + output: Seq[Attribute], + relation: OrcRelation, + conf: Configuration): Unit = { + val ids = output.map(a => relation.dataSchema.fieldIndex(a.name): Integer) + val (sortedIds, sortedNames) = ids.zip(attributes.map(_.name)).sorted.unzip + HiveShim.appendReadColumns(conf, sortedIds, sortedNames) + } + + // Transform all given raw `Writable`s into `Row`s. + private def fillObject( + path: String, + conf: Configuration, + iterator: Iterator[Writable], + nonPartitionKeyAttrs: Seq[(Attribute, Int)], + mutableRow: MutableRow): Iterator[Row] = { + val deserializer = new OrcSerde + val soi = OrcFileOperator.getObjectInspector(path, Some(conf)) + val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { + case (attr, ordinal) => + soi.getStructFieldRef(attr.name.toLowerCase) -> ordinal + }.unzip + val unwrappers = fieldRefs.map(unwrapperFor) + // Map each tuple to a row object + iterator.map { value => + val raw = deserializer.deserialize(value) + var i = 0 + while (i < fieldRefs.length) { + val fieldValue = soi.getStructFieldData(raw, fieldRefs(i)) + if (fieldValue == null) { + mutableRow.setNullAt(fieldOrdinals(i)) + } else { + unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i)) + } + i += 1 + } + mutableRow: Row + } + } + + def execute(): RDD[Row] = { + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration + + // Tries to push down filters if ORC filter push-down is enabled + if (sqlContext.conf.orcFilterPushDown) { + OrcFilters.createFilter(filters).foreach { f => + conf.set(OrcTableScan.SARG_PUSHDOWN, f.toKryo) + conf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) + } + } + + // Sets requested columns + addColumnIds(attributes, relation, conf) + + if (inputPaths.nonEmpty) { + FileInputFormat.setInputPaths(job, inputPaths.map(_.getPath): _*) + } + + val inputFormatClass = + classOf[OrcInputFormat] + .asInstanceOf[Class[_ <: MapRedInputFormat[NullWritable, Writable]]] + + val rdd = sqlContext.sparkContext.hadoopRDD( + conf.asInstanceOf[JobConf], + inputFormatClass, + classOf[NullWritable], + classOf[Writable] + ).asInstanceOf[HadoopRDD[NullWritable, Writable]] + + val wrappedConf = new SerializableWritable(conf) + + rdd.mapPartitionsWithInputSplit { case (split: OrcSplit, iterator) => + val mutableRow = new SpecificMutableRow(attributes.map(_.dataType)) + fillObject( + split.getPath.toString, + wrappedConf.value, + iterator.map(_._2), + attributes.zipWithIndex, + mutableRow) + } + } +} + +private[orc] object OrcTableScan { + // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. + private[orc] val SARG_PUSHDOWN = "sarg.pushdown" +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 1598d4bd47550..7c7afc824d7a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -48,7 +48,14 @@ import scala.collection.JavaConversions._ // SPARK-3729: Test key required to check for initialization errors with config. object TestHive extends TestHiveContext( - new SparkContext("local[2]", "TestSQLContext", new SparkConf().set("spark.sql.test", ""))) + new SparkContext( + "local[2]", + "TestSQLContext", + new SparkConf() + .set("spark.sql.test", "") + .set( + "spark.sql.hive.metastore.barrierPrefixes", + "org.apache.spark.sql.hive.execution.PairSerDe"))) /** * A locally running test instance of Spark's Hive execution engine. @@ -75,9 +82,11 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { lazy val warehousePath = Utils.createTempDir() + private lazy val temporaryConfig = newTemporaryConfiguration() + /** Sets up the system initially or after a RESET command */ protected override def configure(): Map[String, String] = - newTemporaryConfiguration() ++ Map("hive.metastore.warehouse.dir" -> warehousePath.toString) + temporaryConfig ++ Map("hive.metastore.warehouse.dir" -> warehousePath.toString) val testTempDir = Utils.createTempDir() @@ -180,7 +189,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } } - case class TestTable(name: String, commands: (()=>Unit)*) + case class TestTable(name: String, commands: (() => Unit)*) protected[hive] implicit class SqlCmd(sql: String) { def cmd: () => Unit = { @@ -244,8 +253,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { | 'serialization.format'='${classOf[TBinaryProtocol].getName}' |) |STORED AS - |INPUTFORMAT '${classOf[SequenceFileInputFormat[_,_]].getName}' - |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_,_]].getName}' + |INPUTFORMAT '${classOf[SequenceFileInputFormat[_, _]].getName}' + |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}' """.stripMargin) runSqlHive( diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java new file mode 100644 index 0000000000000..c4828c4717643 --- /dev/null +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.hive; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.*; +import org.apache.spark.sql.expressions.Window; +import org.apache.spark.sql.hive.HiveContext; +import org.apache.spark.sql.hive.test.TestHive$; + +public class JavaDataFrameSuite { + private transient JavaSparkContext sc; + private transient HiveContext hc; + + DataFrame df; + + private void checkAnswer(DataFrame actual, List expected) { + String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); + if (errorMessage != null) { + Assert.fail(errorMessage); + } + } + + @Before + public void setUp() throws IOException { + hc = TestHive$.MODULE$; + sc = new JavaSparkContext(hc.sparkContext()); + + List jsonObjects = new ArrayList(10); + for (int i = 0; i < 10; i++) { + jsonObjects.add("{\"key\":" + i + ", \"value\":\"str" + i + "\"}"); + } + df = hc.jsonRDD(sc.parallelize(jsonObjects)); + df.registerTempTable("window_table"); + } + + @After + public void tearDown() throws IOException { + // Clean up tables. + hc.sql("DROP TABLE IF EXISTS window_table"); + } + + @Test + public void saveTableAndQueryIt() { + checkAnswer( + df.select(functions.avg("key").over( + Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), + hc.sql("SELECT avg(key) " + + "OVER (PARTITION BY value " + + " ORDER BY key " + + " ROWS BETWEEN 1 preceding and 1 following) " + + "FROM window_table").collectAsList()); + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java similarity index 89% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 53ddecf57958b..64d1ce92931eb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -14,7 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.hive; + +package test.org.apache.spark.sql.hive; import java.io.File; import java.io.IOException; @@ -36,6 +37,7 @@ import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.QueryTest$; import org.apache.spark.sql.Row; +import org.apache.spark.sql.hive.HiveContext; import org.apache.spark.sql.hive.test.TestHive$; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -81,7 +83,7 @@ public void setUp() throws IOException { jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}"); } JavaRDD rdd = sc.parallelize(jsonObjects); - df = sqlContext.jsonRDD(rdd); + df = sqlContext.read().json(rdd); df.registerTempTable("jsonTable"); } @@ -96,7 +98,11 @@ public void tearDown() throws IOException { public void saveExternalTableAndQueryIt() { Map options = new HashMap(); options.put("path", path.toString()); - df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + df.write() + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .options(options) + .saveAsTable("javaSavedTable"); checkAnswer( sqlContext.sql("SELECT * FROM javaSavedTable"), @@ -115,7 +121,11 @@ public void saveExternalTableAndQueryIt() { public void saveExternalTableWithSchemaAndQueryIt() { Map options = new HashMap(); options.put("path", path.toString()); - df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + df.write() + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .options(options) + .saveAsTable("javaSavedTable"); checkAnswer( sqlContext.sql("SELECT * FROM javaSavedTable"), @@ -138,7 +148,11 @@ public void saveExternalTableWithSchemaAndQueryIt() { @Test public void saveTableAndQueryIt() { Map options = new HashMap(); - df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options); + df.write() + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .options(options) + .saveAsTable("javaSavedTable"); checkAnswer( sqlContext.sql("SELECT * FROM javaSavedTable"), diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFIntegerToString.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFIntegerToString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListListInt.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListListInt.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFListString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFStringString.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFStringString.java diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java similarity index 100% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFTwoListList.java rename to sql/hive/src/test/java/test/org/apache/spark/sql/hive/execution/UDFTwoListList.java diff --git a/sql/hive/src/test/resources/TestUDTF.jar b/sql/hive/src/test/resources/TestUDTF.jar new file mode 100644 index 0000000000000..514f2d5d26fd3 Binary files /dev/null and b/sql/hive/src/test/resources/TestUDTF.jar differ diff --git a/sql/hive/src/test/resources/golden/Test UDTF.close in Lateral Views-0-ac5c96224a534f07b49462ad76620678 b/sql/hive/src/test/resources/golden/Test UDTF.close in Lateral Views-0-ac5c96224a534f07b49462ad76620678 new file mode 100644 index 0000000000000..946e72fc87c2e --- /dev/null +++ b/sql/hive/src/test/resources/golden/Test UDTF.close in Lateral Views-0-ac5c96224a534f07b49462ad76620678 @@ -0,0 +1,2 @@ +97 500 +97 500 diff --git a/sql/hive/src/test/resources/golden/Test UDTF.close in SELECT-0-517f834fef35b896ec64399f42b2a151 b/sql/hive/src/test/resources/golden/Test UDTF.close in SELECT-0-517f834fef35b896ec64399f42b2a151 new file mode 100644 index 0000000000000..a5c8806279fa7 --- /dev/null +++ b/sql/hive/src/test/resources/golden/Test UDTF.close in SELECT-0-517f834fef35b896ec64399f42b2a151 @@ -0,0 +1,2 @@ +3 +3 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index fc6c3c35037b0..39d315aaeab57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -57,7 +57,7 @@ class CachedTableSuite extends QueryTest { checkAnswer( sql("SELECT * FROM src s"), preCacheResults) - + uncacheTable("src") assertCached(sql("SELECT * FROM src"), 0) } @@ -162,7 +162,7 @@ class CachedTableSuite extends QueryTest { test("REFRESH TABLE also needs to recache the data (data source tables)") { val tempPath: File = Utils.createTempDir() tempPath.delete() - table("src").save(tempPath.toString, "parquet", SaveMode.Overwrite) + table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) sql("DROP TABLE IF EXISTS refreshTable") createExternalTable("refreshTable", tempPath.toString, "parquet") checkAnswer( @@ -172,7 +172,7 @@ class CachedTableSuite extends QueryTest { sql("CACHE TABLE refreshTable") assertCached(table("refreshTable")) // Append new data. - table("src").save(tempPath.toString, "parquet", SaveMode.Append) + table("src").write.mode(SaveMode.Append).parquet(tempPath.toString) // We are still using the old data. assertCached(table("refreshTable")) checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala new file mode 100644 index 0000000000000..fb10f8583da99 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.scalatest.BeforeAndAfterAll + +// TODO ideally we should put the test suite into the package `sql`, as +// `hive` package is optional in compiling, however, `SQLContext.sql` doesn't +// support the `cube` or `rollup` yet. +class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll { + private var testData: DataFrame = _ + + override def beforeAll() { + testData = Seq((1, 2), (2, 4)).toDF("a", "b") + TestHive.registerDataFrameAsTable(testData, "mytable") + } + + override def afterAll(): Unit = { + TestHive.dropTempTable("mytable") + } + + test("rollup") { + checkAnswer( + testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")), + sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect() + ) + + checkAnswer( + testData.rollup("a", "b").agg(sum("b")), + sql("select a, b, sum(b) from mytable group by a, b with rollup").collect() + ) + } + + test("cube") { + checkAnswer( + testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), + sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect() + ) + + checkAnswer( + testData.cube("a", "b").agg(sum("b")), + sql("select a, b, sum(b) from mytable group by a, b with cube").collect() + ) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala new file mode 100644 index 0000000000000..efb3f2545db84 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + +class HiveDataFrameWindowSuite extends QueryTest { + + test("reuse window partitionBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = Window.partitionBy("key").orderBy("value") + + checkAnswer( + df.select( + lead("key", 1).over(w), + lead("value", 1).over(w)), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("reuse window orderBy") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val w = Window.orderBy("value").partitionBy("key") + + checkAnswer( + df.select( + lead("key", 1).over(w), + lead("value", 1).over(w)), + Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil) + } + + test("lead") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + + checkAnswer( + df.select( + lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), + sql( + """SELECT + | lead(value) OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lag") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + + checkAnswer( + df.select( + lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), + sql( + """SELECT + | lag(value) OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lead with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), + sql( + """SELECT + | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("lag with default value") { + val df = Seq((1, "1"), (1, "1"), (2, "2"), (1, "1"), + (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), + sql( + """SELECT + | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) + | FROM window_table""".stripMargin).collect()) + } + + test("rank functions in unspecific window") { + val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + max("key").over(Window.partitionBy("value").orderBy("key")), + min("key").over(Window.partitionBy("value").orderBy("key")), + mean("key").over(Window.partitionBy("value").orderBy("key")), + count("key").over(Window.partitionBy("value").orderBy("key")), + sum("key").over(Window.partitionBy("value").orderBy("key")), + ntile(2).over(Window.partitionBy("value").orderBy("key")), + rowNumber().over(Window.partitionBy("value").orderBy("key")), + denseRank().over(Window.partitionBy("value").orderBy("key")), + rank().over(Window.partitionBy("value").orderBy("key")), + cumeDist().over(Window.partitionBy("value").orderBy("key")), + percentRank().over(Window.partitionBy("value").orderBy("key"))), + sql( + s"""SELECT + |key, + |max(key) over (partition by value order by key), + |min(key) over (partition by value order by key), + |avg(key) over (partition by value order by key), + |count(key) over (partition by value order by key), + |sum(key) over (partition by value order by key), + |ntile(2) over (partition by value order by key), + |row_number() over (partition by value order by key), + |dense_rank() over (partition by value order by key), + |rank() over (partition by value order by key), + |cume_dist() over (partition by value order by key), + |percent_rank() over (partition by value order by key) + |FROM window_table""".stripMargin).collect()) + } + + test("aggregation and rows between") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), + sql( + """SELECT + | avg(key) OVER + | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation and range betweens") { + val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), + sql( + """SELECT + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation and rows betweens with unbounded") { + val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + last("value").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), + last("value").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), + last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))), + sql( + """SELECT + | key, + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following), + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row), + | last_value(value) OVER + | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following) + | FROM window_table""".stripMargin).collect()) + } + + test("aggregation and range betweens with unbounded") { + val df = Seq((1, "1"), (2, "2"), (2, "2"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + df.registerTempTable("window_table") + checkAnswer( + df.select( + $"key", + last("value").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(1, Long.MaxValue)) + .equalTo("2") + .as("last_v"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(Long.MinValue, 1)) + .as("avg_key1"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(0, Long.MaxValue)) + .as("avg_key2"), + avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) + .as("avg_key3") + ), + sql( + """SELECT + | key, + | last_value(value) OVER + | (PARTITION BY value ORDER BY key RANGE 1 preceding) == "2", + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following), + | avg(key) OVER + | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) + | FROM window_table""".stripMargin).collect()) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 2a7374cc172b7..df137e7b2b333 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -26,12 +26,12 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.io.LongWritable -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Literal, Row} import org.apache.spark.sql.types._ -class HiveInspectorSuite extends FunSuite with HiveInspectors { +class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { test("Test wrap SettableStructObjectInspector") { val udaf = new UDAFPercentile.PercentileLongEvaluator() udaf.init() @@ -78,10 +78,10 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { Literal(java.sql.Date.valueOf("2014-09-23")) :: Literal(Decimal(BigDecimal(123.123))) :: Literal(new java.sql.Timestamp(123123)) :: - Literal(Array[Byte](1,2,3)) :: - Literal.create(Seq[Int](1,2,3), ArrayType(IntegerType)) :: - Literal.create(Map[Int, Int](1->2, 2->1), MapType(IntegerType, IntegerType)) :: - Literal.create(Row(1,2.0d,3.0f), + Literal(Array[Byte](1, 2, 3)) :: + Literal.create(Seq[Int](1, 2, 3), ArrayType(IntegerType)) :: + Literal.create(Map[Int, Int](1 -> 2, 2 -> 1), MapType(IntegerType, IntegerType)) :: + Literal.create(Row(1, 2.0d, 3.0f), StructType(StructField("c1", IntegerType) :: StructField("c2", DoubleType) :: StructField("c3", FloatType) :: Nil)) :: @@ -111,8 +111,8 @@ class HiveInspectorSuite extends FunSuite with HiveInspectors { case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( - java.util.Arrays.asList(fields.map(f => f.name) :_*), - java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) :_*)) + java.util.Arrays.asList(fields.map(f => f.name) : _*), + java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) : _*)) } def checkDataType(dt1: Seq[DataType], dt2: Seq[DataType]): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index fa8e11ffec2b4..e9bb32667936c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.hive +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.hive.test.TestHive -import org.scalatest.FunSuite import org.apache.spark.sql.test.ExamplePointUDT import org.apache.spark.sql.types.StructType -class HiveMetastoreCatalogSuite extends FunSuite { +class HiveMetastoreCatalogSuite extends SparkFunSuite { test("struct field should accept underscore in sub-column name") { val metastr = "struct" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 7ff5719adb3ab..5a5ea10e3c82e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -55,8 +55,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test(s"$prefix: Converting Hive to Parquet Table via saveAsParquetFile") { withTempPath { dir => - sql("SELECT * FROM src").saveAsParquetFile(dir.getCanonicalPath) - parquetFile(dir.getCanonicalPath).registerTempTable("p") + sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) + read.parquet(dir.getCanonicalPath).registerTempTable("p") withTempTable("p") { checkAnswer( sql("SELECT * FROM src ORDER BY key"), @@ -68,8 +68,8 @@ class HiveParquetSuite extends QueryTest with ParquetTest { test(s"$prefix: INSERT OVERWRITE TABLE Parquet table") { withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { withTempPath { file => - sql("SELECT * FROM t LIMIT 1").saveAsParquetFile(file.getCanonicalPath) - parquetFile(file.getCanonicalPath).registerTempTable("p") + sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) + read.parquet(file.getCanonicalPath).registerTempTable("p") withTempTable("p") { // let's do three overwrites for good measure sql("INSERT OVERWRITE TABLE p SELECT * FROM t") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 941a2941649b8..f765395e148af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.client.{ManagedTable, HiveColumn, ExternalTable, HiveTable} -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class HiveQlSuite extends FunSuite with BeforeAndAfterAll { +class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { override def beforeAll() { if (SessionState.get() == null) { SessionState.start(new HiveConf()) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index ecb990e8aac91..aa5dbe2db6903 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -53,7 +53,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { sql("CREATE TABLE createAndInsertTest (key int, value string)") // Add some data. - testData.insertInto("createAndInsertTest") + testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has also been updated. checkAnswer( @@ -62,7 +62,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { ) // Add more data. - testData.insertInto("createAndInsertTest") + testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest") // Make sure the table has been updated. checkAnswer( @@ -71,7 +71,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { ) // Now overwrite. - testData.insertInto("createAndInsertTest", overwrite = true) + testData.write.mode(SaveMode.Overwrite).insertInto("createAndInsertTest") // Make sure the registered table has also been updated. checkAnswer( @@ -160,7 +160,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=1"::Nil , "p1=a"::"p2=b"::"p3=c"::"p4=c"::"p5=4"::Nil ) - assert(listFolders(tmpDir,List()).sortBy(_.toString()) == expected.sortBy(_.toString)) + assert(listFolders(tmpDir, List()).sortBy(_.toString()) == expected.sortBy(_.toString)) sql("DROP TABLE table_with_partition") sql("DROP TABLE tmp_table") } @@ -240,7 +240,7 @@ class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter { checkAnswer(sql("select key,value from table_with_partition where ds='1' "), testData.collect().toSeq ) - + // test difference type of field sql("ALTER TABLE table_with_partition CHANGE COLUMN key key BIGINT") checkAnswer(sql("select key,value from table_with_partition where ds='1' "), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index e12a6c21ccac4..1c15997ea8e6d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -29,7 +29,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfterAll { import org.apache.spark.sql.hive.test.TestHive.implicits._ val df = - sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value") + sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") override def beforeAll(): Unit = { // The catalog in HiveContext is a case insensitive one. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 47c60f651d14c..58e2d1fbfa73e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -21,744 +21,818 @@ import java.io.File import scala.collection.mutable.ArrayBuffer -import org.scalatest.BeforeAndAfterEach +import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.metastore.TableType -import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.mapred.InvalidInputException import org.apache.spark.sql._ -import org.apache.spark.util.Utils -import org.apache.spark.sql.types._ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils /** * Tests for persisting tables created though the data sources API into the metastore. */ -class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { +class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + override val sqlContext = TestHive - override def afterEach(): Unit = { - reset() - Utils.deleteRecursively(tempPath) + var jsonFilePath: String = _ + + override def beforeAll(): Unit = { + jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile } - val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile - var tempPath: File = Utils.createTempDir() - tempPath.delete() - - test ("persistent JSON table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - jsonFile(filePath).collect().toSeq) + test("persistent JSON table") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath).collect().toSeq) + } } - test ("persistent JSON table with a user specified schema") { - sql( - s""" - |CREATE TABLE jsonTable ( - |a string, - |b String, - |`c_!@(3)` int, - |`` Struct<`d!`:array, `=`:array>>) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - jsonFile(filePath).registerTempTable("expectedJsonTable") - - checkAnswer( - sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), - sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable").collect().toSeq) + test("persistent JSON table with a user specified schema") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable ( + |a string, + |b String, + |`c_!@(3)` int, + |`` Struct<`d!`:array, `=`:array>>) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") + checkAnswer( + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM jsonTable"), + sql("SELECT a, b, `c_!@(3)`, ``.`d!`, ``.`=` FROM expectedJsonTable")) + } + } } - test ("persistent JSON table with a user specified schema with a subset of fields") { - // This works because JSON objects are self-describing and JSONRelation can get needed - // field values based on field names. - sql( - s""" - |CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - val innerStruct = StructType( - StructField("=", ArrayType(StructType(StructField("Dd2", BooleanType, true) :: Nil))) :: Nil) - val expectedSchema = StructType( - StructField("", innerStruct, true) :: - StructField("b", StringType, true) :: Nil) - - assert(expectedSchema === table("jsonTable").schema) - - jsonFile(filePath).registerTempTable("expectedJsonTable") - - checkAnswer( - sql("SELECT b, ``.`=` FROM jsonTable"), - sql("SELECT b, ``.`=` FROM expectedJsonTable").collect().toSeq) + test("persistent JSON table with a user specified schema with a subset of fields") { + withTable("jsonTable") { + // This works because JSON objects are self-describing and JSONRelation can get needed + // field values based on field names. + sql( + s"""CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + val innerStruct = StructType(Seq( + StructField("=", ArrayType(StructType(StructField("Dd2", BooleanType, true) :: Nil))))) + + val expectedSchema = StructType(Seq( + StructField("", innerStruct, true), + StructField("b", StringType, true))) + + assert(expectedSchema === table("jsonTable").schema) + + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") + checkAnswer( + sql("SELECT b, ``.`=` FROM jsonTable"), + sql("SELECT b, ``.`=` FROM expectedJsonTable")) + } + } } test("resolve shortened provider names") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - jsonFile(filePath).collect().toSeq) + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath).collect().toSeq) + } } test("drop table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - jsonFile(filePath).collect().toSeq) - - sql("DROP TABLE jsonTable") - - intercept[Exception] { - sql("SELECT * FROM jsonTable").collect() - } + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) - assert( - (new File(filePath)).exists(), - "The table with specified path is considered as an external table, " + - "its data should not deleted after DROP TABLE.") + checkAnswer( + sql("SELECT * FROM jsonTable"), + read.json(jsonFilePath)) + + sql("DROP TABLE jsonTable") + + intercept[Exception] { + sql("SELECT * FROM jsonTable").collect() + } + + assert( + new File(jsonFilePath).exists(), + "The table with specified path is considered as an external table, " + + "its data should not deleted after DROP TABLE.") + } } test("check change without refresh") { - val tempDir = File.createTempFile("sparksql", "json", Utils.createTempDir()) - tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b")) - - Utils.deleteRecursively(tempDir) - sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - // Schema is cached so the new column does not show. The updated values in existing columns - // will show. - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a1", "b1")) - - sql("REFRESH TABLE jsonTable") - - // Check that the refresh worked - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a1", "b1", "c1")) - Utils.deleteRecursively(tempDir) + withTempPath { tempDir => + withTable("jsonTable") { + (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b")) + + Utils.deleteRecursively(tempDir) + (("a1", "b1", "c1") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + // Schema is cached so the new column does not show. The updated values in existing columns + // will show. + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a1", "b1")) + + sql("REFRESH TABLE jsonTable") + + // Check that the refresh worked + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a1", "b1", "c1")) + } + } } test("drop, change, recreate") { - val tempDir = File.createTempFile("sparksql", "json", Utils.createTempDir()) - tempDir.delete() - sparkContext.parallelize(("a", "b") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b")) - - Utils.deleteRecursively(tempDir) - sparkContext.parallelize(("a", "b", "c") :: Nil).toDF() - .toJSON.saveAsTextFile(tempDir.getCanonicalPath) - - sql("DROP TABLE jsonTable") - - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json - |OPTIONS ( - | path '${tempDir.getCanonicalPath}' - |) - """.stripMargin) - - // New table should reflect new schema. - checkAnswer( - sql("SELECT * FROM jsonTable"), - Row("a", "b", "c")) - Utils.deleteRecursively(tempDir) + withTempPath { tempDir => + (("a", "b") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b")) + + Utils.deleteRecursively(tempDir) + (("a", "b", "c") :: Nil).toDF().toJSON.saveAsTextFile(tempDir.getCanonicalPath) + + sql("DROP TABLE jsonTable") + + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json + |OPTIONS ( + | path '${tempDir.getCanonicalPath}' + |) + """.stripMargin) + + // New table should reflect new schema. + checkAnswer( + sql("SELECT * FROM jsonTable"), + Row("a", "b", "c")) + } + } } test("invalidate cache and reload") { - sql( - s""" - |CREATE TABLE jsonTable (`c_!@(3)` int) - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable (`c_!@(3)` int) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) - jsonFile(filePath).registerTempTable("expectedJsonTable") + withTempTable("expectedJsonTable") { + read.json(jsonFilePath).registerTempTable("expectedJsonTable") - checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - // Discard the cached relation. - invalidateTable("jsonTable") + // Discard the cached relation. + invalidateTable("jsonTable") - checkAnswer( - sql("SELECT * FROM jsonTable"), - sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) + checkAnswer( + sql("SELECT * FROM jsonTable"), + sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq) - invalidateTable("jsonTable") - val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) + invalidateTable("jsonTable") + val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil) - assert(expectedSchema === table("jsonTable").schema) + assert(expectedSchema === table("jsonTable").schema) + } + } } test("CTAS") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - - assert(table("ctasJsonTable").schema === table("jsonTable").schema) - - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM jsonTable").collect()) + withTempPath { tempPath => + withTable("jsonTable", "ctasJsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + } } test("CTAS with IF NOT EXISTS") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - - // Create the table again should trigger a AnalysisException. - val message = intercept[AnalysisException] { - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT * FROM jsonTable - """.stripMargin) - }.getMessage - assert(message.contains("Table ctasJsonTable already exists."), - "We should complain that ctasJsonTable already exists") - - // The following statement should be fine if it has IF NOT EXISTS. - // It tries to create a table ctasJsonTable with a new schema. - // The actual table's schema and data should not be changed. - sql( - s""" - |CREATE TABLE IF NOT EXISTS ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${tempPath}' - |) AS - |SELECT a FROM jsonTable - """.stripMargin) - - // Discard the cached relation. - invalidateTable("ctasJsonTable") - - // Schema should not be changed. - assert(table("ctasJsonTable").schema === table("jsonTable").schema) - // Table data should not be changed. - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM jsonTable").collect()) + withTempPath { path => + val tempPath = path.getCanonicalPath + + withTable("jsonTable", "ctasJsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + + // Create the table again should trigger a AnalysisException. + val message = intercept[AnalysisException] { + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT * FROM jsonTable + """.stripMargin) + }.getMessage + + assert( + message.contains("Table ctasJsonTable already exists."), + "We should complain that ctasJsonTable already exists") + + // The following statement should be fine if it has IF NOT EXISTS. + // It tries to create a table ctasJsonTable with a new schema. + // The actual table's schema and data should not be changed. + sql( + s"""CREATE TABLE IF NOT EXISTS ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$tempPath' + |) AS + |SELECT a FROM jsonTable + """.stripMargin) + + // Discard the cached relation. + invalidateTable("ctasJsonTable") + + // Schema should not be changed. + assert(table("ctasJsonTable").schema === table("jsonTable").schema) + // Table data should not be changed. + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM jsonTable").collect()) + } + } } test("CTAS a managed table") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${filePath}' - |) - """.stripMargin) - - val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") - val filesystemPath = new Path(expectedPath) - val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) - if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) - - // It is a managed table when we do not specify the location. - sql( - s""" - |CREATE TABLE ctasJsonTable - |USING org.apache.spark.sql.json.DefaultSource - |AS - |SELECT * FROM jsonTable - """.stripMargin) - - assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") - - sql( - s""" - |CREATE TABLE loadedTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path '${expectedPath}' - |) - """.stripMargin) - - assert(table("ctasJsonTable").schema === table("loadedTable").schema) - - checkAnswer( - sql("SELECT * FROM ctasJsonTable"), - sql("SELECT * FROM loadedTable").collect() - ) - - sql("DROP TABLE ctasJsonTable") - assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") - } - - test("SPARK-5286 Fail to drop an invalid table when using the data source API") { - sql( - s""" - |CREATE TABLE jsonTable - |USING org.apache.spark.sql.json.DefaultSource - |OPTIONS ( - | path 'it is not a path at all!' - |) - """.stripMargin) - - sql("DROP TABLE jsonTable").collect().foreach(println) - } - - test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") { - val originalDefaultSource = conf.defaultDataSourceName - - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = jsonRDD(rdd) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - // Save the df as a managed table (by not specifiying the path). - df.saveAsTable("savedJsonTable") - - checkAnswer( - sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), - (1 to 4).map(i => Row(i, s"str${i}"))) + withTable("jsonTable", "ctasJsonTable", "loadedTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath' + |) + """.stripMargin) + + val expectedPath = catalog.hiveDefaultTableFilePath("ctasJsonTable") + val filesystemPath = new Path(expectedPath) + val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) + if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) + + // It is a managed table when we do not specify the location. + sql( + s"""CREATE TABLE ctasJsonTable + |USING org.apache.spark.sql.json.DefaultSource + |AS + |SELECT * FROM jsonTable + """.stripMargin) - checkAnswer( - sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), - (6 to 10).map(i => Row(i, s"str${i}"))) + assert(fs.exists(filesystemPath), s"$expectedPath should exist after we create the table.") - invalidateTable("savedJsonTable") + sql( + s"""CREATE TABLE loadedTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$expectedPath' + |) + """.stripMargin) - checkAnswer( - sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), - (1 to 4).map(i => Row(i, s"str${i}"))) + assert(table("ctasJsonTable").schema === table("loadedTable").schema) - checkAnswer( - sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), - (6 to 10).map(i => Row(i, s"str${i}"))) + checkAnswer( + sql("SELECT * FROM ctasJsonTable"), + sql("SELECT * FROM loadedTable")) - // Drop table will also delete the data. - sql("DROP TABLE savedJsonTable") + sql("DROP TABLE ctasJsonTable") + assert(!fs.exists(filesystemPath), s"$expectedPath should not exist after we drop the table.") + } + } - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + test("SPARK-5286 Fail to drop an invalid table when using the data source API") { + withTable("jsonTable") { + sql( + s"""CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path 'it is not a path at all!' + |) + """.stripMargin) + + sql("DROP TABLE jsonTable").collect().foreach(println) + } } - test("save table") { - val originalDefaultSource = conf.defaultDataSourceName + test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") { + withTable("savedJsonTable") { + // Save the df as a managed table (by not specifying the path). + (1 to 10) + .map(i => i -> s"str$i") + .toDF("a", "b") + .write + .format("json") + .saveAsTable("savedJsonTable") - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = jsonRDD(rdd) + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str$i"))) - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - // Save the df as a managed table (by not specifiying the path). - df.saveAsTable("savedJsonTable") + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str$i"))) - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) + invalidateTable("savedJsonTable") - // Right now, we cannot append to an existing JSON table. - intercept[RuntimeException] { - df.saveAsTable("savedJsonTable", SaveMode.Append) - } + checkAnswer( + sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"), + (1 to 4).map(i => Row(i, s"str$i"))) - // We can overwrite it. - df.saveAsTable("savedJsonTable", SaveMode.Overwrite) - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // When the save mode is Ignore, we will do nothing when the table already exists. - df.select("b").saveAsTable("savedJsonTable", SaveMode.Ignore) - assert(df.schema === table("savedJsonTable").schema) - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // Drop table will also delete the data. - sql("DROP TABLE savedJsonTable") - intercept[InvalidInputException] { - jsonFile(catalog.hiveDefaultTableFilePath("savedJsonTable")) + checkAnswer( + sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"), + (6 to 10).map(i => Row(i, s"str$i"))) } + } - // Create an external table by specifying the path. - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.saveAsTable( - "savedJsonTable", - "org.apache.spark.sql.json", - SaveMode.Append, - Map("path" -> tempPath.toString)) - checkAnswer( - sql("SELECT * FROM savedJsonTable"), - df.collect()) - - // Data should not be deleted after we drop the table. - sql("DROP TABLE savedJsonTable") - checkAnswer( - jsonFile(tempPath.toString), - df.collect()) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + test("save table") { + withTempPath { path => + val tempPath = path.getCanonicalPath + + withTable("savedJsonTable") { + val df = (1 to 10).map(i => i -> s"str$i").toDF("a", "b") + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "json") { + // Save the df as a managed table (by not specifying the path). + df.write.saveAsTable("savedJsonTable") + + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // Right now, we cannot append to an existing JSON table. + intercept[RuntimeException] { + df.write.mode(SaveMode.Append).saveAsTable("savedJsonTable") + } + + // We can overwrite it. + df.write.mode(SaveMode.Overwrite).saveAsTable("savedJsonTable") + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // When the save mode is Ignore, we will do nothing when the table already exists. + df.select("b").write.mode(SaveMode.Ignore).saveAsTable("savedJsonTable") + assert(df.schema === table("savedJsonTable").schema) + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + + // Drop table will also delete the data. + sql("DROP TABLE savedJsonTable") + intercept[InvalidInputException] { + read.json(catalog.hiveDefaultTableFilePath("savedJsonTable")) + } + } + + // Create an external table by specifying the path. + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + df.write + .format("org.apache.spark.sql.json") + .mode(SaveMode.Append) + .option("path", tempPath.toString) + .saveAsTable("savedJsonTable") + + checkAnswer(sql("SELECT * FROM savedJsonTable"), df) + } + + // Data should not be deleted after we drop the table. + sql("DROP TABLE savedJsonTable") + checkAnswer(read.json(tempPath.toString), df) + } + } } test("create external table") { - val originalDefaultSource = conf.defaultDataSourceName - - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - val df = jsonRDD(rdd) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - df.saveAsTable( - "savedJsonTable", - "org.apache.spark.sql.json", - SaveMode.Append, - Map("path" -> tempPath.toString)) - - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - createExternalTable("createdJsonTable", tempPath.toString) - assert(table("createdJsonTable").schema === df.schema) - checkAnswer( - sql("SELECT * FROM createdJsonTable"), - df.collect()) - - var message = intercept[AnalysisException] { - createExternalTable("createdJsonTable", filePath.toString) - }.getMessage - assert(message.contains("Table createdJsonTable already exists."), - "We should complain that ctasJsonTable already exists") - - // Data should not be deleted. - sql("DROP TABLE createdJsonTable") - checkAnswer( - jsonFile(tempPath.toString), - df.collect()) - - // Try to specify the schema. - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - val schema = StructType(StructField("b", StringType, true) :: Nil) - createExternalTable( - "createdJsonTable", - "org.apache.spark.sql.json", - schema, - Map("path" -> tempPath.toString)) - checkAnswer( - sql("SELECT * FROM createdJsonTable"), - sql("SELECT b FROM savedJsonTable").collect()) - - sql("DROP TABLE createdJsonTable") - - message = intercept[RuntimeException] { - createExternalTable( - "createdJsonTable", - "org.apache.spark.sql.json", - schema, - Map.empty[String, String]) - }.getMessage - assert( - message.contains("'path' must be specified for json data."), - "We should complain that path is not specified.") - - sql("DROP TABLE savedJsonTable") - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + withTempPath { tempPath => + withTable("savedJsonTable", "createdJsonTable") { + val df = read.json(sparkContext.parallelize((1 to 10).map { i => + s"""{ "a": $i, "b": "str$i" }""" + })) + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + df.write + .format("json") + .mode(SaveMode.Append) + .option("path", tempPath.toString) + .saveAsTable("savedJsonTable") + } + + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "json") { + createExternalTable("createdJsonTable", tempPath.toString) + assert(table("createdJsonTable").schema === df.schema) + checkAnswer(sql("SELECT * FROM createdJsonTable"), df) + + assert( + intercept[AnalysisException] { + createExternalTable("createdJsonTable", jsonFilePath.toString) + }.getMessage.contains("Table createdJsonTable already exists."), + "We should complain that createdJsonTable already exists") + } + + // Data should not be deleted. + sql("DROP TABLE createdJsonTable") + checkAnswer(read.json(tempPath.toString), df) + + // Try to specify the schema. + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME -> "not a source name") { + val schema = StructType(StructField("b", StringType, true) :: Nil) + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map("path" -> tempPath.toString)) + + checkAnswer( + sql("SELECT * FROM createdJsonTable"), + sql("SELECT b FROM savedJsonTable")) + + sql("DROP TABLE createdJsonTable") + + assert( + intercept[RuntimeException] { + createExternalTable( + "createdJsonTable", + "org.apache.spark.sql.json", + schema, + Map.empty[String, String]) + }.getMessage.contains("'path' must be specified for json data."), + "We should complain that path is not specified.") + } + } + } } if (HiveShim.version == "0.13.1") { test("scan a parquet table created through a CTAS statement") { - val originalConvertMetastore = getConf("spark.sql.hive.convertMetastoreParquet", "true") - val originalUseDataSource = getConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") - setConf("spark.sql.hive.convertMetastoreParquet", "true") - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") - - val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - jsonRDD(rdd).registerTempTable("jt") - sql( - """ - |create table test_parquet_ctas STORED AS parquET - |AS select tmp.a from jt tmp where tmp.a < 5 - """.stripMargin) - - checkAnswer( - sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), - Row(3) :: Row(4) :: Nil - ) - - table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK - case _ => - fail( - "test_parquet_ctas should be converted to " + - s"${classOf[ParquetRelation2].getCanonicalName}") + withSQLConf( + "spark.sql.hive.convertMetastoreParquet" -> "true", + SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") { + + withTempTable("jt") { + (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") + + withTable("test_parquet_ctas") { + sql( + """CREATE TABLE test_parquet_ctas STORED AS PARQUET + |AS SELECT tmp.a FROM jt tmp WHERE tmp.a < 5 + """.stripMargin) + + checkAnswer( + sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "), + Row(3) :: Row(4) :: Nil) + + table("test_parquet_ctas").queryExecution.optimizedPlan match { + case LogicalRelation(p: ParquetRelation2) => // OK + case _ => + fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") + } + } + } } - - // Clenup and reset confs. - sql("DROP TABLE IF EXISTS jt") - sql("DROP TABLE IF EXISTS test_parquet_ctas") - setConf("spark.sql.hive.convertMetastoreParquet", originalConvertMetastore) - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalUseDataSource) } } test("Pre insert nullability check (ArrayType)") { - val df1 = - createDataFrame(Tuple1(Seq(Int.box(1), null.asInstanceOf[Integer])) :: Nil).toDF("a") - val expectedSchema1 = - StructType( - StructField("a", ArrayType(IntegerType, containsNull = true), nullable = true) :: Nil) - assert(df1.schema === expectedSchema1) - df1.saveAsTable("arrayInParquet", "parquet", SaveMode.Overwrite) - - val df2 = - createDataFrame(Tuple1(Seq(2, 3)) :: Nil).toDF("a") - val expectedSchema2 = - StructType( - StructField("a", ArrayType(IntegerType, containsNull = false), nullable = true) :: Nil) - assert(df2.schema === expectedSchema2) - df2.insertInto("arrayInParquet", overwrite = false) - createDataFrame(Tuple1(Seq(4, 5)) :: Nil).toDF("a") - .saveAsTable("arrayInParquet", SaveMode.Append) // This one internally calls df2.insertInto. - createDataFrame(Tuple1(Seq(Int.box(6), null.asInstanceOf[Integer])) :: Nil).toDF("a") - .saveAsTable("arrayInParquet", "parquet", SaveMode.Append) - refreshTable("arrayInParquet") - - checkAnswer( - sql("SELECT a FROM arrayInParquet"), - Row(ArrayBuffer(1, null)) :: - Row(ArrayBuffer(2, 3)) :: - Row(ArrayBuffer(4, 5)) :: - Row(ArrayBuffer(6, null)) :: Nil) - - sql("DROP TABLE arrayInParquet") + withTable("arrayInParquet") { + { + val df = (Tuple1(Seq(Int.box(1), null: Integer)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + ArrayType(IntegerType, containsNull = true), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("arrayInParquet") + } + + { + val df = (Tuple1(Seq(2, 3)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + ArrayType(IntegerType, containsNull = false), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Append) + .insertInto("arrayInParquet") + } + + (Tuple1(Seq(4, 5)) :: Nil).toDF("a") + .write + .mode(SaveMode.Append) + .saveAsTable("arrayInParquet") // This one internally calls df2.insertInto. + + (Tuple1(Seq(Int.box(6), null: Integer)) :: Nil).toDF("a") + .write + .mode(SaveMode.Append) + .saveAsTable("arrayInParquet") + + refreshTable("arrayInParquet") + + checkAnswer( + sql("SELECT a FROM arrayInParquet"), + Row(ArrayBuffer(1, null)) :: + Row(ArrayBuffer(2, 3)) :: + Row(ArrayBuffer(4, 5)) :: + Row(ArrayBuffer(6, null)) :: Nil) + } } test("Pre insert nullability check (MapType)") { - val df1 = - createDataFrame(Tuple1(Map(1 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") - val mapType1 = MapType(IntegerType, IntegerType, valueContainsNull = true) - val expectedSchema1 = - StructType( - StructField("a", mapType1, nullable = true) :: Nil) - assert(df1.schema === expectedSchema1) - df1.saveAsTable("mapInParquet", "parquet", SaveMode.Overwrite) - - val df2 = - createDataFrame(Tuple1(Map(2 -> 3)) :: Nil).toDF("a") - val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = false) - val expectedSchema2 = - StructType( - StructField("a", mapType2, nullable = true) :: Nil) - assert(df2.schema === expectedSchema2) - df2.insertInto("mapInParquet", overwrite = false) - createDataFrame(Tuple1(Map(4 -> 5)) :: Nil).toDF("a") - .saveAsTable("mapInParquet", SaveMode.Append) // This one internally calls df2.insertInto. - createDataFrame(Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") - .saveAsTable("mapInParquet", "parquet", SaveMode.Append) - refreshTable("mapInParquet") - - checkAnswer( - sql("SELECT a FROM mapInParquet"), - Row(Map(1 -> null)) :: - Row(Map(2 -> 3)) :: - Row(Map(4 -> 5)) :: - Row(Map(6 -> null)) :: Nil) - - sql("DROP TABLE mapInParquet") + withTable("mapInParquet") { + { + val df = (Tuple1(Map(1 -> (null: Integer))) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + MapType(IntegerType, IntegerType, valueContainsNull = true), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("mapInParquet") + } + + { + val df = (Tuple1(Map(2 -> 3)) :: Nil).toDF("a") + val expectedSchema = + StructType( + StructField( + "a", + MapType(IntegerType, IntegerType, valueContainsNull = false), + nullable = true) :: Nil) + + assert(df.schema === expectedSchema) + + df.write + .format("parquet") + .mode(SaveMode.Append) + .insertInto("mapInParquet") + } + + (Tuple1(Map(4 -> 5)) :: Nil).toDF("a") + .write + .format("parquet") + .mode(SaveMode.Append) + .saveAsTable("mapInParquet") // This one internally calls df2.insertInto. + + (Tuple1(Map(6 -> null.asInstanceOf[Integer])) :: Nil).toDF("a") + .write + .format("parquet") + .mode(SaveMode.Append) + .saveAsTable("mapInParquet") + + refreshTable("mapInParquet") + + checkAnswer( + sql("SELECT a FROM mapInParquet"), + Row(Map(1 -> null)) :: + Row(Map(2 -> 3)) :: + Row(Map(4 -> 5)) :: + Row(Map(6 -> null)) :: Nil) + } } test("SPARK-6024 wide schema support") { - // We will need 80 splits for this schema if the threshold is 4000. - val schema = StructType((1 to 5000).map(i => StructField(s"c_${i}", StringType, true))) - assert( - schema.json.size > conf.schemaStringLengthThreshold, - "To correctly test the fix of SPARK-6024, the value of " + - s"spark.sql.sources.schemaStringLengthThreshold needs to be less than ${schema.json.size}") - // Manually create a metastore data source table. - catalog.createDataSourceTable( - tableName = "wide_schema", - userSpecifiedSchema = Some(schema), - provider = "json", - options = Map("path" -> "just a dummy path"), - isExternal = false) - - invalidateTable("wide_schema") - - val actualSchema = table("wide_schema").schema - assert(schema === actualSchema) + withSQLConf(SQLConf.SCHEMA_STRING_LENGTH_THRESHOLD -> "4000") { + withTable("wide_schema") { + // We will need 80 splits for this schema if the threshold is 4000. + val schema = StructType((1 to 5000).map(i => StructField(s"c_$i", StringType, true))) + + // Manually create a metastore data source table. + catalog.createDataSourceTable( + tableName = "wide_schema", + userSpecifiedSchema = Some(schema), + partitionColumns = Array.empty[String], + provider = "json", + options = Map("path" -> "just a dummy path"), + isExternal = false) + + invalidateTable("wide_schema") + + val actualSchema = table("wide_schema").schema + assert(schema === actualSchema) + } + } } test("SPARK-6655 still support a schema stored in spark.sql.sources.schema") { val tableName = "spark6655" - val schema = StructType(StructField("int", IntegerType, true) :: Nil) - - val hiveTable = HiveTable( - specifiedDatabase = Some("default"), - name = tableName, - schema = Seq.empty, - partitionColumns = Seq.empty, - properties = Map( - "spark.sql.sources.provider" -> "json", - "spark.sql.sources.schema" -> schema.json, - "EXTERNAL" -> "FALSE"), - tableType = ManagedTable, - serdeProperties = Map( - "path" -> catalog.hiveDefaultTableFilePath(tableName))) - - catalog.client.createTable(hiveTable) - - invalidateTable(tableName) - val actualSchema = table(tableName).schema - assert(schema === actualSchema) - sql(s"drop table $tableName") + withTable(tableName) { + val schema = StructType(StructField("int", IntegerType, true) :: Nil) + val hiveTable = HiveTable( + specifiedDatabase = Some("default"), + name = tableName, + schema = Seq.empty, + partitionColumns = Seq.empty, + properties = Map( + "spark.sql.sources.provider" -> "json", + "spark.sql.sources.schema" -> schema.json, + "EXTERNAL" -> "FALSE"), + tableType = ManagedTable, + serdeProperties = Map( + "path" -> catalog.hiveDefaultTableFilePath(tableName))) + + catalog.client.createTable(hiveTable) + + invalidateTable(tableName) + val actualSchema = table(tableName).schema + assert(schema === actualSchema) + } } + test("Saving partition columns information") { + val df = (1 to 10).map(i => (i, i + 1, s"str$i", s"str${i + 1}")).toDF("a", "b", "c", "d") + val tableName = s"partitionInfo_${System.currentTimeMillis()}" + + withTable(tableName) { + df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName) + invalidateTable(tableName) + val metastoreTable = catalog.client.getTable("default", tableName) + val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil) + val actualPartitionColumns = + StructType( + metastoreTable.partitionColumns.map(c => + StructField(c.name, HiveMetastoreTypes.toDataType(c.hiveType)))) + // Make sure partition columns are correctly stored in metastore. + assert( + expectedPartitionColumns.sameType(actualPartitionColumns), + s"Partitions columns stored in metastore $actualPartitionColumns is not the " + + s"partition columns defined by the saveAsTable operation $expectedPartitionColumns.") + + // Check the content of the saved table. + checkAnswer( + table(tableName).select("c", "b", "d", "a"), + df.select("c", "b", "d", "a")) + } + } test("insert into a table") { - def createDF(from: Int, to: Int): DataFrame = - createDataFrame((from to to).map(i => Tuple2(i, s"str$i"))).toDF("c1", "c2") + def createDF(from: Int, to: Int): DataFrame = { + (from to to).map(i => i -> s"str$i").toDF("c1", "c2") + } - createDF(0, 9).saveAsTable("insertParquet", "parquet") - checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), - (6 to 9).map(i => Row(i, s"str$i"))) + withTable("insertParquet") { + createDF(0, 9).write.format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 9).map(i => Row(i, s"str$i"))) - intercept[AnalysisException] { - createDF(10, 19).saveAsTable("insertParquet", "parquet") - } + intercept[AnalysisException] { + createDF(10, 19).write.format("parquet").saveAsTable("insertParquet") + } - createDF(10, 19).saveAsTable("insertParquet", "parquet", SaveMode.Append) - checkAnswer( - sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), - (6 to 19).map(i => Row(i, s"str$i"))) + createDF(10, 19).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, p.c2 FROM insertParquet p WHERE p.c1 > 5"), + (6 to 19).map(i => Row(i, s"str$i"))) - createDF(20, 29).saveAsTable("insertParquet", "parquet", SaveMode.Append) - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), - (6 to 24).map(i => Row(i, s"str$i"))) + createDF(20, 29).write.mode(SaveMode.Append).format("parquet").saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 25"), + (6 to 24).map(i => Row(i, s"str$i"))) - intercept[AnalysisException] { - createDF(30, 39).saveAsTable("insertParquet") - } + intercept[AnalysisException] { + createDF(30, 39).write.saveAsTable("insertParquet") + } + + createDF(30, 39).write.mode(SaveMode.Append).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), + (6 to 34).map(i => Row(i, s"str$i"))) - createDF(30, 39).saveAsTable("insertParquet", SaveMode.Append) - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 35"), - (6 to 34).map(i => Row(i, s"str$i"))) - - createDF(40, 49).insertInto("insertParquet") - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), - (6 to 44).map(i => Row(i, s"str$i"))) - - createDF(50, 59).saveAsTable("insertParquet", SaveMode.Overwrite) - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), - (52 to 54).map(i => Row(i, s"str$i"))) - createDF(60, 69).saveAsTable("insertParquet", SaveMode.Ignore) - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), - (50 to 59).map(i => Row(i, s"str$i"))) - - createDF(70, 79).insertInto("insertParquet", overwrite = true) - checkAnswer( - sql("SELECT p.c1, c2 FROM insertParquet p"), - (70 to 79).map(i => Row(i, s"str$i"))) + createDF(40, 49).write.mode(SaveMode.Append).insertInto("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 5 AND p.c1 < 45"), + (6 to 44).map(i => Row(i, s"str$i"))) + + createDF(50, 59).write.mode(SaveMode.Overwrite).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p WHERE p.c1 > 51 AND p.c1 < 55"), + (52 to 54).map(i => Row(i, s"str$i"))) + createDF(60, 69).write.mode(SaveMode.Ignore).saveAsTable("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (50 to 59).map(i => Row(i, s"str$i"))) + + createDF(70, 79).write.mode(SaveMode.Overwrite).insertInto("insertParquet") + checkAnswer( + sql("SELECT p.c1, c2 FROM insertParquet p"), + (70 to 79).map(i => Row(i, s"str$i"))) + } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index 8afe5459d4f1b..a492ecf203d17 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.hive -import org.scalatest.FunSuite - -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.hive.test.TestHive -class SerializationSuite extends FunSuite { +class SerializationSuite extends SparkFunSuite { test("[SPARK-5840] HiveContext should be serializable") { val hiveContext = TestHive diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 85b6bc93d7122..8245047626d57 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -26,9 +26,9 @@ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { test("UDF case insensitive") { - udf.register("random0", () => { Math.random()}) - udf.register("RANDOM1", () => { Math.random()}) - udf.register("strlenScala", (_: String).length + (_:Int)) + udf.register("random0", () => { Math.random() }) + udf.register("RANDOM1", () => { Math.random() }) + udf.register("strlenScala", (_: String).length + (_: Int)) assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 321dc8d7322b8..7eb4842726665 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,18 +17,17 @@ package org.apache.spark.sql.hive.client -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils -import org.scalatest.FunSuite /** - * A simple set of tests that call the methods of a hive ClientInterface, loading different version - * of hive from maven central. These tests are simple in that they are mostly just testing to make - * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionallity + * A simple set of tests that call the methods of a hive ClientInterface, loading different version + * of hive from maven central. These tests are simple in that they are mostly just testing to make + * sure that reflective calls are not throwing NoSuchMethod error, but the actually functionality * is not fully tested. */ -class VersionsSuite extends FunSuite with Logging { +class VersionsSuite extends SparkFunSuite with Logging { private def buildConf() = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index 23ece7e7cf6e9..b0d3dd44daedc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.BeforeAndAfterAll -class ConcurrentHiveSuite extends FunSuite with BeforeAndAfterAll { +class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { test("Multiple Hive Instances") { (1 to 10).map { i => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 9c056e493bfde..c9dd4c0935a72 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.hive.execution import java.io._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} +import org.scalatest.{BeforeAndAfterAll, GivenWhenThen} -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.sources.DescribeCommand import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -40,7 +40,7 @@ import org.apache.spark.sql.hive.test.TestHive * configured using system properties. */ abstract class HiveComparisonTest - extends FunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { + extends SparkFunSuite with BeforeAndAfterAll with GivenWhenThen with Logging { /** * When set, any cache files that result in test failures will be deleted. Used when the test @@ -273,7 +273,7 @@ abstract class HiveComparisonTest } val hiveCacheFiles = queryList.zipWithIndex.map { - case (queryString, i) => + case (queryString, i) => val cachedAnswerName = s"$testCaseName-$i-${getMd5(queryString)}" new File(answerCache, cachedAnswerName) } @@ -304,7 +304,7 @@ abstract class HiveComparisonTest // other DDL has not been executed yet. hiveQueries.foreach(_.logical) val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { - case ((queryString, i), hiveQuery, cachedAnswerFile)=> + case ((queryString, i), hiveQuery, cachedAnswerFile) => try { // Hooks often break the harness and don't really affect our test anyway, don't // even try running them. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2c9c08a9f3898..440b7c87b0da2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} -import org.scalatest.BeforeAndAfter - import scala.util.Try +import org.scalatest.BeforeAndAfter + import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive @@ -51,14 +52,32 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) + sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) } override def afterAll() { TestHive.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) + sql("DROP TEMPORARY FUNCTION udtf_count2") } + createQueryTest("Test UDTF.close in Lateral Views", + """ + |SELECT key, cc + |FROM src LATERAL VIEW udtf_count2(value) dd AS cc + """.stripMargin, false) // false mean we have to keep the temp function in registry + + createQueryTest("Test UDTF.close in SELECT", + "SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) table", false) + test("SPARK-4908: concurrent hive native commands") { (1 to 100).par.map { _ => sql("USE default") @@ -90,13 +109,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { | SELECT key FROM gen_tmp ORDER BY key ASC; """.stripMargin) - test("multiple generator in projection") { + test("multiple generators in projection") { intercept[AnalysisException] { - sql("SELECT explode(map(key, value)), key FROM src").collect() + sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect() } intercept[AnalysisException] { - sql("SELECT explode(map(key, value)) as k1, k2, key FROM src").collect() + sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect() } } @@ -397,6 +416,25 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { |SELECT * FROM createdtable; """.stripMargin) + test("SPARK-7270: consider dynamic partition when comparing table output") { + sql(s"CREATE TABLE test_partition (a STRING) PARTITIONED BY (b BIGINT, c STRING)") + sql(s"CREATE TABLE ptest (a STRING, b BIGINT, c STRING)") + + val analyzedPlan = sql( + """ + |INSERT OVERWRITE table test_partition PARTITION (b=1, c) + |SELECT 'a', 'c' from ptest + """.stripMargin).queryExecution.analyzed + + assertResult(false, "Incorrect cast detected\n" + analyzedPlan) { + var hasCast = false + analyzedPlan.collect { + case p: Project => p.transformExpressionsUp { case c: Cast => hasCast = true; c } + } + hasCast + } + } + createQueryTest("transform", "SELECT TRANSFORM (key) USING 'cat' AS (tKey) FROM src") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index 8ad3627504229..b08db6de2d2f6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.test.TestHive.{sparkContext, jsonRDD, sql} +import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, jsonRDD, sql} import org.apache.spark.sql.hive.test.TestHive.implicits._ case class Nested(a: Int, B: Int) @@ -31,14 +31,14 @@ case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested]) class HiveResolutionSuite extends HiveComparisonTest { test("SPARK-3698: case insensitive test for nested data") { - jsonRDD(sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": [{"a": {"a": 1}}]}""" :: Nil)).registerTempTable("nested") // This should be successfully analyzed sql("SELECT a[0].A.A from nested").queryExecution.analyzed } test("SPARK-5278: check ambiguous reference to fields") { - jsonRDD(sparkContext.makeRDD( + read.json(sparkContext.makeRDD( """{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested") // there are 2 filed matching field name "b", we should report Ambiguous reference error @@ -77,7 +77,7 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection - sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("caseSensitivityTest") val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") @@ -88,14 +88,14 @@ class HiveResolutionSuite extends HiveComparisonTest { ignore("case insensitivity with scala reflection joins") { // Test resolution with Scala Reflection - sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("caseSensitivityTest") sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect() } test("nested repeated resolution") { - sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) + sparkContext.parallelize(Data(1, 2, Nested(1, 2), Seq(Nested(1, 2))) :: Nil) .toDF().registerTempTable("nestedRepeatedTest") assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index ab53c6309e089..2209fc2f30a3c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -61,7 +61,7 @@ class HiveTableScanSuite extends HiveComparisonTest { TestHive.sql("select KEY from tb where VALUE='just_for_test' limit 5").collect() TestHive.sql("drop table tb") } - + test("Spark-4077: timestamp query for null value") { TestHive.sql("DROP TABLE IF EXISTS timestamp_query_null") TestHive.sql( @@ -71,12 +71,12 @@ class HiveTableScanSuite extends HiveComparisonTest { FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' """.stripMargin) - val location = + val location = Utils.getSparkClassLoader.getResource("data/files/issue-4077-data.txt").getFile() - + TestHive.sql(s"LOAD DATA LOCAL INPATH '$location' INTO TABLE timestamp_query_null") - assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() - === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")),Row(null))) + assert(TestHive.sql("SELECT time from timestamp_query_null limit 2").collect() + === Array(Row(java.sql.Timestamp.valueOf("2014-12-11 00:00:00")), Row(null))) TestHive.sql("DROP TABLE timestamp_query_null") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 7f49eac490572..ce5985888f540 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -101,7 +101,7 @@ class HiveUdfSuite extends QueryTest { sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") TestHive.reset() } - + test("SPARK-2693 udaf aggregates test") { checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), sql("SELECT max(key) FROM src").collect().toSeq) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index a5744ccc68a47..aba3becb1bce2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql.DefaultParserDialect -import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf} -import org.apache.spark.sql.hive.MetastoreRelation +import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim} +import org.apache.spark.sql.hive.{HiveQLDialect, HiveShim, MetastoreRelation} import org.apache.spark.sql.parquet.ParquetRelation2 import org.apache.spark.sql.sources.LogicalRelation import org.apache.spark.sql.types._ @@ -328,7 +327,7 @@ class SQLQuerySuite extends QueryTest { "org.apache.hadoop.hive.ql.io.RCFileInputFormat", "org.apache.hadoop.hive.ql.io.RCFileOutputFormat", "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe", - "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22","MANAGED_TABLE" + "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) if (HiveShim.version =="0.13.1") { @@ -426,10 +425,10 @@ class SQLQuerySuite extends QueryTest { test("SPARK-4825 save join to table") { val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF() sql("CREATE TABLE test1 (key INT, value STRING)") - testData.insertInto("test1") + testData.write.mode(SaveMode.Append).insertInto("test1") sql("CREATE TABLE test2 (key INT, value STRING)") - testData.insertInto("test2") - testData.insertInto("test2") + testData.write.mode(SaveMode.Append).insertInto("test2") + testData.write.mode(SaveMode.Append).insertInto("test2") sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key") checkAnswer( table("test"), @@ -536,26 +535,49 @@ class SQLQuerySuite extends QueryTest { test("SPARK-4296 Grouping field with Hive UDF as sub expression") { val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") checkAnswer( sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), Row("str-1", 1970)) dropTempTable("data") - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) dropTempTable("data") } - test("resolve udtf with single alias") { + test("resolve udtf in projection #1") { val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") val df = sql("SELECT explode(a) AS val FROM data") val col = df("val") } + test("resolve udtf in projection #2") { + val rdd = sparkContext.makeRDD((1 to 2).map(i => s"""{"a":[$i, ${i + 1}]}""")) + jsonRDD(rdd).registerTempTable("data") + checkAnswer(sql("SELECT explode(map(1, 1)) FROM data LIMIT 1"), Row(1, 1) :: Nil) + checkAnswer(sql("SELECT explode(map(1, 1)) as (k1, k2) FROM data LIMIT 1"), Row(1, 1) :: Nil) + intercept[AnalysisException] { + sql("SELECT explode(map(1, 1)) as k1 FROM data LIMIT 1") + } + + intercept[AnalysisException] { + sql("SELECT explode(map(1, 1)) as (k1, k2, k3) FROM data LIMIT 1") + } + } + + // TGF with non-TGF in project is allowed in Spark SQL, but not in Hive + test("TGF with non-TGF in projection") { + val rdd = sparkContext.makeRDD( """{"a": "1", "b":"1"}""" :: Nil) + jsonRDD(rdd).registerTempTable("data") + checkAnswer( + sql("SELECT explode(map(a, b)) as (k1, k2), a, b FROM data"), + Row("1", "1", "1", "1") :: Nil) + } + test("logical.Project should not be resolved if it contains aggregates or generators") { // This test is used to test the fix of SPARK-5875. // The original issue was that Project's resolved will be true when it contains @@ -564,7 +586,7 @@ class SQLQuerySuite extends QueryTest { // PreInsertionCasts will actually start to work before ImplicitGenerate and then // generates an invalid query plan. val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i + 1}]}""")) - jsonRDD(rdd).registerTempTable("data") + read.json(rdd).registerTempTable("data") val originalConf = getConf("spark.sql.hive.convertCTAS", "false") setConf("spark.sql.hive.convertCTAS", "false") @@ -596,7 +618,7 @@ class SQLQuerySuite extends QueryTest { sql(s"DROP TABLE $tableName") } } - + test("SPARK-5203 union with different decimal precision") { Seq.empty[(Decimal, Decimal)] .toDF("d1", "d2") @@ -758,10 +780,130 @@ class SQLQuerySuite extends QueryTest { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: multiple window expressions in a single expression") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.registerTempTable("nums") + + val expected = + Row(1, 1, 1, 55, 1, 57) :: + Row(0, 2, 3, 55, 2, 60) :: + Row(1, 3, 6, 55, 4, 65) :: + Row(0, 4, 10, 55, 6, 71) :: + Row(1, 5, 15, 55, 9, 79) :: + Row(0, 6, 21, 55, 12, 88) :: + Row(1, 7, 28, 55, 16, 99) :: + Row(0, 8, 36, 55, 20, 111) :: + Row(1, 9, 45, 55, 25, 125) :: + Row(0, 10, 55, 55, 30, 140) :: Nil + + val actual = sql( + """ + |SELECT + | y, + | x, + | sum(x) OVER w1 AS running_sum, + | sum(x) OVER w2 AS total_sum, + | sum(x) OVER w3 AS running_sum_per_y, + | ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as combined2 + |FROM nums + |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT RoW), + | w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOuNDED FoLLOWING), + | w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + """.stripMargin) + + checkAnswer(actual, expected) + + dropTempTable("nums") + } + test("test case key when") { (1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t") checkAnswer( sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"), Row(0, "1") :: Row(22, "2") :: Row(0, "3") :: Row(44, "4") :: Row(0, "5") :: Nil) } + + test("SPARK-7595: Window will cause resolve failed with self join") { + checkAnswer(sql( + """ + |with + | v1 as (select key, count(value) over (partition by key) cnt_val from src), + | v2 as (select v1.key, v1_lag.cnt_val from v1, v1 v1_lag where v1.key = v1_lag.key) + | select * from v2 order by key limit 1 + """.stripMargin), Row(0, 3)) + } + + test("SPARK-7269 Check analysis failed in case in-sensitive") { + Seq(1, 2, 3).map { i => + (i.toString, i.toString) + }.toDF("key", "value").registerTempTable("df_analysis") + sql("SELECT kEy from df_analysis group by key").collect() + sql("SELECT kEy+3 from df_analysis group by key+3").collect() + sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect() + sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect() + sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect() + sql("SELECT 2 from df_analysis A group by key+1").collect() + intercept[AnalysisException] { + sql("SELECT kEy+1 from df_analysis group by key+3") + } + intercept[AnalysisException] { + sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)") + } + } + + // `Math.exp(1.0)` has different result for different jdk version, so not use createQueryTest + test("udf_java_method") { + checkAnswer(sql( + """ + |SELECT java_method("java.lang.String", "valueOf", 1), + | java_method("java.lang.String", "isEmpty"), + | java_method("java.lang.Math", "max", 2, 3), + | java_method("java.lang.Math", "min", 2, 3), + | java_method("java.lang.Math", "round", 2.5), + | java_method("java.lang.Math", "exp", 1.0), + | java_method("java.lang.Math", "floor", 1.9) + |FROM src tablesample (1 rows) + """.stripMargin), + Row( + "1", + "true", + java.lang.Math.max(2, 3).toString, + java.lang.Math.min(2, 3).toString, + java.lang.Math.round(2.5).toString, + java.lang.Math.exp(1.0).toString, + java.lang.Math.floor(1.9).toString)) + } + + test("dynamic partition value test") { + try { + sql("set hive.exec.dynamic.partition.mode=nonstrict") + // date + sql("drop table if exists dynparttest1") + sql("create table dynparttest1 (value int) partitioned by (pdate date)") + sql( + """ + |insert into table dynparttest1 partition(pdate) + | select count(*), cast('2015-05-21' as date) as pdate from src + """.stripMargin) + checkAnswer( + sql("select * from dynparttest1"), + Seq(Row(500, java.sql.Date.valueOf("2015-05-21")))) + + // decimal + sql("drop table if exists dynparttest2") + sql("create table dynparttest2 (value int) partitioned by (pdec decimal(5, 1))") + sql( + """ + |insert into table dynparttest2 partition(pdec) + | select count(*), cast('100.12' as decimal(5, 1)) as pdec from src + """.stripMargin) + checkAnswer( + sql("select * from dynparttest2"), + Seq(Row(500, new java.math.BigDecimal("100.1")))) + } finally { + sql("drop table if exists dynparttest1") + sql("drop table if exists dynparttest2") + sql("set hive.exec.dynamic.partition.mode=strict") + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala new file mode 100644 index 0000000000000..080af5bb23c16 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.sources.HadoopFsRelationTest +import org.apache.spark.sql.types._ + +class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[DefaultSource].getCanonicalName + + import sqlContext._ + import sqlContext.implicits._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) + .toDF("a", "b", "p1") + .write + .format("orc") + .save(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + load( + source = dataSourceName, + options = Map( + "path" -> file.getCanonicalPath, + "dataSchema" -> dataSchemaWithPartition.json))) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala new file mode 100644 index 0000000000000..0e63d84e9824a --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ +import org.apache.spark.util.Utils +import org.scalatest.BeforeAndAfterAll + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + + +// The data where the partitioning key exists only in the directory structure. +case class OrcParData(intField: Int, stringField: String) + +// The data that also includes the partitioning key +case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) + +// TODO This test suite duplicates ParquetPartitionDiscoverySuite a lot +class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { + val defaultPartitionName = ConfVars.DEFAULTPARTITIONNAME.defaultVal + + def withTempDir(f: File => Unit): Unit = { + val dir = Utils.createTempDir().getCanonicalFile + try f(dir) finally Utils.deleteRecursively(dir) + } + + def makeOrcFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().write.format("orc").mode("overwrite").save(path.getCanonicalPath) + } + + + def makeOrcFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.write.format("orc").mode("overwrite").save(path.getCanonicalPath) + } + + protected def withTempTable(tableName: String)(f: => Unit): Unit = { + try f finally TestHive.dropTempTable(tableName) + } + + protected def makePartitionDir( + basePath: File, + defaultPartitionName: String, + partitionCols: (String, Any)*): File = { + val partNames = partitionCols.map { case (k, v) => + val valueString = if (v == null || v == "") defaultPartitionName else v.toString + s"$k=$valueString" + } + + val partDir = partNames.foldLeft(basePath) { (parent, child) => + new File(parent, child) + } + + assert(partDir.mkdirs(), s"Couldn't create directory $partDir") + partDir + } + + test("read partitioned table - normal case") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeOrcFile( + (1 to 10).map(i => OrcParData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + read.format("orc").load(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, i.toString, 1, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, i.toString, pi, "foo")) + } + } + } + + test("read partitioned table - partition key included in orc file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } { + makeOrcFile( + (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + read.format("orc").load(base.getCanonicalPath).registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", "bar") + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT intField, pi FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + _ <- Seq("foo", "bar") + } yield Row(i, pi)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi = 1"), + for { + i <- 1 to 10 + ps <- Seq("foo", "bar") + } yield Row(i, 1, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps = 'foo'"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, "foo")) + } + } + } + + + test("read partitioned table - with nulls") { + withTempDir { base => + for { + // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero... + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeOrcFile( + (1 to 10).map(i => OrcParData(i, i.toString)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + read + .format("orc") + .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + .load(base.getCanonicalPath) + .registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, pi, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE pi IS NULL"), + for { + i <- 1 to 10 + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, i.toString, null, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, null.asInstanceOf[Integer]) + } yield Row(i, i.toString, pi, null)) + } + } + } + + test("read partitioned table - with nulls and partition keys are included in Orc file") { + withTempDir { base => + for { + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } { + makeOrcFile( + (1 to 10).map(i => OrcParDataWithKey(i, pi, i.toString, ps)), + makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) + } + + read + .format("orc") + .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) + .load(base.getCanonicalPath) + .registerTempTable("t") + + withTempTable("t") { + checkAnswer( + sql("SELECT * FROM t"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + ps <- Seq("foo", null.asInstanceOf[String]) + } yield Row(i, pi, i.toString, ps)) + + checkAnswer( + sql("SELECT * FROM t WHERE ps IS NULL"), + for { + i <- 1 to 10 + pi <- Seq(1, 2) + } yield Row(i, pi, i.toString, null)) + } + } + } +} + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala new file mode 100644 index 0000000000000..57c23fe77f8b5 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.io.orc.CompressionKind +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive.implicits._ + +case class AllDataTypesWithNonPrimitiveType( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean, + array: Seq[Int], + arrayContainsNull: Seq[Option[Int]], + map: Map[Int, Long], + mapValueContainsNull: Map[Int, Option[Long]], + data: (Seq[Int], (Int, String))) + +case class BinaryData(binaryData: Array[Byte]) + +case class Contact(name: String, phone: String) + +case class Person(name: String, age: Int, contacts: Seq[Contact]) + +class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { + override val sqlContext = TestHive + + import TestHive.read + + def getTempFilePath(prefix: String, suffix: String = ""): File = { + val tempFile = File.createTempFile(prefix, suffix) + tempFile.delete() + tempFile + } + + test("Read/write All Types") { + val data = (0 to 255).map { i => + (s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0) + } + + withOrcFile(data) { file => + checkAnswer( + read.format("orc").load(file), + data.toDF().collect()) + } + } + + test("Read/write binary data") { + withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file => + val bytes = read.format("orc").load(file).head().getAs[Array[Byte]](0) + assert(new String(bytes, "utf8") === "test") + } + } + + test("Read/write all types with non-primitive type") { + val data = (0 to 255).map { i => + AllDataTypesWithNonPrimitiveType( + s"$i", i, i.toLong, i.toFloat, i.toDouble, i.toShort, i.toByte, i % 2 == 0, + 0 until i, + (0 until i).map(Option(_).filter(_ % 3 == 0)), + (0 until i).map(i => i -> i.toLong).toMap, + (0 until i).map(i => i -> Option(i.toLong)).toMap + (i -> None), + (0 until i, (i, s"$i"))) + } + + withOrcFile(data) { file => + checkAnswer( + read.format("orc").load(file), + data.toDF().collect()) + } + } + + test("Creating case class RDD table") { + val data = (1 to 100).map(i => (i, s"val_$i")) + sparkContext.parallelize(data).toDF().registerTempTable("t") + withTempTable("t") { + checkAnswer(sql("SELECT * FROM t"), data.toDF().collect()) + } + } + + test("Simple selection form ORC table") { + val data = (1 to 10).map { i => + Person(s"name_$i", i, (0 to 1).map { m => Contact(s"contact_$m", s"phone_$m") }) + } + + withOrcTable(data, "t") { + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // expr = leaf-0 + assert(sql("SELECT name FROM t WHERE age <= 5").count() === 5) + + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // expr = (not leaf-0) + assertResult(10) { + sql("SELECT name, contacts FROM t where age > 5") + .flatMap(_.getAs[Seq[_]]("contacts")) + .count() + } + + // ppd: + // leaf-0 = (LESS_THAN_EQUALS age 5) + // leaf-1 = (LESS_THAN age 8) + // expr = (and (not leaf-0) leaf-1) + { + val df = sql("SELECT name, contacts FROM t WHERE age > 5 AND age < 8") + assert(df.count() === 2) + assertResult(4) { + df.flatMap(_.getAs[Seq[_]]("contacts")).count() + } + } + + // ppd: + // leaf-0 = (LESS_THAN age 2) + // leaf-1 = (LESS_THAN_EQUALS age 8) + // expr = (or leaf-0 (not leaf-1)) + { + val df = sql("SELECT name, contacts FROM t WHERE age < 2 OR age > 8") + assert(df.count() === 3) + assertResult(6) { + df.flatMap(_.getAs[Seq[_]]("contacts")).count() + } + } + } + } + + test("save and load case class RDD with `None`s as orc") { + val data = ( + None: Option[Int], + None: Option[Long], + None: Option[Float], + None: Option[Double], + None: Option[Boolean] + ) :: Nil + + withOrcFile(data) { file => + checkAnswer( + read.format("orc").load(file), + Row(Seq.fill(5)(null): _*)) + } + } + + // We only support zlib in Hive 0.12.0 now + test("Default compression options for writing to an ORC file") { + withOrcFile((1 to 100).map(i => (i, s"val_$i"))) { file => + assertResult(CompressionKind.ZLIB) { + OrcFileOperator.getFileReader(file).getCompression + } + } + } + + // Following codec is supported in hive-0.13.1, ignore it now + ignore("Other compression options for writing to an ORC file - 0.13.1 and above") { + val data = (1 to 100).map(i => (i, s"val_$i")) + val conf = sparkContext.hadoopConfiguration + + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "SNAPPY") + withOrcFile(data) { file => + assertResult(CompressionKind.SNAPPY) { + OrcFileOperator.getFileReader(file).getCompression + } + } + + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "NONE") + withOrcFile(data) { file => + assertResult(CompressionKind.NONE) { + OrcFileOperator.getFileReader(file).getCompression + } + } + + conf.set(ConfVars.HIVE_ORC_DEFAULT_COMPRESS.varname, "LZO") + withOrcFile(data) { file => + assertResult(CompressionKind.LZO) { + OrcFileOperator.getFileReader(file).getCompression + } + } + } + + test("simple select queries") { + withOrcTable((0 until 10).map(i => (i, i.toString)), "t") { + checkAnswer( + sql("SELECT `_1` FROM t where t.`_1` > 5"), + (6 until 10).map(Row.apply(_))) + + checkAnswer( + sql("SELECT `_1` FROM t as tmp where tmp.`_1` < 5"), + (0 until 5).map(Row.apply(_))) + } + } + + test("appending") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withOrcTable(data, "t") { + sql("INSERT INTO TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) + } + + test("overwriting") { + val data = (0 until 10).map(i => (i, i.toString)) + createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + withOrcTable(data, "t") { + sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") + checkAnswer(table("t"), data.map(Row.fromTuple)) + } + catalog.unregisterTable(Seq("tmp")) + } + + test("self-join") { + // 4 rows, cells of column 1 of row 2 and row 4 are null + val data = (1 to 4).map { i => + val maybeInt = if (i % 2 == 0) None else Some(i) + (maybeInt, i.toString) + } + + withOrcTable(data, "t") { + val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x.`_1` = y.`_1`") + val queryOutput = selfJoin.queryExecution.analyzed.output + + assertResult(4, "Field count mismatches")(queryOutput.size) + assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + queryOutput.filter(_.name == "_1").map(_.exprId).size + } + + checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3"))) + } + } + + test("nested data - struct with array field") { + val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + withOrcTable(data, "t") { + checkAnswer(sql("SELECT `_1`.`_2`[0] FROM t"), data.map { + case Tuple1((_, Seq(string))) => Row(string) + }) + } + } + + test("nested data - array of struct") { + val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + withOrcTable(data, "t") { + checkAnswer(sql("SELECT `_1`[0].`_2` FROM t"), data.map { + case Tuple1(Seq((_, string))) => Row(string) + }) + } + } + + test("columns only referenced by pushed down filters should remain") { + withOrcTable((1 to 10).map(Tuple1.apply), "t") { + checkAnswer(sql("SELECT `_1` FROM t WHERE `_1` < 10"), (1 to 9).map(Row.apply(_))) + } + } + + test("SPARK-5309 strings stored using dictionary compression in orc") { + withOrcTable((0 until 1000).map(i => ("same", "run_" + i / 100, 1)), "t") { + checkAnswer( + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t GROUP BY `_1`, `_2`"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer( + sql("SELECT `_1`, `_2`, SUM(`_3`) FROM t WHERE `_2` = 'run_5' GROUP BY `_1`, `_2`"), + List(Row("same", "run_5", 100))) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala new file mode 100644 index 0000000000000..82e08caf46457 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.{QueryTest, Row} + +case class OrcData(intField: Int, stringField: String) + +abstract class OrcSuite extends QueryTest with BeforeAndAfterAll { + var orcTableDir: File = null + var orcTableAsDir: File = null + + override def beforeAll(): Unit = { + super.beforeAll() + + orcTableAsDir = File.createTempFile("orctests", "sparksql") + orcTableAsDir.delete() + orcTableAsDir.mkdir() + + // Hack: to prepare orc data files using hive external tables + orcTableDir = File.createTempFile("orctests", "sparksql") + orcTableDir.delete() + orcTableDir.mkdir() + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + sparkContext + .makeRDD(1 to 10) + .map(i => OrcData(i, s"part-$i")) + .toDF() + .registerTempTable(s"orc_temp_table") + + sql( + s"""CREATE EXTERNAL TABLE normal_orc( + | intField INT, + | stringField STRING + |) + |STORED AS ORC + |LOCATION '${orcTableAsDir.getCanonicalPath}' + """.stripMargin) + + sql( + s"""INSERT INTO TABLE normal_orc + |SELECT intField, stringField FROM orc_temp_table + """.stripMargin) + } + + override def afterAll(): Unit = { + orcTableDir.delete() + orcTableAsDir.delete() + } + + test("create temporary orc table") { + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_source"), Row(10)) + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + (1 to 10).map(i => Row(i, s"part-$i"))) + + checkAnswer( + sql("SELECT * FROM normal_orc_source where intField > 5"), + (6 to 10).map(i => Row(i, s"part-$i"))) + + checkAnswer( + sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + (1 to 10).map(i => Row(1, s"part-$i"))) + } + + test("create temporary orc table as") { + checkAnswer(sql("SELECT COUNT(*) FROM normal_orc_as_source"), Row(10)) + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + (1 to 10).map(i => Row(i, s"part-$i"))) + + checkAnswer( + sql("SELECT * FROM normal_orc_source WHERE intField > 5"), + (6 to 10).map(i => Row(i, s"part-$i"))) + + checkAnswer( + sql("SELECT COUNT(intField), stringField FROM normal_orc_source GROUP BY stringField"), + (1 to 10).map(i => Row(1, s"part-$i"))) + } + + test("appending insert") { + sql("INSERT INTO TABLE normal_orc_source SELECT * FROM orc_temp_table WHERE intField > 5") + + checkAnswer( + sql("SELECT * FROM normal_orc_source"), + (1 to 5).map(i => Row(i, s"part-$i")) ++ (6 to 10).flatMap { i => + Seq.fill(2)(Row(i, s"part-$i")) + }) + } + + test("overwrite insert") { + sql( + """INSERT OVERWRITE TABLE normal_orc_as_source + |SELECT * FROM orc_temp_table WHERE intField > 5 + """.stripMargin) + + checkAnswer( + sql("SELECT * FROM normal_orc_as_source"), + (6 to 10).map(i => Row(i, s"part-$i"))) + } +} + +class OrcSourceSuite extends OrcSuite { + override def beforeAll(): Unit = { + super.beforeAll() + + sql( + s"""CREATE TEMPORARY TABLE normal_orc_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + |) + """.stripMargin) + + sql( + s"""CREATE TEMPORARY TABLE normal_orc_as_source + |USING org.apache.spark.sql.hive.orc + |OPTIONS ( + | PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}' + |) + """.stripMargin) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala new file mode 100644 index 0000000000000..750f0b04aaa87 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import java.io.File + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql._ + +private[sql] trait OrcTest extends SQLTestUtils { + protected def hiveContext = sqlContext.asInstanceOf[HiveContext] + + import sqlContext.sparkContext + import sqlContext.implicits._ + + /** + * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withOrcFile[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: String => Unit): Unit = { + withTempPath { file => + sparkContext.parallelize(data).toDF().write.format("orc").save(file.getCanonicalPath) + f(file.getCanonicalPath) + } + } + + /** + * Writes `data` to a Orc file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Orc file will be deleted after `f` returns. + */ + protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: DataFrame => Unit): Unit = { + withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path))) + } + + /** + * Writes `data` to a Orc file, reads it back as a [[DataFrame]] and registers it as a + * temporary table named `tableName`, then call `f`. The temporary table together with the + * Orc file will be dropped/deleted after `f` returns. + */ + protected def withOrcTable[T <: Product: ClassTag: TypeTag] + (data: Seq[T], tableName: String) + (f: => Unit): Unit = { + withOrcDataFrame(data) { df => + hiveContext.registerDataFrameAsTable(df, tableName) + withTempTable(tableName)(f) + } + } + + protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + data.toDF().write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) + } + + protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index bf1121ddf0273..e62ac909cbd0c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -21,16 +21,16 @@ import java.io.File import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.{QueryTest, SQLConf} import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.sources.{InsertIntoDataSource, LogicalRelation} import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} -import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.sources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} import org.apache.spark.util.Utils // The data where the partitioning key exists only in the directory structure. @@ -38,7 +38,7 @@ case class ParquetData(intField: Int, stringField: String) // The data that also includes the partitioning key case class ParquetDataWithKey(p: Int, intField: Int, stringField: String) -case class StructContainer(intStructField :Int, stringStructField: String) +case class StructContainer(intStructField: Int, stringStructField: String) case class ParquetDataWithComplexTypes( intField: Int, @@ -151,9 +151,9 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { } val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}""")) - jsonRDD(rdd1).registerTempTable("jt") + read.json(rdd1).registerTempTable("jt") val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}""")) - jsonRDD(rdd2).registerTempTable("jt_array") + read.json(rdd2).registerTempTable("jt_array") setConf("spark.sql.hive.convertMetastoreParquet", "true") } @@ -292,10 +292,10 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { ) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK - case _ => - fail( - s"test_parquet_ctas should be converted to ${classOf[ParquetRelation2].getCanonicalName}") + case LogicalRelation(_: ParquetRelation2) => // OK + case _ => fail( + "test_parquet_ctas should be converted to " + + s"${classOf[ParquetRelation2].getCanonicalName}") } sql("DROP TABLE IF EXISTS test_parquet_ctas") @@ -316,12 +316,10 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") df.queryExecution.executedPlan match { - case ExecutedCommand( - InsertIntoDataSource( - LogicalRelation(r: ParquetRelation2), query, overwrite)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation2, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[ParquetRelation2].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + s"However, found a ${o.toString} ") } @@ -348,9 +346,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") df.queryExecution.executedPlan match { - case ExecutedCommand( - InsertIntoDataSource( - LogicalRelation(r: ParquetRelation2), query, overwrite)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation2, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[ParquetRelation2].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + @@ -390,10 +386,58 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { sql("DROP TABLE ms_convert") } + def collectParquetRelation(df: DataFrame): ParquetRelation2 = { + val plan = df.queryExecution.analyzed + plan.collectFirst { + case LogicalRelation(r: ParquetRelation2) => r + }.getOrElse { + fail(s"Expecting a ParquetRelation2, but got:\n$plan") + } + } + + test("SPARK-7749: non-partitioned metastore Parquet table lookup should use cached relation") { + sql( + s"""CREATE TABLE nonPartitioned ( + | key INT, + | value STRING + |) + |STORED AS PARQUET + """.stripMargin) + + // First lookup fills the cache + val r1 = collectParquetRelation(table("nonPartitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("nonPartitioned")) + // They should be the same instance + assert(r1 eq r2) + + sql("DROP TABLE nonPartitioned") + } + + test("SPARK-7749: partitioned metastore Parquet table lookup should use cached relation") { + sql( + s"""CREATE TABLE partitioned ( + | key INT, + | value STRING + |) + |PARTITIONED BY (part INT) + |STORED AS PARQUET + """.stripMargin) + + // First lookup fills the cache + val r1 = collectParquetRelation(table("partitioned")) + // Second lookup should reuse the cache + val r2 = collectParquetRelation(table("partitioned")) + // They should be the same instance + assert(r1 eq r2) + + sql("DROP TABLE partitioned") + } + test("Caching converted data source Parquet Relations") { - def checkCached(tableIdentifer: catalog.QualifiedTableName): Unit = { + def checkCached(tableIdentifier: catalog.QualifiedTableName): Unit = { // Converted test_parquet should be cached. - catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) match { + catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK case other => @@ -419,30 +463,30 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - var tableIdentifer = catalog.QualifiedTableName("default", "test_insert_parquet") + var tableIdentifier = catalog.QualifiedTableName("default", "test_insert_parquet") // First, make sure the converted test_parquet is not cached. - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) // Table lookup will make the table cached. table("test_insert_parquet") - checkCached(tableIdentifer) + checkCached(tableIdentifier) // For insert into non-partitioned table, we will do the conversion, // so the converted test_insert_parquet should be cached. invalidateTable("test_insert_parquet") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_insert_parquet |select a, b from jt """.stripMargin) - checkCached(tableIdentifer) + checkCached(tableIdentifier) // Make sure we can read the data. checkAnswer( sql("select * from test_insert_parquet"), sql("select a, b from jt").collect()) // Invalidate the cache. invalidateTable("test_insert_parquet") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) // Create a partitioned table. sql( @@ -459,8 +503,8 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' """.stripMargin) - tableIdentifer = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + tableIdentifier = catalog.QualifiedTableName("default", "test_parquet_partitioned_cache_test") + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test @@ -469,18 +513,18 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { """.stripMargin) // Right now, insert into a partitioned Parquet is not supported in data source Parquet. // So, we expect it is not cached. - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) sql( """ |INSERT INTO TABLE test_parquet_partitioned_cache_test |PARTITION (date='2015-04-02') |select a, b from jt """.stripMargin) - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) // Make sure we can cache the partitioned table. table("test_parquet_partitioned_cache_test") - checkCached(tableIdentifer) + checkCached(tableIdentifier) // Make sure we can read the data. checkAnswer( sql("select STRINGField, date, intField from test_parquet_partitioned_cache_test"), @@ -492,7 +536,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { """.stripMargin).collect()) invalidateTable("test_parquet_partitioned_cache_test") - assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifer) === null) + assert(catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) === null) sql("DROP TABLE test_insert_parquet") sql("DROP TABLE test_parquet_partitioned_cache_test") @@ -622,16 +666,16 @@ class ParquetSourceSuiteBase extends ParquetPartitioningTest { sql("drop table if exists spark_6016_fix") // Create a DataFrame with two partitions. So, the created table will have two parquet files. - val df1 = jsonRDD(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) - df1.saveAsTable("spark_6016_fix", "parquet", SaveMode.Overwrite) + val df1 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2)) + df1.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") checkAnswer( sql("select * from spark_6016_fix"), (1 to 10).map(i => Row(i)) ) // Create a DataFrame with four partitions. So, the created table will have four parquet files. - val df2 = jsonRDD(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) - df2.saveAsTable("spark_6016_fix", "parquet", SaveMode.Overwrite) + val df2 = read.json(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4)) + df2.write.mode(SaveMode.Overwrite).format("parquet").saveAsTable("spark_6016_fix") // For the bug of SPARK-6016, we are caching two outdated footers for df1. Then, // since the new table has four parquet files, we are trying to read new footers from two files // and then merge metadata in footers of these four (two outdated ones and two latest one), @@ -668,7 +712,7 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { StructField("a", arrayType1, nullable = true) :: Nil) assert(df.schema === expectedSchema1) - df.saveAsTable("alwaysNullable", "parquet") + df.write.format("parquet").saveAsTable("alwaysNullable") val mapType2 = MapType(IntegerType, IntegerType, valueContainsNull = true) val arrayType2 = ArrayType(IntegerType, containsNull = true) @@ -691,13 +735,13 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { val filePath = new File(tempDir, "testParquet").getCanonicalPath val filePath2 = new File(tempDir, "testParquet2").getCanonicalPath - val df = Seq(1,2,3).map(i => (i, i.toString)).toDF("int", "str") + val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") val df2 = df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").max("y.int") - intercept[RuntimeException](df2.saveAsParquetFile(filePath)) + intercept[Throwable](df2.write.parquet(filePath)) val df3 = df2.toDF("str", "max_int") - df3.saveAsParquetFile(filePath2) - val df4 = parquetFile(filePath2) + df3.write.parquet(filePath2) + val df4 = read.parquet(filePath2) checkAnswer(df4, Row("1", 1) :: Row("2", 2) :: Row("3", 3) :: Nil) assert(df4.columns === Array("str", "max_int")) } @@ -736,14 +780,14 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll sparkContext.makeRDD(1 to 10) .map(i => ParquetData(i, s"part-$p")) .toDF() - .saveAsParquetFile(partDir.getCanonicalPath) + .write.parquet(partDir.getCanonicalPath) } sparkContext .makeRDD(1 to 10) .map(i => ParquetData(i, s"part-1")) .toDF() - .saveAsParquetFile(new File(normalTableDir, "normal").getCanonicalPath) + .write.parquet(new File(normalTableDir, "normal").getCanonicalPath) partitionedTableDirWithKey = Utils.createTempDir() @@ -752,7 +796,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll sparkContext.makeRDD(1 to 10) .map(i => ParquetDataWithKey(p, i, s"part-$p")) .toDF() - .saveAsParquetFile(partDir.getCanonicalPath) + .write.parquet(partDir.getCanonicalPath) } partitionedTableDirWithKeyAndComplexTypes = Utils.createTempDir() @@ -762,7 +806,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll sparkContext.makeRDD(1 to 10).map { i => ParquetDataWithKeyAndComplexTypes( p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) - }.toDF().saveAsParquetFile(partDir.getCanonicalPath) + }.toDF().write.parquet(partDir.getCanonicalPath) } partitionedTableDirWithComplexTypes = Utils.createTempDir() @@ -771,7 +815,7 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p") sparkContext.makeRDD(1 to 10).map { i => ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i) - }.toDF().saveAsParquetFile(partDir.getCanonicalPath) + }.toDF().write.parquet(partDir.getCanonicalPath) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala deleted file mode 100644 index 415b1cd168848..0000000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala +++ /dev/null @@ -1,525 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources - -import org.apache.hadoop.fs.Path - -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.parquet.ParquetTest -import org.apache.spark.sql.types._ - -// TODO Don't extend ParquetTest -// This test suite extends ParquetTest for some convenient utility methods. These methods should be -// moved to some more general places, maybe QueryTest. -class FSBasedRelationSuite extends QueryTest with ParquetTest { - override val sqlContext: SQLContext = TestHive - - import sqlContext._ - import sqlContext.implicits._ - - val dataSchema = - StructType( - Seq( - StructField("a", IntegerType, nullable = false), - StructField("b", StringType, nullable = false))) - - val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") - - val partitionedTestDF1 = (for { - i <- 1 to 3 - p2 <- Seq("foo", "bar") - } yield (i, s"val_$i", 1, p2)).toDF("a", "b", "p1", "p2") - - val partitionedTestDF2 = (for { - i <- 1 to 3 - p2 <- Seq("foo", "bar") - } yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2") - - val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) - - def checkQueries(df: DataFrame): Unit = { - // Selects everything - checkAnswer( - df, - for (i <- 1 to 3; p1 <- 1 to 2; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", p1, p2)) - - // Simple filtering and partition pruning - checkAnswer( - df.filter('a > 1 && 'p1 === 2), - for (i <- 2 to 3; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", 2, p2)) - - // Simple projection and filtering - checkAnswer( - df.filter('a > 1).select('b, 'a + 1), - for (i <- 2 to 3; _ <- 1 to 2; _ <- Seq("foo", "bar")) yield Row(s"val_$i", i + 1)) - - // Simple projection and partition pruning - checkAnswer( - df.filter('a > 1 && 'p1 < 2).select('b, 'p1), - for (i <- 2 to 3; _ <- Seq("foo", "bar")) yield Row(s"val_$i", 1)) - - // Self-join - df.registerTempTable("t") - withTempTable("t") { - checkAnswer( - sql( - """SELECT l.a, r.b, l.p1, r.p2 - |FROM t l JOIN t r - |ON l.a = r.a AND l.p1 = r.p1 AND l.p2 = r.p2 - """.stripMargin), - for (i <- 1 to 3; p1 <- 1 to 2; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", p1, p2)) - } - } - - test("save()/load() - non-partitioned table - Overwrite") { - withTempPath { file => - testDF.save( - path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite) - - testDF.save( - path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite) - - checkAnswer( - load( - source = classOf[SimpleTextSource].getCanonicalName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), - testDF.collect()) - } - } - - test("save()/load() - non-partitioned table - Append") { - withTempPath { file => - testDF.save( - path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite) - - testDF.save( - path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Append) - - checkAnswer( - load( - source = classOf[SimpleTextSource].getCanonicalName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)).orderBy("a"), - testDF.unionAll(testDF).orderBy("a").collect()) - } - } - - test("save()/load() - non-partitioned table - ErrorIfExists") { - withTempDir { file => - intercept[RuntimeException] { - testDF.save( - path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.ErrorIfExists) - } - } - } - - test("save()/load() - non-partitioned table - Ignore") { - withTempDir { file => - testDF.save( - path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Ignore) - - val path = new Path(file.getCanonicalPath) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - assert(fs.listStatus(path).isEmpty) - } - } - - test("save()/load() - partitioned table - simple queries") { - withTempPath { file => - partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.ErrorIfExists, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - checkQueries( - load( - source = classOf[SimpleTextSource].getCanonicalName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json))) - } - } - - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") - .saveAsTextFile(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - load( - source = classOf[SimpleTextSource].getCanonicalName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchemaWithPartition.json))) - } - } - - test("save()/load() - partitioned table - Overwrite") { - withTempPath { file => - partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - checkAnswer( - load( - source = classOf[SimpleTextSource].getCanonicalName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), - partitionedTestDF.collect()) - } - } - - test("save()/load() - partitioned table - Append") { - withTempPath { file => - partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Append, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - checkAnswer( - load( - source = classOf[SimpleTextSource].getCanonicalName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), - partitionedTestDF.unionAll(partitionedTestDF).collect()) - } - } - - test("save()/load() - partitioned table - Append - new partition values") { - withTempPath { file => - partitionedTestDF1.save( - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF2.save( - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Append, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - checkAnswer( - load( - source = classOf[SimpleTextSource].getCanonicalName, - options = Map( - "path" -> file.getCanonicalPath, - "dataSchema" -> dataSchema.json)), - partitionedTestDF.collect()) - } - } - - test("save()/load() - partitioned table - ErrorIfExists") { - withTempDir { file => - intercept[RuntimeException] { - partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.ErrorIfExists, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - } - } - } - - test("save()/load() - partitioned table - Ignore") { - withTempDir { file => - partitionedTestDF.save( - path = file.getCanonicalPath, - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Ignore) - - val path = new Path(file.getCanonicalPath) - val fs = path.getFileSystem(SparkHadoopUtil.get.conf) - assert(fs.listStatus(path).isEmpty) - } - } - - def withTable(tableName: String)(f: => Unit): Unit = { - try f finally sql(s"DROP TABLE $tableName") - } - - test("saveAsTable()/load() - non-partitioned table - Overwrite") { - testDF.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite, - Map("dataSchema" -> dataSchema.json)) - - withTable("t") { - checkAnswer(table("t"), testDF.collect()) - } - } - - test("saveAsTable()/load() - non-partitioned table - Append") { - testDF.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite) - - testDF.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Append) - - withTable("t") { - checkAnswer(table("t"), testDF.unionAll(testDF).orderBy("a").collect()) - } - } - - test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { - Seq.empty[(Int, String)].toDF().registerTempTable("t") - - withTempTable("t") { - intercept[AnalysisException] { - testDF.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.ErrorIfExists) - } - } - } - - test("saveAsTable()/load() - non-partitioned table - Ignore") { - Seq.empty[(Int, String)].toDF().registerTempTable("t") - - withTempTable("t") { - testDF.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Ignore) - - assert(table("t").collect().isEmpty) - } - } - - test("saveAsTable()/load() - partitioned table - simple queries") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite, - Map("dataSchema" -> dataSchema.json)) - - withTable("t") { - checkQueries(table("t")) - } - } - - test("saveAsTable()/load() - partitioned table - Overwrite") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - withTable("t") { - checkAnswer(table("t"), partitionedTestDF.collect()) - } - } - - test("saveAsTable()/load() - partitioned table - Append") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - withTable("t") { - checkAnswer(table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) - } - } - - test("saveAsTable()/load() - partitioned table - Append - new partition values") { - partitionedTestDF1.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - partitionedTestDF2.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - withTable("t") { - checkAnswer(table("t"), partitionedTestDF.collect()) - } - } - - test("saveAsTable()/load() - partitioned table - Append - mismatched partition columns") { - partitionedTestDF1.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - // Using only a subset of all partition columns - intercept[Throwable] { - partitionedTestDF2.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1")) - } - - // Using different order of partition columns - intercept[Throwable] { - partitionedTestDF2.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Append, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p2", "p1")) - } - } - - test("saveAsTable()/load() - partitioned table - ErrorIfExists") { - Seq.empty[(Int, String)].toDF().registerTempTable("t") - - withTempTable("t") { - intercept[AnalysisException] { - partitionedTestDF.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.ErrorIfExists, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - } - } - } - - test("saveAsTable()/load() - partitioned table - Ignore") { - Seq.empty[(Int, String)].toDF().registerTempTable("t") - - withTempTable("t") { - partitionedTestDF.saveAsTable( - tableName = "t", - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Ignore, - options = Map("dataSchema" -> dataSchema.json), - partitionColumns = Seq("p1", "p2")) - - assert(table("t").collect().isEmpty) - } - } - - test("Hadoop style globbing") { - withTempPath { file => - partitionedTestDF.save( - source = classOf[SimpleTextSource].getCanonicalName, - mode = SaveMode.Overwrite, - options = Map("path" -> file.getCanonicalPath), - partitionColumns = Seq("p1", "p2")) - - val df = load( - source = classOf[SimpleTextSource].getCanonicalName, - options = Map( - "path" -> s"${file.getCanonicalPath}/p1=*/p2=???", - "dataSchema" -> dataSchema.json)) - - val expectedPaths = Set( - s"${file.getCanonicalFile}/p1=1/p2=foo", - s"${file.getCanonicalFile}/p1=2/p2=foo", - s"${file.getCanonicalFile}/p1=1/p2=bar", - s"${file.getCanonicalFile}/p1=2/p2=bar" - ).map { p => - val path = new Path(p) - val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - path.makeQualified(fs.getUri, fs.getWorkingDirectory).toString - } - - println(df.queryExecution) - - val actualPaths = df.queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: FSBasedRelation) => - relation.paths.toSet - }.getOrElse { - fail("Expect an FSBasedRelation, but none could be found") - } - - assert(actualPaths === expectedPaths) - checkAnswer(df, partitionedTestDF.collect()) - } - } -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 8801aba2f64c3..0f959b3d0b86d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -21,28 +21,28 @@ import java.text.NumberFormat import java.util.UUID import com.google.common.base.Objects -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} -import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.{Row, SQLContext} /** - * A simple example [[FSBasedRelationProvider]]. + * A simple example [[HadoopFsRelationProvider]]. */ -class SimpleTextSource extends FSBasedRelationProvider { +class SimpleTextSource extends HadoopFsRelationProvider { override def createRelation( sqlContext: SQLContext, paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], - parameters: Map[String, String]): FSBasedRelation = { - val partitionsSchema = partitionColumns.getOrElse(StructType(Array.empty[StructField])) - new SimpleTextRelation(paths, schema, partitionsSchema, parameters)(sqlContext) + parameters: Map[String, String]): HadoopFsRelation = { + new SimpleTextRelation(paths, schema, partitionColumns, parameters)(sqlContext) } } @@ -59,38 +59,32 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW } } -class SimpleTextOutputWriter extends OutputWriter { - private var recordWriter: RecordWriter[NullWritable, Text] = _ - private var taskAttemptContext: TaskAttemptContext = _ - - override def init( - path: String, - dataSchema: StructType, - context: TaskAttemptContext): Unit = { - recordWriter = new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) - taskAttemptContext = context - } +class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { + private val recordWriter: RecordWriter[NullWritable, Text] = + new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) override def write(row: Row): Unit = { val serialized = row.toSeq.map(_.toString).mkString(",") recordWriter.write(null, new Text(serialized)) } - override def close(): Unit = recordWriter.close(taskAttemptContext) + override def close(): Unit = { + recordWriter.close(context) + } } /** - * A simple example [[FSBasedRelation]], used for testing purposes. Data are stored as comma + * A simple example [[HadoopFsRelation]], used for testing purposes. Data are stored as comma * separated string lines. When scanning data, schema must be explicitly provided via data source * option `"dataSchema"`. */ class SimpleTextRelation( - paths: Array[String], + override val paths: Array[String], val maybeDataSchema: Option[StructType], - partitionsSchema: StructType, + override val userDefinedPartitionColumns: Option[StructType], parameters: Map[String, String])( @transient val sqlContext: SQLContext) - extends FSBasedRelation(paths, partitionsSchema) { + extends HadoopFsRelation { import sqlContext.sparkContext @@ -108,18 +102,63 @@ class SimpleTextRelation( } override def hashCode(): Int = - Objects.hashCode(paths, maybeDataSchema, dataSchema) + Objects.hashCode(paths, maybeDataSchema, dataSchema, partitionColumns) - override def outputWriterClass: Class[_ <: OutputWriter] = - classOf[SimpleTextOutputWriter] - - override def buildScan(inputPaths: Array[String]): RDD[Row] = { + override def buildScan(inputStatuses: Array[FileStatus]): RDD[Row] = { val fields = dataSchema.map(_.dataType) - sparkContext.textFile(inputPaths.mkString(",")).map { record => + sparkContext.textFile(inputStatuses.map(_.getPath).mkString(",")).map { record => Row(record.split(",").zip(fields).map { case (value, dataType) => - Cast(Literal(value), dataType).eval() + // `Cast`ed values are always of Catalyst types (i.e. UTF8String instead of String, etc.) + val catalystValue = Cast(Literal(value), dataType).eval() + // Here we're converting Catalyst values to Scala values to test `needsConversion` + CatalystTypeConverters.convertToScala(catalystValue, dataType) }: _*) } } + + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, context) + } + } +} + +/** + * A simple example [[HadoopFsRelationProvider]]. + */ +class CommitFailureTestSource extends HadoopFsRelationProvider { + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + new CommitFailureTestRelation(paths, schema, partitionColumns, parameters)(sqlContext) + } +} + +class CommitFailureTestRelation( + override val paths: Array[String], + maybeDataSchema: Option[StructType], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + @transient sqlContext: SQLContext) + extends SimpleTextRelation( + paths, maybeDataSchema, userDefinedPartitionColumns, parameters)(sqlContext) { + override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new SimpleTextOutputWriter(path, context) { + override def close(): Unit = { + sys.error("Intentional task commitment failure for testing purpose.") + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala new file mode 100644 index 0000000000000..74095426741e3 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -0,0 +1,597 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import com.google.common.io.Files +import org.apache.hadoop.fs.Path + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ + +abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { + override val sqlContext: SQLContext = TestHive + + import sqlContext._ + import sqlContext.implicits._ + + val dataSourceName = classOf[SimpleTextSource].getCanonicalName + + val dataSchema = + StructType( + Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = false))) + + val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") + + val partitionedTestDF1 = (for { + i <- 1 to 3 + p2 <- Seq("foo", "bar") + } yield (i, s"val_$i", 1, p2)).toDF("a", "b", "p1", "p2") + + val partitionedTestDF2 = (for { + i <- 1 to 3 + p2 <- Seq("foo", "bar") + } yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2") + + val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) + + def checkQueries(df: DataFrame): Unit = { + // Selects everything + checkAnswer( + df, + for (i <- 1 to 3; p1 <- 1 to 2; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", p1, p2)) + + // Simple filtering and partition pruning + checkAnswer( + df.filter('a > 1 && 'p1 === 2), + for (i <- 2 to 3; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", 2, p2)) + + // Simple projection and filtering + checkAnswer( + df.filter('a > 1).select('b, 'a + 1), + for (i <- 2 to 3; _ <- 1 to 2; _ <- Seq("foo", "bar")) yield Row(s"val_$i", i + 1)) + + // Simple projection and partition pruning + checkAnswer( + df.filter('a > 1 && 'p1 < 2).select('b, 'p1), + for (i <- 2 to 3; _ <- Seq("foo", "bar")) yield Row(s"val_$i", 1)) + + // Project many copies of columns with different types (reproduction for SPARK-7858) + checkAnswer( + df.filter('a > 1 && 'p1 < 2).select('b, 'b, 'b, 'b, 'p1, 'p1, 'p1, 'p1), + for (i <- 2 to 3; _ <- Seq("foo", "bar")) + yield Row(s"val_$i", s"val_$i", s"val_$i", s"val_$i", 1, 1, 1, 1)) + + // Self-join + df.registerTempTable("t") + withTempTable("t") { + checkAnswer( + sql( + """SELECT l.a, r.b, l.p1, r.p2 + |FROM t l JOIN t r + |ON l.a = r.a AND l.p1 = r.p1 AND l.p2 = r.p2 + """.stripMargin), + for (i <- 1 to 3; p1 <- 1 to 2; p2 <- Seq("foo", "bar")) yield Row(i, s"val_$i", p1, p2)) + } + } + + test("save()/load() - non-partitioned table - Overwrite") { + withTempPath { file => + testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) + testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) + + checkAnswer( + read.format(dataSourceName) + .option("path", file.getCanonicalPath) + .option("dataSchema", dataSchema.json) + .load(), + testDF.collect()) + } + } + + test("save()/load() - non-partitioned table - Append") { + withTempPath { file => + testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) + testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath) + + checkAnswer( + read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath).orderBy("a"), + testDF.unionAll(testDF).orderBy("a").collect()) + } + } + + test("save()/load() - non-partitioned table - ErrorIfExists") { + withTempDir { file => + intercept[RuntimeException] { + testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).save(file.getCanonicalPath) + } + } + } + + test("save()/load() - non-partitioned table - Ignore") { + withTempDir { file => + testDF.write.mode(SaveMode.Ignore).format(dataSourceName).save(file.getCanonicalPath) + + val path = new Path(file.getCanonicalPath) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + assert(fs.listStatus(path).isEmpty) + } + } + + test("save()/load() - partitioned table - simple queries") { + withTempPath { file => + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.ErrorIfExists) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath)) + } + } + + test("save()/load() - partitioned table - Overwrite") { + withTempPath { file => + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + checkAnswer( + read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath), + partitionedTestDF.collect()) + } + } + + test("save()/load() - partitioned table - Append") { + withTempPath { file => + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Append) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + checkAnswer( + read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath), + partitionedTestDF.unionAll(partitionedTestDF).collect()) + } + } + + test("save()/load() - partitioned table - Append - new partition values") { + withTempPath { file => + partitionedTestDF1.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + checkAnswer( + read.format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(file.getCanonicalPath), + partitionedTestDF.collect()) + } + } + + test("save()/load() - partitioned table - ErrorIfExists") { + withTempDir { file => + intercept[RuntimeException] { + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.ErrorIfExists) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + } + } + } + + test("save()/load() - partitioned table - Ignore") { + withTempDir { file => + partitionedTestDF.write + .format(dataSourceName).mode(SaveMode.Ignore).save(file.getCanonicalPath) + + val path = new Path(file.getCanonicalPath) + val fs = path.getFileSystem(SparkHadoopUtil.get.conf) + assert(fs.listStatus(path).isEmpty) + } + } + + test("saveAsTable()/load() - non-partitioned table - Overwrite") { + testDF.write.format(dataSourceName).mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .saveAsTable("t") + + withTable("t") { + checkAnswer(table("t"), testDF.collect()) + } + } + + test("saveAsTable()/load() - non-partitioned table - Append") { + testDF.write.format(dataSourceName).mode(SaveMode.Overwrite).saveAsTable("t") + testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t") + + withTable("t") { + checkAnswer(table("t"), testDF.unionAll(testDF).orderBy("a").collect()) + } + } + + test("saveAsTable()/load() - non-partitioned table - ErrorIfExists") { + Seq.empty[(Int, String)].toDF().registerTempTable("t") + + withTempTable("t") { + intercept[AnalysisException] { + testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).saveAsTable("t") + } + } + } + + test("saveAsTable()/load() - non-partitioned table - Ignore") { + Seq.empty[(Int, String)].toDF().registerTempTable("t") + + withTempTable("t") { + testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") + assert(table("t").collect().isEmpty) + } + } + + test("saveAsTable()/load() - partitioned table - simple queries") { + partitionedTestDF.write.format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .saveAsTable("t") + + withTable("t") { + checkQueries(table("t")) + } + } + + test("saveAsTable()/load() - partitioned table - Overwrite") { + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + withTable("t") { + checkAnswer(table("t"), partitionedTestDF.collect()) + } + } + + test("saveAsTable()/load() - partitioned table - Append") { + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + withTable("t") { + checkAnswer(table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) + } + } + + test("saveAsTable()/load() - partitioned table - Append - new partition values") { + partitionedTestDF1.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + withTable("t") { + checkAnswer(table("t"), partitionedTestDF.collect()) + } + } + + test("saveAsTable()/load() - partitioned table - Append - mismatched partition columns") { + partitionedTestDF1.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + // Using only a subset of all partition columns + intercept[Throwable] { + partitionedTestDF2.write + .format(dataSourceName) + .mode(SaveMode.Append) + .option("dataSchema", dataSchema.json) + .partitionBy("p1") + .saveAsTable("t") + } + } + + test("saveAsTable()/load() - partitioned table - ErrorIfExists") { + Seq.empty[(Int, String)].toDF().registerTempTable("t") + + withTempTable("t") { + intercept[AnalysisException] { + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.ErrorIfExists) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + } + } + } + + test("saveAsTable()/load() - partitioned table - Ignore") { + Seq.empty[(Int, String)].toDF().registerTempTable("t") + + withTempTable("t") { + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Ignore) + .option("dataSchema", dataSchema.json) + .partitionBy("p1", "p2") + .saveAsTable("t") + + assert(table("t").collect().isEmpty) + } + } + + test("Hadoop style globbing") { + withTempPath { file => + partitionedTestDF.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("p1", "p2") + .save(file.getCanonicalPath) + + val df = read + .format(dataSourceName) + .option("dataSchema", dataSchema.json) + .load(s"${file.getCanonicalPath}/p1=*/p2=???") + + val expectedPaths = Set( + s"${file.getCanonicalFile}/p1=1/p2=foo", + s"${file.getCanonicalFile}/p1=2/p2=foo", + s"${file.getCanonicalFile}/p1=1/p2=bar", + s"${file.getCanonicalFile}/p1=2/p2=bar" + ).map { p => + val path = new Path(p) + val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + path.makeQualified(fs.getUri, fs.getWorkingDirectory).toString + } + + val actualPaths = df.queryExecution.analyzed.collectFirst { + case LogicalRelation(relation: HadoopFsRelation) => + relation.paths.toSet + }.getOrElse { + fail("Expect an FSBasedRelation, but none could be found") + } + + assert(actualPaths === expectedPaths) + checkAnswer(df, partitionedTestDF.collect()) + } + } + + test("Partition column type casting") { + withTempPath { file => + val input = partitionedTestDF.select('a, 'b, 'p1.cast(StringType).as('ps), 'p2) + + input + .write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("ps", "p2") + .saveAsTable("t") + + withTempTable("t") { + checkAnswer(table("t"), input.collect()) + } + } + } + + test("SPARK-7616: adjust column name order accordingly when saving partitioned table") { + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + + df.write + .format(dataSourceName) + .mode(SaveMode.Overwrite) + .partitionBy("c", "a") + .saveAsTable("t") + + withTable("t") { + checkAnswer(table("t"), df.select('b, 'c, 'a).collect()) + } + } +} + +class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName + + import sqlContext._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } +} + +class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { + import TestHive.implicits._ + + override val sqlContext = TestHive + + val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName + + test("SPARK-7684: commitTask() failure should fallback to abortTask()") { + withTempPath { file => + val df = (1 to 3).map(i => i -> s"val_$i").toDF("a", "b") + intercept[SparkException] { + df.write.format(dataSourceName).save(file.getCanonicalPath) + } + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } +} + +class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName + + import sqlContext._ + import sqlContext.implicits._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) + .toDF("a", "b", "p1") + .write.parquet(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-7868: _temporary directories should be ignored") { + withTempPath { dir => + val df = Seq("a", "b", "c").zipWithIndex.toDF() + + df.write + .format("parquet") + .save(dir.getCanonicalPath) + + df.write + .format("parquet") + .save(s"${dir.getCanonicalPath}/_temporary") + + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + } + } + + test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df = Seq(1 -> "a").toDF() + + // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw + // since it's not a valid Parquet file. + val emptyFile = new File(path, "empty") + Files.createParentDirs(emptyFile) + Files.touch(emptyFile) + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Ignore).save(path) + + // This should only complain that the destination directory already exists, rather than file + // "empty" is not a Parquet file. + assert { + intercept[RuntimeException] { + df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) + }.getMessage.contains("already exists") + } + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + checkAnswer(read.format("parquet").load(path), df) + } + } +} diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala deleted file mode 100644 index 33e96eaabfbf6..0000000000000 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.net.URI -import java.util.{ArrayList => JArrayList, Properties} - -import scala.collection.JavaConversions._ -import scala.language.implicitConversions - -import org.apache.hadoop.{io => hadoopIo} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.ql.Context -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} -import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} -import org.apache.hadoop.hive.ql.processors._ -import org.apache.hadoop.hive.ql.stats.StatsSetupConst -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer, io => hiveIo} -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, ObjectInspector, PrimitiveObjectInspector} -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory -import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} -import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.InputFormat - -import org.apache.spark.sql.types.{UTF8String, Decimal, DecimalType} - -private[hive] case class HiveFunctionWrapper(functionClassName: String) - extends java.io.Serializable { - - // for Serialization - def this() = this(null) - - import org.apache.spark.util.Utils._ - def createFunction[UDFType <: AnyRef](): UDFType = { - getContextOrSparkClassLoader - .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] - } -} - -/** - * A compatibility layer for interacting with Hive version 0.12.0. - */ -private[hive] object HiveShim { - val version = "0.12.0" - - def getTableDesc( - serdeClass: Class[_ <: Deserializer], - inputFormatClass: Class[_ <: InputFormat[_, _]], - outputFormatClass: Class[_], - properties: Properties) = { - new TableDesc(serdeClass, inputFormatClass, outputFormatClass, properties) - } - - def getStringWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, - getStringWritable(value)) - - def getIntWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.INT, - getIntWritable(value)) - - def getDoubleWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DOUBLE, - getDoubleWritable(value)) - - def getBooleanWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BOOLEAN, - getBooleanWritable(value)) - - def getLongWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.LONG, - getLongWritable(value)) - - def getFloatWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.FLOAT, - getFloatWritable(value)) - - def getShortWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.SHORT, - getShortWritable(value)) - - def getByteWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BYTE, - getByteWritable(value)) - - def getBinaryWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.BINARY, - getBinaryWritable(value)) - - def getDateWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DATE, - getDateWritable(value)) - - def getTimestampWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.TIMESTAMP, - getTimestampWritable(value)) - - def getDecimalWritableConstantObjectInspector(value: Any): ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.DECIMAL, - getDecimalWritable(value)) - - def getPrimitiveNullWritableConstantObjectInspector: ObjectInspector = - PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.VOID, null) - - def getStringWritable(value: Any): hadoopIo.Text = - if (value == null) null else new hadoopIo.Text(value.asInstanceOf[UTF8String].toString) - - def getIntWritable(value: Any): hadoopIo.IntWritable = - if (value == null) null else new hadoopIo.IntWritable(value.asInstanceOf[Int]) - - def getDoubleWritable(value: Any): hiveIo.DoubleWritable = - if (value == null) null else new hiveIo.DoubleWritable(value.asInstanceOf[Double]) - - def getBooleanWritable(value: Any): hadoopIo.BooleanWritable = - if (value == null) null else new hadoopIo.BooleanWritable(value.asInstanceOf[Boolean]) - - def getLongWritable(value: Any): hadoopIo.LongWritable = - if (value == null) null else new hadoopIo.LongWritable(value.asInstanceOf[Long]) - - def getFloatWritable(value: Any): hadoopIo.FloatWritable = - if (value == null) null else new hadoopIo.FloatWritable(value.asInstanceOf[Float]) - - def getShortWritable(value: Any): hiveIo.ShortWritable = - if (value == null) null else new hiveIo.ShortWritable(value.asInstanceOf[Short]) - - def getByteWritable(value: Any): hiveIo.ByteWritable = - if (value == null) null else new hiveIo.ByteWritable(value.asInstanceOf[Byte]) - - def getBinaryWritable(value: Any): hadoopIo.BytesWritable = - if (value == null) null else new hadoopIo.BytesWritable(value.asInstanceOf[Array[Byte]]) - - def getDateWritable(value: Any): hiveIo.DateWritable = - if (value == null) null else new hiveIo.DateWritable(value.asInstanceOf[Int]) - - def getTimestampWritable(value: Any): hiveIo.TimestampWritable = - if (value == null) { - null - } else { - new hiveIo.TimestampWritable(value.asInstanceOf[java.sql.Timestamp]) - } - - def getDecimalWritable(value: Any): hiveIo.HiveDecimalWritable = - if (value == null) { - null - } else { - new hiveIo.HiveDecimalWritable( - HiveShim.createDecimal(value.asInstanceOf[Decimal].toJavaBigDecimal)) - } - - def getPrimitiveNullWritable: NullWritable = NullWritable.get() - - def createDriverResultsArray = new JArrayList[String] - - def processResults(results: JArrayList[String]) = results - - def getStatsSetupConstTotalSize = StatsSetupConst.TOTAL_SIZE - - def getStatsSetupConstRawDataSize = StatsSetupConst.RAW_DATA_SIZE - - def createDefaultDBIfNeeded(context: HiveContext) = { } - - def getCommandProcessor(cmd: Array[String], conf: HiveConf) = { - CommandProcessorFactory.get(cmd(0), conf) - } - - def createDecimal(bd: java.math.BigDecimal): HiveDecimal = { - new HiveDecimal(bd) - } - - def appendReadColumns(conf: Configuration, ids: Seq[Integer], names: Seq[String]) { - ColumnProjectionUtils.appendReadColumnIDs(conf, ids) - ColumnProjectionUtils.appendReadColumnNames(conf, names) - } - - def getExternalTmpPath(context: Context, uri: URI) = { - context.getExternalTmpFileURI(uri) - } - - def getDataLocationPath(p: Partition) = p.getPartitionPath - - def getAllPartitionsOf(client: Hive, tbl: Table) = client.getAllPartitionsForPruner(tbl) - - def compatibilityBlackList = Seq( - "decimal_.*", - "udf7", - "drop_partitions_filter2", - "show_.*", - "serde_regex", - "udf_to_date", - "udaf_collect_set", - "udf_concat" - ) - - def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { - tbl.setDataLocation(new Path(crtTbl.getLocation()).toUri()) - } - - def decimalMetastoreString(decimalType: DecimalType): String = "decimal" - - def decimalTypeInfo(decimalType: DecimalType): TypeInfo = - TypeInfoFactory.decimalTypeInfo - - def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { - DecimalType.Unlimited - } - - def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { - if (hdoi.preferWritable()) { - Decimal(hdoi.getPrimitiveWritableObject(data).getHiveDecimal().bigDecimalValue) - } else { - Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) - } - } - - def getConvertedOI( - inputOI: ObjectInspector, - outputOI: ObjectInspector): ObjectInspector = { - ObjectInspectorConverters.getConvertedOI(inputOI, outputOI, true) - } - - def prepareWritable(w: Writable): Writable = { - w - } - - def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = {} -} - -private[hive] class ShimFileSinkDesc( - var dir: String, - var tableInfo: TableDesc, - var compressed: Boolean) - extends FileSinkDesc(dir, tableInfo, compressed) { -} diff --git a/streaming/pom.xml b/streaming/pom.xml index 5ab7f4472c38b..49d035a1e9696 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -40,6 +40,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + diff --git a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.css b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css similarity index 90% rename from core/src/main/resources/org/apache/spark/ui/static/streaming-page.css rename to streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css index 5da9d631ad124..b22c884bfebdb 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.css +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.css @@ -56,3 +56,11 @@ .histogram { width: auto; } + +span.expand-input-rate { + cursor: pointer; +} + +tr.batch-table-cell-highlight > td { + background-color: #D6FFE4 !important; +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js similarity index 79% rename from core/src/main/resources/org/apache/spark/ui/static/streaming-page.js rename to streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js index a4e03b156f13e..75251f493ad22 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/streaming-page.js +++ b/streaming/src/main/resources/org/apache/spark/streaming/ui/static/streaming-page.js @@ -98,7 +98,16 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { var x = d3.scale.linear().domain([minX, maxX]).range([0, width]); var y = d3.scale.linear().domain([minY, maxY]).range([height, 0]); - var xAxis = d3.svg.axis().scale(x).orient("bottom").tickFormat(function(d) { return timeFormat[d]; }); + var xAxis = d3.svg.axis().scale(x).orient("bottom").tickFormat(function(d) { + var formattedDate = timeFormat[d]; + var dotIndex = formattedDate.indexOf('.'); + if (dotIndex >= 0) { + // Remove milliseconds + return formattedDate.substring(0, dotIndex); + } else { + return formattedDate; + } + }); var formatYValue = d3.format(",.2f"); var yAxis = d3.svg.axis().scale(y).orient("left").ticks(5).tickFormat(formatYValue); @@ -137,6 +146,12 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { .attr("class", "line") .attr("d", line); + // If the user click one point in the graphs, jump to the batch row and highlight it. And + // recovery the batch row after 3 seconds if necessary. + // We need to remember the last clicked batch so that we can recovery it. + var lastClickedBatch = null; + var lastTimeout = null; + // Add points to the line. However, we make it invisible at first. But when the user moves mouse // over a point, it will be displayed with its detail. svg.selectAll(".point") @@ -145,6 +160,7 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { .attr("stroke", "white") // white and opacity = 0 make it invisible .attr("fill", "white") .attr("opacity", "0") + .style("cursor", "pointer") .attr("cx", function(d) { return x(d.x); }) .attr("cy", function(d) { return y(d.y); }) .attr("r", function(d) { return 3; }) @@ -166,7 +182,29 @@ function drawTimeline(id, data, minX, maxX, minY, maxY, unitY, batchInterval) { .attr("opacity", "0"); }) .on("click", function(d) { - window.location.href = "batch/?id=" + d.x; + if (lastTimeout != null) { + window.clearTimeout(lastTimeout); + } + if (lastClickedBatch != null) { + clearBatchRow(lastClickedBatch); + lastClickedBatch = null; + } + lastClickedBatch = d.x; + highlightBatchRow(lastClickedBatch) + lastTimeout = window.setTimeout(function () { + lastTimeout = null; + if (lastClickedBatch != null) { + clearBatchRow(lastClickedBatch); + lastClickedBatch = null; + } + }, 3000); // Clean up after 3 seconds + + var batchSelector = $("#batch-" + d.x); + var topOffset = batchSelector.offset().top - 15; + if (topOffset < 0) { + topOffset = 0; + } + $('html,body').animate({scrollTop: topOffset}, 200); }); } @@ -209,6 +247,9 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) { svg.append("g") .attr("class", "x axis") .call(xAxis) + .append("text") + .attr("transform", "translate(" + (margin.left + width - 40) + ", 15)") + .text("#batches"); svg.append("g") .attr("class", "y axis") @@ -252,23 +293,29 @@ function drawHistogram(id, values, minY, maxY, unitY, batchInterval) { } $(function() { - function getParameterFromURL(param) - { - var parameters = window.location.search.substring(1); // Remove "?" - var keyValues = parameters.split('&'); - for (var i = 0; i < keyValues.length; i++) - { - var paramKeyValue = keyValues[i].split('='); - if (paramKeyValue[0] == param) - { - return paramKeyValue[1]; - } + var status = window.localStorage && window.localStorage.getItem("show-streams-detail") == "true"; + + $("span.expand-input-rate").click(function() { + status = !status; + $("#inputs-table").toggle('collapsed'); + // Toggle the class of the arrow between open and closed + $(this).find('.expand-input-rate-arrow').toggleClass('arrow-open').toggleClass('arrow-closed'); + if (window.localStorage) { + window.localStorage.setItem("show-streams-detail", "" + status); } - } + }); - if (getParameterFromURL("show-streams-detail") == "true") { - // Show the details for all InputDStream - $('#inputs-table').toggle('collapsed'); - $('#triangle').html('▼'); + if (status) { + $("#inputs-table").toggle('collapsed'); + // Toggle the class of the arrow between open and closed + $(this).find('.expand-input-rate-arrow').toggleClass('arrow-open').toggleClass('arrow-closed'); } }); + +function highlightBatchRow(batch) { + $("#batch-" + batch).parent().addClass("batch-table-cell-highlight"); +} + +function clearBatchRow(batch) { + $("#batch-" + batch).parent().removeClass("batch-table-cell-highlight"); +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 7bfae253c3a0c..d8dc4e4101664 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -102,6 +102,44 @@ object Checkpoint extends Logging { Seq.empty } } + + /** Serialize the checkpoint, or throw any exception that occurs */ + def serialize(checkpoint: Checkpoint, conf: SparkConf): Array[Byte] = { + val compressionCodec = CompressionCodec.createCodec(conf) + val bos = new ByteArrayOutputStream() + val zos = compressionCodec.compressedOutputStream(bos) + val oos = new ObjectOutputStream(zos) + Utils.tryWithSafeFinally { + oos.writeObject(checkpoint) + } { + oos.close() + } + bos.toByteArray + } + + /** Deserialize a checkpoint from the input stream, or throw any exception that occurs */ + def deserialize(inputStream: InputStream, conf: SparkConf): Checkpoint = { + val compressionCodec = CompressionCodec.createCodec(conf) + var ois: ObjectInputStreamWithLoader = null + Utils.tryWithSafeFinally { + + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val zis = compressionCodec.compressedInputStream(inputStream) + ois = new ObjectInputStreamWithLoader(zis, + Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + cp.validate() + cp + } { + if (ois != null) { + ois.close() + } + } + } } @@ -189,17 +227,10 @@ class CheckpointWriter( } def write(checkpoint: Checkpoint, clearCheckpointDataLater: Boolean) { - val bos = new ByteArrayOutputStream() - val zos = compressionCodec.compressedOutputStream(bos) - val oos = new ObjectOutputStream(zos) - Utils.tryWithSafeFinally { - oos.writeObject(checkpoint) - } { - oos.close() - } try { + val bytes = Checkpoint.serialize(checkpoint, conf) executor.execute(new CheckpointWriteHandler( - checkpoint.checkpointTime, bos.toByteArray, clearCheckpointDataLater)) + checkpoint.checkpointTime, bytes, clearCheckpointDataLater)) logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => @@ -264,25 +295,8 @@ object CheckpointReader extends Logging { checkpointFiles.foreach(file => { logInfo("Attempting to load checkpoint from file " + file) try { - var ois: ObjectInputStreamWithLoader = null - var cp: Checkpoint = null - Utils.tryWithSafeFinally { - val fis = fs.open(file) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val zis = compressionCodec.compressedInputStream(fis) - ois = new ObjectInputStreamWithLoader(zis, - Thread.currentThread().getContextClassLoader) - cp = ois.readObject.asInstanceOf[Checkpoint] - } { - if (ois != null) { - ois.close() - } - } - cp.validate() + val fis = fs.open(file) + val cp = Checkpoint.deserialize(fis, conf) logInfo("Checkpoint successfully loaded from file " + file) logInfo("Checkpoint was generated at time " + cp.checkpointTime) return Some(cp) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 85b354ff4aa0d..40789c66f3991 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -157,10 +157,10 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def validate() { this.synchronized { - assert(batchDuration != null, "Batch duration has not been set") + require(batchDuration != null, "Batch duration has not been set") // assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + // " is very low") - assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute") + require(getOutputStreams().size > 0, "No output operations registered, so nothing to execute") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 407cab45ed4c6..9cd9684d36404 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -17,12 +17,13 @@ package org.apache.spark.streaming -import java.io.InputStream +import java.io.{InputStream, NotSerializableException} import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.Map import scala.collection.mutable.Queue import scala.reflect.ClassTag +import scala.util.control.NonFatal import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.conf.Configuration @@ -34,14 +35,15 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.input.FixedLengthBinaryInputFormat -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, RDDOperationScope} +import org.apache.spark.serializer.SerializationDebugger import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContextState._ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.CallSite +import org.apache.spark.util.{CallSite, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -134,7 +136,7 @@ class StreamingContext private[streaming] ( if (sc_ != null) { sc_ } else if (isCheckpointPresent) { - new SparkContext(cp_.createSparkConf()) + SparkContext.getOrCreate(cp_.createSparkConf()) } else { throw new SparkException("Cannot create StreamingContext without a SparkContext") } @@ -155,7 +157,7 @@ class StreamingContext private[streaming] ( cp_.graph.restoreCheckpointData() cp_.graph } else { - assert(batchDur_ != null, "Batch duration for streaming context cannot be null") + require(batchDur_ != null, "Batch duration for StreamingContext cannot be null") val newGraph = new DStreamGraph() newGraph.setBatchDuration(batchDur_) newGraph @@ -200,6 +202,8 @@ class StreamingContext private[streaming] ( private val startSite = new AtomicReference[CallSite](null) + private var shutdownHookRef: AnyRef = _ + /** * Return the associated Spark context */ @@ -235,21 +239,46 @@ class StreamingContext private[streaming] ( } } + private[streaming] def isCheckpointingEnabled: Boolean = { + checkpointDir != null + } + private[streaming] def initialCheckpoint: Checkpoint = { if (isCheckpointPresent) cp_ else null } private[streaming] def getNewInputStreamId() = nextInputStreamId.getAndIncrement() + /** + * Execute a block of code in a scope such that all new DStreams created in this body will + * be part of the same scope. For more detail, see the comments in `doCompute`. + * + * Note: Return statements are NOT allowed in the given body. + */ + private[streaming] def withScope[U](body: => U): U = sparkContext.withScope(body) + + /** + * Execute a block of code in a scope such that all new DStreams created in this body will + * be part of the same scope. For more detail, see the comments in `doCompute`. + * + * Note: Return statements are NOT allowed in the given body. + */ + private[streaming] def withNamedScope[U](name: String)(body: => U): U = { + RDDOperationScope.withScope(sc, name, allowNesting = false, ignoreParent = false)(body) + } + /** * Create an input stream with any arbitrary user implemented receiver. * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver + * + * @deprecated As of 1.0.0", replaced by `receiverStream`. */ @deprecated("Use receiverStream", "1.0.0") - def networkStream[T: ClassTag]( - receiver: Receiver[T]): ReceiverInputDStream[T] = { - receiverStream(receiver) + def networkStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { + withNamedScope("network stream") { + receiverStream(receiver) + } } /** @@ -257,9 +286,10 @@ class StreamingContext private[streaming] ( * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver */ - def receiverStream[T: ClassTag]( - receiver: Receiver[T]): ReceiverInputDStream[T] = { - new PluggableInputDStream[T](this, receiver) + def receiverStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { + withNamedScope("receiver stream") { + new PluggableInputDStream[T](this, receiver) + } } /** @@ -279,7 +309,7 @@ class StreamingContext private[streaming] ( name: String, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, supervisorStrategy: SupervisorStrategy = ActorSupervisorStrategy.defaultStrategy - ): ReceiverInputDStream[T] = { + ): ReceiverInputDStream[T] = withNamedScope("actor stream") { receiverStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) } @@ -296,7 +326,7 @@ class StreamingContext private[streaming] ( hostname: String, port: Int, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[String] = { + ): ReceiverInputDStream[String] = withNamedScope("socket text stream") { socketStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel) } @@ -334,7 +364,7 @@ class StreamingContext private[streaming] ( hostname: String, port: Int, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 - ): ReceiverInputDStream[T] = { + ): ReceiverInputDStream[T] = withNamedScope("raw socket stream") { new RawInputDStream[T](this, hostname, port, storageLevel) } @@ -408,7 +438,7 @@ class StreamingContext private[streaming] ( * file system. File names starting with . are ignored. * @param directory HDFS directory to monitor for new file */ - def textFileStream(directory: String): DStream[String] = { + def textFileStream(directory: String): DStream[String] = withNamedScope("text file stream") { fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString) } @@ -430,14 +460,15 @@ class StreamingContext private[streaming] ( @Experimental def binaryRecordsStream( directory: String, - recordLength: Int): DStream[Array[Byte]] = { + recordLength: Int): DStream[Array[Byte]] = withNamedScope("binary records stream") { val conf = sc_.hadoopConfiguration conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) val br = fileStream[LongWritable, BytesWritable, FixedLengthBinaryInputFormat]( - directory, FileInputDStream.defaultFilter : Path => Boolean, newFilesOnly=true, conf) + directory, FileInputDStream.defaultFilter: Path => Boolean, newFilesOnly = true, conf) val data = br.map { case (k, v) => val bytes = v.getBytes - assert(bytes.length == recordLength, "Byte array does not have correct length") + require(bytes.length == recordLength, "Byte array does not have correct length. " + + s"${bytes.length} did not equal recordLength: $recordLength") bytes } data @@ -477,7 +508,7 @@ class StreamingContext private[streaming] ( /** * Create a unified DStream from multiple DStreams of the same type and same slide duration. */ - def union[T: ClassTag](streams: Seq[DStream[T]]): DStream[T] = { + def union[T: ClassTag](streams: Seq[DStream[T]]): DStream[T] = withScope { new UnionDStream[T](streams.toArray) } @@ -488,7 +519,7 @@ class StreamingContext private[streaming] ( def transform[T: ClassTag]( dstreams: Seq[DStream[_]], transformFunc: (Seq[RDD[_]], Time) => RDD[T] - ): DStream[T] = { + ): DStream[T] = withScope { new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc)) } @@ -503,11 +534,26 @@ class StreamingContext private[streaming] ( assert(graph != null, "Graph is null") graph.validate() - assert( - checkpointDir == null || checkpointDuration != null, + require( + !isCheckpointingEnabled || checkpointDuration != null, "Checkpoint directory has been set, but the graph checkpointing interval has " + "not been set. Please use StreamingContext.checkpoint() to set the interval." ) + + // Verify whether the DStream checkpoint is serializable + if (isCheckpointingEnabled) { + val checkpoint = new Checkpoint(this, Time.apply(0)) + try { + Checkpoint.serialize(checkpoint, conf) + } catch { + case e: NotSerializableException => + throw new NotSerializableException( + "DStream checkpointing has been enabled but the DStreams with their functions " + + "are not serializable\nSerialization stack:\n" + + SerializationDebugger.find(checkpoint).map("\t- " + _).mkString("\n") + ) + } + } } /** @@ -528,26 +574,36 @@ class StreamingContext private[streaming] ( /** * Start the execution of the streams. * - * @throws SparkException if the StreamingContext is already stopped. + * @throws IllegalStateException if the StreamingContext is already stopped. */ def start(): Unit = synchronized { state match { case INITIALIZED => - validate() startSite.set(DStream.getCreationSite()) sparkContext.setCallSite(startSite.get) StreamingContext.ACTIVATION_LOCK.synchronized { StreamingContext.assertNoOtherContextIsActive() - scheduler.start() - uiTab.foreach(_.attach()) - state = StreamingContextState.ACTIVE + try { + validate() + scheduler.start() + state = StreamingContextState.ACTIVE + } catch { + case NonFatal(e) => + logError("Error starting the context, marking it as stopped", e) + scheduler.stop(false) + state = StreamingContextState.STOPPED + throw e + } StreamingContext.setActiveContext(this) } + shutdownHookRef = Utils.addShutdownHook( + StreamingContext.SHUTDOWN_HOOK_PRIORITY)(stopOnShutdown) + uiTab.foreach(_.attach()) logInfo("StreamingContext started") case ACTIVE => logWarning("StreamingContext has already been started") case STOPPED => - throw new SparkException("StreamingContext has already been stopped") + throw new IllegalStateException("StreamingContext has already been stopped") } } @@ -563,6 +619,8 @@ class StreamingContext private[streaming] ( * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. * @param timeout time to wait in milliseconds + * + * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. */ @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long) { @@ -619,6 +677,9 @@ class StreamingContext private[streaming] ( uiTab.foreach(_.detach()) StreamingContext.setActiveContext(null) waiter.notifyStop() + if (shutdownHookRef != null) { + Utils.removeShutdownHook(shutdownHookRef) + } logInfo("StreamingContext stopped successfully") } // Even if we have already stopped, we still need to attempt to stop the SparkContext because @@ -629,6 +690,13 @@ class StreamingContext private[streaming] ( state = STOPPED } } + + private def stopOnShutdown(): Unit = { + val stopGracefully = conf.getBoolean("spark.streaming.stopGracefullyOnShutdown", false) + logInfo(s"Invoking stop(stopGracefully=$stopGracefully) from shutdown hook") + // Do not stop SparkContext, let its own shutdown hook stop it + stop(stopSparkContext = false, stopGracefully = stopGracefully) + } } /** @@ -644,12 +712,14 @@ object StreamingContext extends Logging { */ private val ACTIVATION_LOCK = new Object() + private val SHUTDOWN_HOOK_PRIORITY = Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY + 1 + private val activeContext = new AtomicReference[StreamingContext](null) private def assertNoOtherContextIsActive(): Unit = { ACTIVATION_LOCK.synchronized { if (activeContext.get() != null) { - throw new SparkException( + throw new IllegalStateException( "Only one StreamingContext may be started in this JVM. " + "Currently running StreamingContext was started at" + activeContext.get.startSite.get.longForm) @@ -675,6 +745,10 @@ object StreamingContext extends Logging { } } + /** + * @deprecated As of 1.3.0, replaced by implicit functions in the DStream companion object. + * This is kept here only for backward compatibility. + */ @deprecated("Replaced by implicit functions in the DStream companion object. This is " + "kept here only for backward compatibility.", "1.3.0") def toPairDStreamFunctions[K, V](stream: DStream[(K, V)]) @@ -750,53 +824,6 @@ object StreamingContext extends Logging { checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) } - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the StreamingContext - * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note - * that the SparkConf configuration in the checkpoint data will not be restored as the - * SparkContext has already been created. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new StreamingContext using the given SparkContext - * @param sparkContext SparkContext using which the StreamingContext will be created - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: SparkContext => StreamingContext, - sparkContext: SparkContext - ): StreamingContext = { - getOrCreate(checkpointPath, creatingFunc, sparkContext, createOnError = false) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the StreamingContext - * will be created by called the provided `creatingFunc` on the provided `sparkContext`. Note - * that the SparkConf configuration in the checkpoint data will not be restored as the - * SparkContext has already been created. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new StreamingContext using the given SparkContext - * @param sparkContext SparkContext using which the StreamingContext will be created - * @param createOnError Whether to create a new StreamingContext if there is an - * error in reading checkpoint data. By default, an exception will be - * thrown on error. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: SparkContext => StreamingContext, - sparkContext: SparkContext, - createOnError: Boolean - ): StreamingContext = { - val checkpointOption = CheckpointReader.read( - checkpointPath, sparkContext.conf, sparkContext.hadoopConfiguration, createOnError) - checkpointOption.map(new StreamingContext(sparkContext, _, null)) - .getOrElse(creatingFunc(sparkContext)) - } - /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 93baad19e3ee1..959ac9c177f81 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -227,7 +227,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * @param numPartitions Number of partitions of each RDD in the new DStream. */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - :JavaPairDStream[K, JIterable[V]] = { + : JavaPairDStream[K, JIterable[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) .mapValues(asJavaIterable _) } @@ -247,7 +247,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ):JavaPairDStream[K, JIterable[V]] = { + ): JavaPairDStream[K, JIterable[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) .mapValues(asJavaIterable _) } @@ -262,7 +262,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * batching interval */ def reduceByKeyAndWindow(reduceFunc: JFunction2[V, V, V], windowDuration: Duration) - :JavaPairDStream[K, V] = { + : JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration) } @@ -281,7 +281,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( reduceFunc: JFunction2[V, V, V], windowDuration: Duration, slideDuration: Duration - ):JavaPairDStream[K, V] = { + ): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index d8fbed2c50644..989e3a729ebc2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -148,6 +148,9 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { /** The underlying SparkContext */ val sparkContext = new JavaSparkContext(ssc.sc) + /** + * @deprecated As of 0.9.0, replaced by `sparkContext` + */ @deprecated("use sparkContext", "0.9.0") val sc: JavaSparkContext = sparkContext @@ -619,6 +622,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Wait for the execution to stop. Any exceptions that occurs during the execution * will be thrown in this thread. * @param timeout time to wait in milliseconds + * @deprecated As of 1.3.0, replaced by `awaitTerminationOrTimeout(Long)`. */ @deprecated("Use awaitTerminationOrTimeout(Long) instead", "1.3.0") def awaitTermination(timeout: Long): Unit = { @@ -677,6 +681,7 @@ object JavaStreamingContext { * * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") def getOrCreate( @@ -699,6 +704,7 @@ object JavaStreamingContext { * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( @@ -724,6 +730,7 @@ object JavaStreamingContext { * file system * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( @@ -804,51 +811,6 @@ object JavaStreamingContext { new JavaStreamingContext(ssc) } - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param sparkContext SparkContext using which the StreamingContext will be created - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], - sparkContext: JavaSparkContext - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { - creatingFunc.call(new JavaSparkContext(sparkContext)).ssc - }, sparkContext.sc) - new JavaStreamingContext(ssc) - } - - /** - * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. - * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be - * recreated from the checkpoint data. If the data does not exist, then the provided factory - * will be used to create a JavaStreamingContext. - * - * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program - * @param creatingFunc Function to create a new JavaStreamingContext - * @param sparkContext SparkContext using which the StreamingContext will be created - * @param createOnError Whether to create a new JavaStreamingContext if there is an - * error in reading checkpoint data. - */ - def getOrCreate( - checkpointPath: String, - creatingFunc: JFunction[JavaSparkContext, JavaStreamingContext], - sparkContext: JavaSparkContext, - createOnError: Boolean - ): JavaStreamingContext = { - val ssc = StreamingContext.getOrCreate(checkpointPath, (sparkContext: SparkContext) => { - creatingFunc.call(new JavaSparkContext(sparkContext)).ssc - }, sparkContext.sc, createOnError) - new JavaStreamingContext(ssc) - } - /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 4c28654ef6413..d06401245ff17 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -109,7 +109,7 @@ private[python] object PythonTransformFunctionSerializer { } def serialize(func: PythonTransformFunction): Array[Byte] = { - assert(serializer != null, "Serializer has not been registered!") + require(serializer != null, "Serializer has not been registered!") // get the id of PythonTransformFunction in py4j val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy]) val f = h.getClass().getDeclaredField("id") @@ -119,7 +119,7 @@ private[python] object PythonTransformFunctionSerializer { } def deserialize(bytes: Array[Byte]): PythonTransformFunction = { - assert(serializer != null, "Serializer has not been registered!") + require(serializer != null, "Serializer has not been registered!") serializer.loads(bytes) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 64de7526a6a34..192aa6a139bcb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -25,12 +25,13 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import scala.util.matching.Regex -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.rdd.{BlockRDD, PairRDDFunctions, RDD} +import org.apache.spark.{Logging, SparkContext, SparkException} +import org.apache.spark.rdd.{BlockRDD, PairRDDFunctions, RDD, RDDOperationScope} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext.rddToFileName import org.apache.spark.streaming.scheduler.Job +import org.apache.spark.streaming.ui.UIUtils import org.apache.spark.util.{CallSite, MetadataCleaner, Utils} /** @@ -73,7 +74,7 @@ abstract class DStream[T: ClassTag] ( def dependencies: List[DStream[_]] /** Method that generates a RDD for the given time */ - def compute (validTime: Time): Option[RDD[T]] + def compute(validTime: Time): Option[RDD[T]] // ======================================================================= // Methods and fields available on all DStreams @@ -111,6 +112,44 @@ abstract class DStream[T: ClassTag] ( /* Set the creation call site */ private[streaming] val creationSite = DStream.getCreationSite() + /** + * The base scope associated with the operation that created this DStream. + * + * This is the medium through which we pass the DStream operation name (e.g. updatedStateByKey) + * to the RDDs created by this DStream. Note that we never use this scope directly in RDDs. + * Instead, we instantiate a new scope during each call to `compute` based on this one. + * + * This is not defined if the DStream is created outside of one of the public DStream operations. + */ + protected[streaming] val baseScope: Option[String] = { + Option(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY)) + } + + /** + * Make a scope that groups RDDs created in the same DStream operation in the same batch. + * + * Each DStream produces many scopes and each scope may be shared by other DStreams created + * in the same operation. Separate calls to the same DStream operation create separate scopes. + * For instance, `dstream.map(...).map(...)` creates two separate scopes per batch. + */ + private def makeScope(time: Time): Option[RDDOperationScope] = { + baseScope.map { bsJson => + val formattedBatchTime = UIUtils.formatBatchTime( + time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) + val bs = RDDOperationScope.fromJson(bsJson) + val baseName = bs.name // e.g. countByWindow, "kafka stream [0]" + val scopeName = + if (baseName.length > 10) { + // If the operation name is too long, wrap the line + s"$baseName\n@ $formattedBatchTime" + } else { + s"$baseName @ $formattedBatchTime" + } + val scopeId = s"${bs.id}_${time.milliseconds}" + new RDDOperationScope(scopeName, id = scopeId) + } + } + /** Persist the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { @@ -178,53 +217,52 @@ abstract class DStream[T: ClassTag] ( case StreamingContextState.INITIALIZED => // good to go case StreamingContextState.ACTIVE => - throw new SparkException( + throw new IllegalStateException( "Adding new inputs, transformations, and output operations after " + "starting a context is not supported") case StreamingContextState.STOPPED => - throw new SparkException( + throw new IllegalStateException( "Adding new inputs, transformations, and output operations after " + "stopping a context is not supported") } } private[streaming] def validateAtStart() { - assert(rememberDuration != null, "Remember duration is set to null") + require(rememberDuration != null, "Remember duration is set to null") - assert( + require( !mustCheckpoint || checkpointDuration != null, "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." + " Please use DStream.checkpoint() to set the interval." ) - assert( + require( checkpointDuration == null || context.sparkContext.checkpointDir.isDefined, - "The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" + - " or SparkContext.checkpoint() to set the checkpoint directory." + "The checkpoint directory has not been set. Please set it by StreamingContext.checkpoint()." ) - assert( + require( checkpointDuration == null || checkpointDuration >= slideDuration, "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + checkpointDuration + " which is lower than its slide time (" + slideDuration + "). " + "Please set it to at least " + slideDuration + "." ) - assert( + require( checkpointDuration == null || checkpointDuration.isMultipleOf(slideDuration), "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + checkpointDuration + " which not a multiple of its slide time (" + slideDuration + "). " + - "Please set it to a multiple " + slideDuration + "." + "Please set it to a multiple of " + slideDuration + "." ) - assert( + require( checkpointDuration == null || storageLevel != StorageLevel.NONE, "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + "level has not been set to enable persisting. Please use DStream.persist() to set the " + "storage level to use memory for better checkpointing performance." ) - assert( + require( checkpointDuration == null || rememberDuration > checkpointDuration, "The remember duration for " + this.getClass.getSimpleName + " has been set to " + rememberDuration + " which is not more than the checkpoint interval (" + @@ -233,7 +271,7 @@ abstract class DStream[T: ClassTag] ( val metadataCleanerDelay = MetadataCleaner.getDelaySeconds(ssc.conf) logInfo("metadataCleanupDelay = " + metadataCleanerDelay) - assert( + require( metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000, "It seems you are doing some DStream window operation or setting a checkpoint interval " + "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + @@ -295,28 +333,23 @@ abstract class DStream[T: ClassTag] ( * Get the RDD corresponding to the given time; either retrieve it from cache * or compute-and-cache it. */ - private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { + private[streaming] final def getOrCompute(time: Time): Option[RDD[T]] = { // If RDD was already generated, then retrieve it from HashMap, // or else compute the RDD generatedRDDs.get(time).orElse { // Compute the RDD if time is valid (e.g. correct time in a sliding window) // of RDD generation, else generate nothing. if (isTimeValid(time)) { - // Set the thread-local property for call sites to this DStream's creation site - // such that RDDs generated by compute gets that as their creation site. - // Note that this `getOrCompute` may get called from another DStream which may have - // set its own call site. So we store its call site in a temporary variable, - // set this DStream's creation site, generate RDDs and then restore the previous call site. - val prevCallSite = ssc.sparkContext.getCallSite() - ssc.sparkContext.setCallSite(creationSite) - // Disable checks for existing output directories in jobs launched by the streaming - // scheduler, since we may need to write output to an existing directory during checkpoint - // recovery; see SPARK-4835 for more details. We need to have this call here because - // compute() might cause Spark jobs to be launched. - val rddOption = PairRDDFunctions.disableOutputSpecValidation.withValue(true) { - compute(time) + + val rddOption = createRDDWithLocalProperties(time) { + // Disable checks for existing output directories in jobs launched by the streaming + // scheduler, since we may need to write output to an existing directory during checkpoint + // recovery; see SPARK-4835 for more details. We need to have this call here because + // compute() might cause Spark jobs to be launched. + PairRDDFunctions.disableOutputSpecValidation.withValue(true) { + compute(time) + } } - ssc.sparkContext.setCallSite(prevCallSite) rddOption.foreach { case newRDD => // Register the generated RDD for caching and checkpointing @@ -337,6 +370,41 @@ abstract class DStream[T: ClassTag] ( } } + /** + * Wrap a body of code such that the call site and operation scope + * information are passed to the RDDs created in this body properly. + */ + protected def createRDDWithLocalProperties[U](time: Time)(body: => U): U = { + val scopeKey = SparkContext.RDD_SCOPE_KEY + val scopeNoOverrideKey = SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY + // Pass this DStream's operation scope and creation site information to RDDs through + // thread-local properties in our SparkContext. Since this method may be called from another + // DStream, we need to temporarily store any old scope and creation site information to + // restore them later after setting our own. + val prevCallSite = ssc.sparkContext.getCallSite() + val prevScope = ssc.sparkContext.getLocalProperty(scopeKey) + val prevScopeNoOverride = ssc.sparkContext.getLocalProperty(scopeNoOverrideKey) + + try { + ssc.sparkContext.setCallSite(creationSite) + // Use the DStream's base scope for this RDD so we can (1) preserve the higher level + // DStream operation name, and (2) share this scope with other DStreams created in the + // same operation. Disallow nesting so that low-level Spark primitives do not show up. + // TODO: merge callsites with scopes so we can just reuse the code there + makeScope(time).foreach { s => + ssc.sparkContext.setLocalProperty(scopeKey, s.toJson) + ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, "true") + } + + body + } finally { + // Restore any state that was modified before returning + ssc.sparkContext.setCallSite(prevCallSite) + ssc.sparkContext.setLocalProperty(scopeKey, prevScope) + ssc.sparkContext.setLocalProperty(scopeNoOverrideKey, prevScopeNoOverride) + } + } + /** * Generate a SparkStreaming job for the given time. This is an internal method that * should not be called directly. This default implementation creates a job @@ -456,7 +524,7 @@ abstract class DStream[T: ClassTag] ( // ======================================================================= /** Return a new DStream by applying a function to all elements of this DStream. */ - def map[U: ClassTag](mapFunc: T => U): DStream[U] = { + def map[U: ClassTag](mapFunc: T => U): DStream[U] = ssc.withScope { new MappedDStream(this, context.sparkContext.clean(mapFunc)) } @@ -464,26 +532,31 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream by applying a function to all elements of this DStream, * and then flattening the results */ - def flatMap[U: ClassTag](flatMapFunc: T => Traversable[U]): DStream[U] = { + def flatMap[U: ClassTag](flatMapFunc: T => Traversable[U]): DStream[U] = ssc.withScope { new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc)) } /** Return a new DStream containing only the elements that satisfy a predicate. */ - def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc) + def filter(filterFunc: T => Boolean): DStream[T] = ssc.withScope { + new FilteredDStream(this, context.sparkContext.clean(filterFunc)) + } /** * Return a new DStream in which each RDD is generated by applying glom() to each RDD of * this DStream. Applying glom() to an RDD coalesces all elements within each partition into * an array. */ - def glom(): DStream[Array[T]] = new GlommedDStream(this) - + def glom(): DStream[Array[T]] = ssc.withScope { + new GlommedDStream(this) + } /** * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the * returned DStream has exactly numPartitions partitions. */ - def repartition(numPartitions: Int): DStream[T] = this.transform(_.repartition(numPartitions)) + def repartition(numPartitions: Int): DStream[T] = ssc.withScope { + this.transform(_.repartition(numPartitions)) + } /** * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs @@ -493,7 +566,7 @@ abstract class DStream[T: ClassTag] ( def mapPartitions[U: ClassTag]( mapPartFunc: Iterator[T] => Iterator[U], preservePartitioning: Boolean = false - ): DStream[U] = { + ): DStream[U] = ssc.withScope { new MapPartitionedDStream(this, context.sparkContext.clean(mapPartFunc), preservePartitioning) } @@ -501,14 +574,15 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream in which each RDD has a single element generated by reducing each RDD * of this DStream. */ - def reduce(reduceFunc: (T, T) => T): DStream[T] = + def reduce(reduceFunc: (T, T) => T): DStream[T] = ssc.withScope { this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2) + } /** * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): DStream[Long] = { + def count(): DStream[Long] = ssc.withScope { this.map(_ => (null, 1L)) .transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))) .reduceByKey(_ + _) @@ -522,24 +596,29 @@ abstract class DStream[T: ClassTag] ( * `numPartitions` not specified). */ def countByValue(numPartitions: Int = ssc.sc.defaultParallelism)(implicit ord: Ordering[T] = null) - : DStream[(T, Long)] = + : DStream[(T, Long)] = ssc.withScope { this.map(x => (x, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) + } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of 0.9.0, replaced by `foreachRDD`. */ @deprecated("use foreachRDD", "0.9.0") - def foreach(foreachFunc: RDD[T] => Unit): Unit = { + def foreach(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { this.foreachRDD(foreachFunc) } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. + * + * @deprecated As of 0.9.0, replaced by `foreachRDD`. */ @deprecated("use foreachRDD", "0.9.0") - def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = { + def foreach(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope { this.foreachRDD(foreachFunc) } @@ -547,17 +626,18 @@ abstract class DStream[T: ClassTag] ( * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ - def foreachRDD(foreachFunc: RDD[T] => Unit) { - this.foreachRDD((r: RDD[T], t: Time) => foreachFunc(r)) + def foreachRDD(foreachFunc: RDD[T] => Unit): Unit = ssc.withScope { + val cleanedF = context.sparkContext.clean(foreachFunc, false) + this.foreachRDD((r: RDD[T], t: Time) => cleanedF(r)) } /** * Apply a function to each RDD in this DStream. This is an output operator, so * 'this' DStream will be registered as an output stream and therefore materialized. */ - def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + def foreachRDD(foreachFunc: (RDD[T], Time) => Unit): Unit = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean new ForEachDStream(this, context.sparkContext.clean(foreachFunc, false)).register() } @@ -566,9 +646,9 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. */ - def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = context.sparkContext.clean(transformFunc, false) transform((r: RDD[T], t: Time) => cleanedF(r)) @@ -578,12 +658,12 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream in which each RDD is generated by applying a function * on each RDD of 'this' DStream. */ - def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = context.sparkContext.clean(transformFunc, false) - val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { + val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { assert(rdds.length == 1) cleanedF(rdds.head.asInstanceOf[RDD[T]], time) } @@ -596,9 +676,9 @@ abstract class DStream[T: ClassTag] ( */ def transformWith[U: ClassTag, V: ClassTag]( other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V] - ): DStream[V] = { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + ): DStream[V] = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = ssc.sparkContext.clean(transformFunc, false) transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2)) @@ -610,9 +690,9 @@ abstract class DStream[T: ClassTag] ( */ def transformWith[U: ClassTag, V: ClassTag]( other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V] - ): DStream[V] = { - // because the DStream is reachable from the outer object here, and because - // DStreams can't be serialized with closures, we can't proactively check + ): DStream[V] = ssc.withScope { + // because the DStream is reachable from the outer object here, and because + // DStreams can't be serialized with closures, we can't proactively check // it for serializability and so we pass the optional false to SparkContext.clean val cleanedF = ssc.sparkContext.clean(transformFunc, false) val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => { @@ -628,7 +708,7 @@ abstract class DStream[T: ClassTag] ( * Print the first ten elements of each RDD generated in this DStream. This is an output * operator, so this DStream will be registered as an output stream and there materialized. */ - def print() { + def print(): Unit = ssc.withScope { print(10) } @@ -636,7 +716,7 @@ abstract class DStream[T: ClassTag] ( * Print the first num elements of each RDD generated in this DStream. This is an output * operator, so this DStream will be registered as an output stream and there materialized. */ - def print(num: Int) { + def print(num: Int): Unit = ssc.withScope { def foreachFunc: (RDD[T], Time) => Unit = { (rdd: RDD[T], time: Time) => { val firstNum = rdd.take(num + 1) @@ -668,7 +748,7 @@ abstract class DStream[T: ClassTag] ( * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval */ - def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = { + def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = ssc.withScope { new WindowedDStream(this, windowDuration, slideDuration) } @@ -686,7 +766,7 @@ abstract class DStream[T: ClassTag] ( reduceFunc: (T, T) => T, windowDuration: Duration, slideDuration: Duration - ): DStream[T] = { + ): DStream[T] = ssc.withScope { this.reduce(reduceFunc).window(windowDuration, slideDuration).reduce(reduceFunc) } @@ -711,7 +791,7 @@ abstract class DStream[T: ClassTag] ( invReduceFunc: (T, T) => T, windowDuration: Duration, slideDuration: Duration - ): DStream[T] = { + ): DStream[T] = ssc.withScope { this.map(x => (1, x)) .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, 1) .map(_._2) @@ -727,7 +807,9 @@ abstract class DStream[T: ClassTag] ( * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval */ - def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Long] = { + def countByWindow( + windowDuration: Duration, + slideDuration: Duration): DStream[Long] = ssc.withScope { this.map(_ => 1L).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) } @@ -748,8 +830,7 @@ abstract class DStream[T: ClassTag] ( slideDuration: Duration, numPartitions: Int = ssc.sc.defaultParallelism) (implicit ord: Ordering[T] = null) - : DStream[(T, Long)] = - { + : DStream[(T, Long)] = ssc.withScope { this.map(x => (x, 1L)).reduceByKeyAndWindow( (x: Long, y: Long) => x + y, (x: Long, y: Long) => x - y, @@ -764,19 +845,21 @@ abstract class DStream[T: ClassTag] ( * Return a new DStream by unifying data of another DStream with this DStream. * @param that Another DStream having the same slideDuration as this DStream. */ - def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that)) + def union(that: DStream[T]): DStream[T] = ssc.withScope { + new UnionDStream[T](Array(this, that)) + } /** * Return all the RDDs defined by the Interval object (both end times included) */ - def slice(interval: Interval): Seq[RDD[T]] = { + def slice(interval: Interval): Seq[RDD[T]] = ssc.withScope { slice(interval.beginTime, interval.endTime) } /** * Return all the RDDs between 'fromTime' to 'toTime' (both included) */ - def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { + def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = ssc.withScope { if (!isInitialized) { throw new SparkException(this + " has not been initialized") } @@ -810,7 +893,7 @@ abstract class DStream[T: ClassTag] ( * The file name at each batch interval is generated based on `prefix` and * `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsObjectFiles(prefix: String, suffix: String = "") { + def saveAsObjectFiles(prefix: String, suffix: String = ""): Unit = ssc.withScope { val saveFunc = (rdd: RDD[T], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsObjectFile(file) @@ -823,7 +906,7 @@ abstract class DStream[T: ClassTag] ( * of elements. The file name at each batch interval is generated based on * `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix". */ - def saveAsTextFiles(prefix: String, suffix: String = "") { + def saveAsTextFiles(prefix: String, suffix: String = ""): Unit = ssc.withScope { val saveFunc = (rdd: RDD[T], time: Time) => { val file = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(file) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index eca69f00188e4..6c1fab56740ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -69,7 +69,7 @@ import org.apache.spark.util.{TimeStampedHashMap, Utils} * processing semantics are undefined. */ private[streaming] -class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( +class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( @transient ssc_ : StreamingContext, directory: String, filter: Path => Boolean = FileInputDStream.defaultFilter, @@ -251,7 +251,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( /** Generate one RDD from an array of files */ private def filesToRDD(files: Seq[String]): RDD[(K, V)] = { - val fileRDDs = files.map(file =>{ + val fileRDDs = files.map { file => val rdd = serializableConfOpt.map(_.value) match { case Some(config) => context.sparkContext.newAPIHadoopFile( file, @@ -267,7 +267,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( "Refer to the streaming programming guide for more details.") } rdd - }) + } new UnionRDD(context.sparkContext, fileRDDs) } @@ -294,7 +294,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]]( private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() - generatedRDDs = new mutable.HashMap[Time, RDD[(K,V)]] () + generatedRDDs = new mutable.HashMap[Time, RDD[(K, V)]]() batchTimeToSelectedFiles = new mutable.HashMap[Time, Array[String]] with mutable.SynchronizedMap[Time, Array[String]] recentlySelectedFiles = new mutable.HashSet[String]() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index 685a32e1d280d..c109ceccc6989 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -37,7 +37,7 @@ class ForEachDStream[T: ClassTag] ( override def generateJob(time: Time): Option[Job] = { parent.getOrCompute(time) match { case Some(rdd) => - val jobFunc = () => { + val jobFunc = () => createRDDWithLocalProperties(time) { ssc.sparkContext.setCallSite(creationSite) foreachFunc(rdd, time) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 9716adb62817c..d58c99a8ff321 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -17,10 +17,13 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Time, Duration, StreamingContext} - import scala.reflect.ClassTag +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDDOperationScope +import org.apache.spark.streaming.{Time, Duration, StreamingContext} +import org.apache.spark.util.Utils + /** * This is the abstract base class for all input streams. This class provides methods * start() and stop() which is called by Spark Streaming system to start and stop receiving data. @@ -44,10 +47,31 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) /** This is an unique identifier for the input stream. */ val id = ssc.getNewInputStreamId() + /** A human-readable name of this InputDStream */ + private[streaming] def name: String = { + // e.g. FlumePollingDStream -> "Flume polling stream" + val newName = Utils.getFormattedClassName(this) + .replaceAll("InputDStream", "Stream") + .split("(?=[A-Z])") + .filter(_.nonEmpty) + .mkString(" ") + .toLowerCase + .capitalize + s"$newName [$id]" + } + /** - * The name of this InputDStream. By default, it's the class name with its id. + * The base scope associated with the operation that created this DStream. + * + * For InputDStreams, we use the name of this DStream as the scope name. + * If an outer scope is given, we assume that it includes an alternative name for this stream. */ - private[streaming] def name: String = s"${getClass.getSimpleName}-$id" + protected[streaming] override val baseScope: Option[String] = { + val scopeName = Option(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY)) + .map { json => RDDOperationScope.fromJson(json).name + s" [$id]" } + .getOrElse(name.toLowerCase) + Some(new RDDOperationScope(scopeName).toJson) + } /** * Checks whether the 'time' is valid wrt slideDuration for generating RDD. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 8a58571632447..358e4c66df7ba 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -32,12 +32,14 @@ import org.apache.spark.streaming.StreamingContext.rddToFileName /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. */ -class PairDStreamFunctions[K, V](self: DStream[(K,V)]) +class PairDStreamFunctions[K, V](self: DStream[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K]) extends Serializable { private[streaming] def ssc = self.ssc + private[streaming] def sparkContext = self.context.sparkContext + private[streaming] def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = { new HashPartitioner(numPartitions) } @@ -46,7 +48,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. */ - def groupByKey(): DStream[(K, Iterable[V])] = { + def groupByKey(): DStream[(K, Iterable[V])] = ssc.withScope { groupByKey(defaultPartitioner()) } @@ -54,7 +56,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ - def groupByKey(numPartitions: Int): DStream[(K, Iterable[V])] = { + def groupByKey(numPartitions: Int): DStream[(K, Iterable[V])] = ssc.withScope { groupByKey(defaultPartitioner(numPartitions)) } @@ -62,7 +64,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying `groupByKey` on each RDD. The supplied * org.apache.spark.Partitioner is used to control the partitioning of each RDD. */ - def groupByKey(partitioner: Partitioner): DStream[(K, Iterable[V])] = { + def groupByKey(partitioner: Partitioner): DStream[(K, Iterable[V])] = ssc.withScope { val createCombiner = (v: V) => ArrayBuffer[V](v) val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v) val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2) @@ -75,7 +77,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * merged using the associative reduce function. Hash partitioning is used to generate the RDDs * with Spark's default number of partitions. */ - def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = { + def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = ssc.withScope { reduceByKey(reduceFunc, defaultPartitioner()) } @@ -84,7 +86,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs * with `numPartitions` partitions. */ - def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = { + def reduceByKey( + reduceFunc: (V, V) => V, + numPartitions: Int): DStream[(K, V)] = ssc.withScope { reduceByKey(reduceFunc, defaultPartitioner(numPartitions)) } @@ -93,9 +97,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * merged using the supplied reduce function. org.apache.spark.Partitioner is used to control * the partitioning of each RDD. */ - def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner) + def reduceByKey( + reduceFunc: (V, V) => V, + partitioner: Partitioner): DStream[(K, V)] = ssc.withScope { + combineByKey((v: V) => v, reduceFunc, reduceFunc, partitioner) } /** @@ -104,12 +109,20 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * org.apache.spark.rdd.PairRDDFunctions in the Spark core documentation for more information. */ def combineByKey[C: ClassTag]( - createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiner: (C, C) => C, - partitioner: Partitioner, - mapSideCombine: Boolean = true): DStream[(K, C)] = { - new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiner: (C, C) => C, + partitioner: Partitioner, + mapSideCombine: Boolean = true): DStream[(K, C)] = ssc.withScope { + val cleanedCreateCombiner = sparkContext.clean(createCombiner) + val cleanedMergeValue = sparkContext.clean(mergeValue) + val cleanedMergeCombiner = sparkContext.clean(mergeCombiner) + new ShuffledDStream[K, V, C]( + self, + cleanedCreateCombiner, + cleanedMergeValue, + cleanedMergeCombiner, + partitioner, mapSideCombine) } @@ -121,7 +134,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval */ - def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Iterable[V])] = { + def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Iterable[V])] = ssc.withScope { groupByKeyAndWindow(windowDuration, self.slideDuration, defaultPartitioner()) } @@ -136,8 +149,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * DStream's batching interval */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) - : DStream[(K, Iterable[V])] = - { + : DStream[(K, Iterable[V])] = ssc.withScope { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner()) } @@ -157,7 +169,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, numPartitions: Int - ): DStream[(K, Iterable[V])] = { + ): DStream[(K, Iterable[V])] = ssc.withScope { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner(numPartitions)) } @@ -176,7 +188,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ): DStream[(K, Iterable[V])] = { + ): DStream[(K, Iterable[V])] = ssc.withScope { val createCombiner = (v: Iterable[V]) => new ArrayBuffer[V] ++= v val mergeValue = (buf: ArrayBuffer[V], v: Iterable[V]) => buf ++= v val mergeCombiner = (buf1: ArrayBuffer[V], buf2: ArrayBuffer[V]) => buf1 ++= buf2 @@ -198,7 +210,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def reduceByKeyAndWindow( reduceFunc: (V, V) => V, windowDuration: Duration - ): DStream[(K, V)] = { + ): DStream[(K, V)] = ssc.withScope { reduceByKeyAndWindow(reduceFunc, windowDuration, self.slideDuration, defaultPartitioner()) } @@ -217,7 +229,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) reduceFunc: (V, V) => V, windowDuration: Duration, slideDuration: Duration - ): DStream[(K, V)] = { + ): DStream[(K, V)] = ssc.withScope { reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner()) } @@ -238,7 +250,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, numPartitions: Int - ): DStream[(K, V)] = { + ): DStream[(K, V)] = ssc.withScope { reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) } @@ -260,11 +272,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ): DStream[(K, V)] = { - val cleanedReduceFunc = ssc.sc.clean(reduceFunc) - self.reduceByKey(cleanedReduceFunc, partitioner) + ): DStream[(K, V)] = ssc.withScope { + self.reduceByKey(reduceFunc, partitioner) .window(windowDuration, slideDuration) - .reduceByKey(cleanedReduceFunc, partitioner) + .reduceByKey(reduceFunc, partitioner) } /** @@ -294,8 +305,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) slideDuration: Duration = self.slideDuration, numPartitions: Int = ssc.sc.defaultParallelism, filterFunc: ((K, V)) => Boolean = null - ): DStream[(K, V)] = { - + ): DStream[(K, V)] = ssc.withScope { reduceByKeyAndWindow( reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions), filterFunc @@ -328,7 +338,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) slideDuration: Duration, partitioner: Partitioner, filterFunc: ((K, V)) => Boolean - ): DStream[(K, V)] = { + ): DStream[(K, V)] = ssc.withScope { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc) @@ -349,7 +359,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) */ def updateStateByKey[S: ClassTag]( updateFunc: (Seq[V], Option[S]) => Option[S] - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { updateStateByKey(updateFunc, defaultPartitioner()) } @@ -365,7 +375,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def updateStateByKey[S: ClassTag]( updateFunc: (Seq[V], Option[S]) => Option[S], numPartitions: Int - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) } @@ -382,9 +392,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def updateStateByKey[S: ClassTag]( updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { + val cleanedUpdateF = sparkContext.clean(updateFunc) val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s))) } updateStateByKey(newUpdateFunc, partitioner, true) } @@ -406,7 +417,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None) } @@ -425,9 +436,10 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner, initialRDD: RDD[(K, S)] - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { + val cleanedUpdateF = sparkContext.clean(updateFunc) val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { - iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + iterator.flatMap(t => cleanedUpdateF(t._2, t._3).map(s => (t._1, s))) } updateStateByKey(newUpdateFunc, partitioner, true, initialRDD) } @@ -451,7 +463,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) partitioner: Partitioner, rememberPartitioner: Boolean, initialRDD: RDD[(K, S)] - ): DStream[(K, S)] = { + ): DStream[(K, S)] = ssc.withScope { new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, Some(initialRDD)) } @@ -460,8 +472,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying a map function to the value of each key-value pairs in * 'this' DStream without changing the key. */ - def mapValues[U: ClassTag](mapValuesFunc: V => U): DStream[(K, U)] = { - new MapValuedDStream[K, V, U](self, mapValuesFunc) + def mapValues[U: ClassTag](mapValuesFunc: V => U): DStream[(K, U)] = ssc.withScope { + new MapValuedDStream[K, V, U](self, sparkContext.clean(mapValuesFunc)) } /** @@ -470,8 +482,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) */ def flatMapValues[U: ClassTag]( flatMapValuesFunc: V => TraversableOnce[U] - ): DStream[(K, U)] = { - new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc) + ): DStream[(K, U)] = ssc.withScope { + new FlatMapValuedDStream[K, V, U](self, sparkContext.clean(flatMapValuesFunc)) } /** @@ -479,7 +491,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Hash partitioning is used to generate the RDDs with Spark's default number * of partitions. */ - def cogroup[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Iterable[V], Iterable[W]))] = { + def cogroup[W: ClassTag]( + other: DStream[(K, W)]): DStream[(K, (Iterable[V], Iterable[W]))] = ssc.withScope { cogroup(other, defaultPartitioner()) } @@ -487,8 +500,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. */ - def cogroup[W: ClassTag](other: DStream[(K, W)], numPartitions: Int) - : DStream[(K, (Iterable[V], Iterable[W]))] = { + def cogroup[W: ClassTag]( + other: DStream[(K, W)], + numPartitions: Int): DStream[(K, (Iterable[V], Iterable[W]))] = ssc.withScope { cogroup(other, defaultPartitioner(numPartitions)) } @@ -499,7 +513,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def cogroup[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (Iterable[V], Iterable[W]))] = { + ): DStream[(K, (Iterable[V], Iterable[W]))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.cogroup(rdd2, partitioner) @@ -510,7 +524,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream. * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. */ - def join[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, W))] = { + def join[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, W))] = ssc.withScope { join[W](other, defaultPartitioner()) } @@ -518,7 +532,9 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. */ - def join[W: ClassTag](other: DStream[(K, W)], numPartitions: Int): DStream[(K, (V, W))] = { + def join[W: ClassTag]( + other: DStream[(K, W)], + numPartitions: Int): DStream[(K, (V, W))] = ssc.withScope { join[W](other, defaultPartitioner(numPartitions)) } @@ -529,7 +545,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def join[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (V, W))] = { + ): DStream[(K, (V, W))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.join(rdd2, partitioner) @@ -541,7 +557,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default * number of partitions. */ - def leftOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, Option[W]))] = { + def leftOuterJoin[W: ClassTag]( + other: DStream[(K, W)]): DStream[(K, (V, Option[W]))] = ssc.withScope { leftOuterJoin[W](other, defaultPartitioner()) } @@ -553,7 +570,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def leftOuterJoin[W: ClassTag]( other: DStream[(K, W)], numPartitions: Int - ): DStream[(K, (V, Option[W]))] = { + ): DStream[(K, (V, Option[W]))] = ssc.withScope { leftOuterJoin[W](other, defaultPartitioner(numPartitions)) } @@ -565,7 +582,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def leftOuterJoin[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (V, Option[W]))] = { + ): DStream[(K, (V, Option[W]))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.leftOuterJoin(rdd2, partitioner) @@ -577,7 +594,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default * number of partitions. */ - def rightOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Option[V], W))] = { + def rightOuterJoin[W: ClassTag]( + other: DStream[(K, W)]): DStream[(K, (Option[V], W))] = ssc.withScope { rightOuterJoin[W](other, defaultPartitioner()) } @@ -589,7 +607,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def rightOuterJoin[W: ClassTag]( other: DStream[(K, W)], numPartitions: Int - ): DStream[(K, (Option[V], W))] = { + ): DStream[(K, (Option[V], W))] = ssc.withScope { rightOuterJoin[W](other, defaultPartitioner(numPartitions)) } @@ -601,7 +619,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def rightOuterJoin[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (Option[V], W))] = { + ): DStream[(K, (Option[V], W))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.rightOuterJoin(rdd2, partitioner) @@ -613,7 +631,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default * number of partitions. */ - def fullOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Option[V], Option[W]))] = { + def fullOuterJoin[W: ClassTag]( + other: DStream[(K, W)]): DStream[(K, (Option[V], Option[W]))] = ssc.withScope { fullOuterJoin[W](other, defaultPartitioner()) } @@ -625,7 +644,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def fullOuterJoin[W: ClassTag]( other: DStream[(K, W)], numPartitions: Int - ): DStream[(K, (Option[V], Option[W]))] = { + ): DStream[(K, (Option[V], Option[W]))] = ssc.withScope { fullOuterJoin[W](other, defaultPartitioner(numPartitions)) } @@ -637,7 +656,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def fullOuterJoin[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (Option[V], Option[W]))] = { + ): DStream[(K, (Option[V], Option[W]))] = ssc.withScope { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.fullOuterJoin(rdd2, partitioner) @@ -651,7 +670,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def saveAsHadoopFiles[F <: OutputFormat[K, V]]( prefix: String, suffix: String - )(implicit fm: ClassTag[F]) { + )(implicit fm: ClassTag[F]): Unit = ssc.withScope { saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } @@ -667,7 +686,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], conf: JobConf = new JobConf(ssc.sparkContext.hadoopConfiguration) - ) { + ): Unit = ssc.withScope { // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints val serializableConf = new SerializableWritable(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { @@ -684,7 +703,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]]( prefix: String, suffix: String - )(implicit fm: ClassTag[F]) { + )(implicit fm: ClassTag[F]): Unit = ssc.withScope { saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } @@ -700,7 +719,7 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], conf: Configuration = ssc.sparkContext.hadoopConfiguration - ) { + ): Unit = ssc.withScope { // Wrap conf in SerializableWritable so that ForeachDStream can be serialized for checkpoints val serializableConf = new SerializableWritable(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 5cfe43a1ce726..e4ff05e12f201 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -73,27 +73,38 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont val inputInfo = InputInfo(id, blockInfos.map(_.numRecords).sum) ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) - // Are WAL record handles present with all the blocks - val areWALRecordHandlesPresent = blockInfos.forall { _.walRecordHandleOption.nonEmpty } + if (blockInfos.nonEmpty) { + // Are WAL record handles present with all the blocks + val areWALRecordHandlesPresent = blockInfos.forall { _.walRecordHandleOption.nonEmpty } - if (areWALRecordHandlesPresent) { - // If all the blocks have WAL record handle, then create a WALBackedBlockRDD - val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray - val walRecordHandles = blockInfos.map { _.walRecordHandleOption.get }.toArray - new WriteAheadLogBackedBlockRDD[T]( - ssc.sparkContext, blockIds, walRecordHandles, isBlockIdValid) - } else { - // Else, create a BlockRDD. However, if there are some blocks with WAL info but not others - // then that is unexpected and log a warning accordingly. - if (blockInfos.find(_.walRecordHandleOption.nonEmpty).nonEmpty) { - if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { - logError("Some blocks do not have Write Ahead Log information; " + - "this is unexpected and data may not be recoverable after driver failures") - } else { - logWarning("Some blocks have Write Ahead Log information; this is unexpected") + if (areWALRecordHandlesPresent) { + // If all the blocks have WAL record handle, then create a WALBackedBlockRDD + val isBlockIdValid = blockInfos.map { _.isBlockIdValid() }.toArray + val walRecordHandles = blockInfos.map { _.walRecordHandleOption.get }.toArray + new WriteAheadLogBackedBlockRDD[T]( + ssc.sparkContext, blockIds, walRecordHandles, isBlockIdValid) + } else { + // Else, create a BlockRDD. However, if there are some blocks with WAL info but not + // others then that is unexpected and log a warning accordingly. + if (blockInfos.find(_.walRecordHandleOption.nonEmpty).nonEmpty) { + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { + logError("Some blocks do not have Write Ahead Log information; " + + "this is unexpected and data may not be recoverable after driver failures") + } else { + logWarning("Some blocks have Write Ahead Log information; this is unexpected") + } } + new BlockRDD[T](ssc.sc, blockIds) + } + } else { + // If no block is ready now, creating WriteAheadLogBackedBlockRDD or BlockRDD + // according to the configuration + if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { + new WriteAheadLogBackedBlockRDD[T]( + ssc.sparkContext, Array.empty, Array.empty, Array.empty) + } else { + new BlockRDD[T](ssc.sc, Array.empty) } - new BlockRDD[T](ssc.sc, blockIds) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index 1385ccbf56ee5..6a583bf2a3626 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -38,14 +38,14 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( _windowDuration: Duration, _slideDuration: Duration, partitioner: Partitioner - ) extends DStream[(K,V)](parent.ssc) { + ) extends DStream[(K, V)](parent.ssc) { - assert(_windowDuration.isMultipleOf(parent.slideDuration), + require(_windowDuration.isMultipleOf(parent.slideDuration), "The window duration of ReducedWindowedDStream (" + _windowDuration + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" ) - assert(_slideDuration.isMultipleOf(parent.slideDuration), + require(_slideDuration.isMultipleOf(parent.slideDuration), "The slide duration of ReducedWindowedDStream (" + _slideDuration + ") " + "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" ) @@ -58,7 +58,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( super.persist(StorageLevel.MEMORY_ONLY_SER) reducedStream.persist(StorageLevel.MEMORY_ONLY_SER) - def windowDuration: Duration = _windowDuration + def windowDuration: Duration = _windowDuration override def dependencies: List[DStream[_]] = List(reducedStream) @@ -68,7 +68,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( override def parentRememberDuration: Duration = rememberDuration + windowDuration - override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + override def persist(storageLevel: StorageLevel): DStream[(K, V)] = { super.persist(storageLevel) reducedStream.persist(storageLevel) this @@ -118,7 +118,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( // Get the RDD of the reduced value of the previous window val previousWindowRDD = - getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K, V)]())) // Make the list of RDDs that needs to cogrouped together for reducing their reduced values val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index 7757ccac09a58..e0ffd5d86b435 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -25,19 +25,19 @@ import scala.reflect.ClassTag private[streaming] class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag]( - parent: DStream[(K,V)], + parent: DStream[(K, V)], createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiner: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true - ) extends DStream[(K,C)] (parent.ssc) { + ) extends DStream[(K, C)] (parent.ssc) { override def dependencies: List[DStream[_]] = List(parent) override def slideDuration: Duration = parent.slideDuration - override def compute(validTime: Time): Option[RDD[(K,C)]] = { + override def compute(validTime: Time): Option[RDD[(K, C)]] = { parent.getOrCompute(validTime) match { case Some(rdd) => Some(rdd.combineByKey[C]( createCombiner, mergeValue, mergeCombiner, partitioner, mapSideCombine)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 8b72bcf20653d..5ce5b7aae6e69 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.dstream +import scala.util.control.NonFatal + import org.apache.spark.streaming.StreamingContext import org.apache.spark.storage.StorageLevel import org.apache.spark.util.NextIterator @@ -74,13 +76,17 @@ class SocketReceiver[T: ClassTag]( while(!isStopped && iterator.hasNext) { store(iterator.next) } - logInfo("Stopped receiving") - restart("Retrying connecting to " + host + ":" + port) + if (!isStopped()) { + restart("Socket data stream had no more data") + } else { + logInfo("Stopped receiving") + } } catch { case e: java.net.ConnectException => restart("Error connecting to " + host + ":" + port, e) - case t: Throwable => - restart("Error receiving data", t) + case NonFatal(e) => + logWarning("Error receiving data", e) + restart("Error receiving data", e) } finally { if (socket != null) { socket.close() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index de8718d0a80fe..621d6dff788f4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -51,7 +51,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => { val i = iterator.map(t => { val itr = t._2._2.iterator - val headOption = if(itr.hasNext) Some(itr.next) else None + val headOption = if (itr.hasNext) Some(itr.next()) else None (t._1, t._2._1.toSeq, headOption) }) updateFuncLocal(i) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 899865a906c27..4efba039f8959 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -44,7 +44,7 @@ class WindowedDStream[T: ClassTag]( // Persist parent level by default, as those RDDs are going to be obviously reused. parent.persist(StorageLevel.MEMORY_ONLY_SER) - def windowDuration: Duration = _windowDuration + def windowDuration: Duration = _windowDuration override def dependencies: List[DStream[_]] = List(parent) @@ -68,7 +68,7 @@ class WindowedDStream[T: ClassTag]( new PartitionerAwareUnionRDD(ssc.sc, rddsInWindow) } else { logDebug("Using normal union for windowing at " + validTime) - new UnionRDD(ssc.sc,rddsInWindow) + new UnionRDD(ssc.sc, rddsInWindow) } Some(windowRDD) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 4bebcc5aa7ca0..8d73593ab6375 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -164,7 +164,7 @@ private[streaming] class BlockGenerator( private def keepPushingBlocks() { logInfo("Started block pushing thread") try { - while(!stopped) { + while (!stopped) { Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { case Some(block) => pushBlock(block) case None => @@ -191,7 +191,7 @@ private[streaming] class BlockGenerator( logError(message, t) listener.onError(message, t) } - + private def pushBlock(block: Block) { listener.onPushBlock(block.id, block.buffer) logInfo("Pushed block " + block.id) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index 97db9ded83367..8df542b367d27 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -17,8 +17,9 @@ package org.apache.spark.streaming.receiver +import com.google.common.util.concurrent.{RateLimiter => GuavaRateLimiter} + import org.apache.spark.{Logging, SparkConf} -import com.google.common.util.concurrent.{RateLimiter=>GuavaRateLimiter} /** Provides waitToPush() method to limit the rate at which receivers consume data. * diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 651b534ac1900..207d64d9414ee 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -62,7 +62,7 @@ private[streaming] case class BlockManagerBasedStoreResult(blockId: StreamBlockI private[streaming] class BlockManagerBasedBlockHandler( blockManager: BlockManager, storageLevel: StorageLevel) extends ReceivedBlockHandler with Logging { - + def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { val putResult: Seq[(BlockId, BlockStatus)] = block match { case ArrayBufferBlock(arrayBuffer) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 4943f29395d12..33be067ebdaf2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -18,14 +18,14 @@ package org.apache.spark.streaming.receiver import java.nio.ByteBuffer +import java.util.concurrent.CountDownLatch import scala.collection.mutable.ArrayBuffer +import scala.concurrent._ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId -import java.util.concurrent.CountDownLatch -import scala.concurrent._ -import ExecutionContext.Implicits.global +import org.apache.spark.util.ThreadUtils /** * Abstract class that is responsible for supervising a Receiver in the worker. @@ -46,6 +46,9 @@ private[streaming] abstract class ReceiverSupervisor( // Attach the executor to the receiver receiver.attachExecutor(this) + private val futureExecutionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("receiver-supervisor-future", 128)) + /** Receiver id */ protected val streamId = receiver.streamId @@ -111,6 +114,7 @@ private[streaming] abstract class ReceiverSupervisor( stoppingError = error.orNull stopReceiver(message, error) onStop(message, error) + futureExecutionContext.shutdownNow() stopLatch.countDown() } @@ -150,6 +154,8 @@ private[streaming] abstract class ReceiverSupervisor( /** Restart receiver with delay */ def restartReceiver(message: String, error: Option[Throwable], delay: Int) { Future { + // This is a blocking action so we should use "futureExecutionContext" which is a cached + // thread pool. logWarning("Restarting receiver with delay " + delay + " ms: " + message, error.getOrElse(null)) stopReceiver("Restarting receiver with delay " + delay + "ms: " + message, error) @@ -158,7 +164,7 @@ private[streaming] abstract class ReceiverSupervisor( logInfo("Starting receiver again") startReceiver() logInfo("Receiver started again") - } + }(futureExecutionContext) } /** Check if receiver has been marked for stopping */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 1d1ddaaccf217..4af9b6d3b56ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -126,6 +126,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { eventLoop.post(ErrorReported(msg, e)) } + def isStarted(): Boolean = synchronized { + eventLoop != null + } + private def processEvent(event: JobSchedulerEvent) { try { event match { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index a9f4147a5f020..7720259a5d794 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -153,7 +153,7 @@ private[streaming] class ReceivedBlockTracker( * returns only after the files are cleaned up. */ def cleanupOldBatches(cleanupThreshTime: Time, waitForCompletion: Boolean): Unit = synchronized { - assert(cleanupThreshTime.milliseconds < clock.getTimeMillis()) + require(cleanupThreshTime.milliseconds < clock.getTimeMillis()) val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq logInfo("Deleting batches " + timesToCleanup) writeToLog(BatchCleanupEvent(timesToCleanup)) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index f73f7e705ee0d..f1504b09c9873 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -230,7 +230,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false class ReceiverLauncher { @transient val env = ssc.env @volatile @transient private var running = false - @transient val thread = new Thread() { + @transient val thread = new Thread() { override def run() { try { SparkEnv.set(env) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index 3619e129ad9cf..f702bd5bc9466 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -17,11 +17,14 @@ package org.apache.spark.streaming.ui +import java.text.SimpleDateFormat +import java.util.Date + import scala.xml.Node import org.apache.spark.ui.{UIUtils => SparkUIUtils} -private[ui] abstract class BatchTableBase(tableId: String) { +private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) { protected def columns: Seq[Node] = { Batch Time @@ -35,14 +38,15 @@ private[ui] abstract class BatchTableBase(tableId: String) { protected def baseRow(batch: BatchUIData): Seq[Node] = { val batchTime = batch.batchTime.milliseconds - val formattedBatchTime = SparkUIUtils.formatDate(batch.batchTime.milliseconds) + val formattedBatchTime = UIUtils.formatBatchTime(batchTime, batchInterval) val eventCount = batch.numRecords val schedulingDelay = batch.schedulingDelay val formattedSchedulingDelay = schedulingDelay.map(SparkUIUtils.formatDuration).getOrElse("-") val processingTime = batch.processingDelay val formattedProcessingTime = processingTime.map(SparkUIUtils.formatDuration).getOrElse("-") + val batchTimeId = s"batch-$batchTime" - + {formattedBatchTime} @@ -79,7 +83,8 @@ private[ui] abstract class BatchTableBase(tableId: String) { private[ui] class ActiveBatchTable( runningBatches: Seq[BatchUIData], - waitingBatches: Seq[BatchUIData]) extends BatchTableBase("active-batches-table") { + waitingBatches: Seq[BatchUIData], + batchInterval: Long) extends BatchTableBase("active-batches-table", batchInterval) { override protected def columns: Seq[Node] = super.columns ++ Status @@ -99,8 +104,8 @@ private[ui] class ActiveBatchTable( } } -private[ui] class CompletedBatchTable(batches: Seq[BatchUIData]) - extends BatchTableBase("completed-batches-table") { +private[ui] class CompletedBatchTable(batches: Seq[BatchUIData], batchInterval: Long) + extends BatchTableBase("completed-batches-table", batchInterval) { override protected def columns: Seq[Node] = super.columns ++ Total Delay diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 831f60e870f74..f75067669abe5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -17,6 +17,8 @@ package org.apache.spark.streaming.ui +import java.text.SimpleDateFormat +import java.util.Date import javax.servlet.http.HttpServletRequest import scala.xml.{NodeSeq, Node, Text} @@ -288,7 +290,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { val batchTime = Option(request.getParameter("id")).map(id => Time(id.toLong)).getOrElse { throw new IllegalArgumentException(s"Missing id parameter") } - val formattedBatchTime = SparkUIUtils.formatDate(batchTime.milliseconds) + val formattedBatchTime = + UIUtils.formatBatchTime(batchTime.milliseconds, streamingListener.batchDuration) val batchUIData = streamingListener.getBatchUIData(batchTime).getOrElse { throw new IllegalArgumentException(s"Batch $formattedBatchTime does not exist") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index ff0f2b18dc321..4ee7a486e370b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -166,8 +166,8 @@ private[ui] class StreamingPage(parent: StreamingTab) private def generateLoadResources(): Seq[Node] = { // scalastyle:off - - + + // scalastyle:on } @@ -186,6 +186,8 @@ private[ui] class StreamingPage(parent: StreamingTab) {SparkUIUtils.formatDate(startTime)} + ({listener.numTotalCompletedBatches} + completed batches, {listener.numTotalReceivedRecords} records)


    } @@ -199,9 +201,9 @@ private[ui] class StreamingPage(parent: StreamingTab) * @param times all time values that will be used in the graphs. */ private def generateTimeMap(times: Seq[Long]): Seq[Node] = { - val dateFormat = new SimpleDateFormat("HH:mm:ss") val js = "var timeFormat = {};\n" + times.map { time => - val formattedTime = dateFormat.format(new Date(time)) + val formattedTime = + UIUtils.formatBatchTime(time, listener.batchDuration, showYYYYMMSS = false) s"timeFormat[$time] = '$formattedTime';" }.mkString("\n") @@ -244,17 +246,6 @@ private[ui] class StreamingPage(parent: StreamingTab) val maxEventRate = eventRateForAllStreams.max.map(_.ceil.toLong).getOrElse(0L) val minEventRate = 0L - // JavaScript to show/hide the InputDStreams sub table. - val triangleJs = - s"""$$('#inputs-table').toggle('collapsed'); - |var status = false; - |if ($$(this).html() == '$BLACK_RIGHT_TRIANGLE_HTML') { - |$$(this).html('$BLACK_DOWN_TRIANGLE_HTML');status = true;} - |else {$$(this).html('$BLACK_RIGHT_TRIANGLE_HTML');status = false;} - |window.history.pushState('', - | document.title, window.location.pathname + '?show-streams-detail=' + status);""" - .stripMargin.replaceAll("\\n", "") // it must be only one single line - val batchInterval = UIUtils.convertToTimeUnit(listener.batchDuration, normalizedUnit) val jsCollector = new JsCollector @@ -326,10 +317,18 @@ private[ui] class StreamingPage(parent: StreamingTab)
    - {if (hasStream) { - {Unparsed(BLACK_RIGHT_TRIANGLE_HTML)} - }} - Input Rate + { + if (hasStream) { + + + + Input Rate + + + } else { + Input Rate + } + }
    Avg: {eventRateForAllStreams.formattedAvg} events/sec
    @@ -475,14 +474,14 @@ private[ui] class StreamingPage(parent: StreamingTab) val activeBatchesContent = {

    Active Batches ({runningBatches.size + waitingBatches.size})

    ++ - new ActiveBatchTable(runningBatches, waitingBatches).toNodeSeq + new ActiveBatchTable(runningBatches, waitingBatches, listener.batchDuration).toNodeSeq } val completedBatchesContent = {

    Completed Batches (last {completedBatches.size} out of {listener.numTotalCompletedBatches})

    ++ - new CompletedBatchTable(completedBatches).toNodeSeq + new CompletedBatchTable(completedBatches, listener.batchDuration).toNodeSeq } activeBatchesContent ++ completedBatchesContent diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index f307b54bb9630..e0c0f57212f55 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -17,9 +17,11 @@ package org.apache.spark.streaming.ui +import org.eclipse.jetty.servlet.ServletContextHandler + import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.StreamingContext -import org.apache.spark.ui.{SparkUI, SparkUITab} +import org.apache.spark.ui.{JettyUtils, SparkUI, SparkUITab} import StreamingTab._ @@ -30,6 +32,8 @@ import StreamingTab._ private[spark] class StreamingTab(val ssc: StreamingContext) extends SparkUITab(getSparkUI(ssc), "streaming") with Logging { + private val STATIC_RESOURCE_DIR = "org/apache/spark/streaming/ui/static" + val parent = getSparkUI(ssc) val listener = ssc.progressListener @@ -38,12 +42,18 @@ private[spark] class StreamingTab(val ssc: StreamingContext) attachPage(new StreamingPage(this)) attachPage(new BatchPage(this)) + var staticHandler: ServletContextHandler = null + def attach() { getSparkUI(ssc).attachTab(this) + staticHandler = JettyUtils.createStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming") + getSparkUI(ssc).attachHandler(staticHandler) } def detach() { getSparkUI(ssc).detachTab(this) + getSparkUI(ssc).detachHandler(staticHandler) + staticHandler = null } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala index c206f973b2c66..86cfb1fa47370 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -17,9 +17,11 @@ package org.apache.spark.streaming.ui +import java.text.SimpleDateFormat +import java.util.TimeZone import java.util.concurrent.TimeUnit -object UIUtils { +private[streaming] object UIUtils { /** * Return the short string for a `TimeUnit`. @@ -62,7 +64,7 @@ object UIUtils { * Convert `milliseconds` to the specified `unit`. We cannot use `TimeUnit.convert` because it * will discard the fractional part. */ - def convertToTimeUnit(milliseconds: Long, unit: TimeUnit): Double = unit match { + def convertToTimeUnit(milliseconds: Long, unit: TimeUnit): Double = unit match { case TimeUnit.NANOSECONDS => milliseconds * 1000 * 1000 case TimeUnit.MICROSECONDS => milliseconds * 1000 case TimeUnit.MILLISECONDS => milliseconds @@ -71,4 +73,55 @@ object UIUtils { case TimeUnit.HOURS => milliseconds / 1000.0 / 60.0 / 60.0 case TimeUnit.DAYS => milliseconds / 1000.0 / 60.0 / 60.0 / 24.0 } + + // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. + private val batchTimeFormat = new ThreadLocal[SimpleDateFormat]() { + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + } + + private val batchTimeFormatWithMilliseconds = new ThreadLocal[SimpleDateFormat]() { + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss.SSS") + } + + /** + * If `batchInterval` is less than 1 second, format `batchTime` with milliseconds. Otherwise, + * format `batchTime` without milliseconds. + * + * @param batchTime the batch time to be formatted + * @param batchInterval the batch interval + * @param showYYYYMMSS if showing the `yyyy/MM/dd` part. If it's false, the return value wll be + * only `HH:mm:ss` or `HH:mm:ss.SSS` depending on `batchInterval` + * @param timezone only for test + */ + def formatBatchTime( + batchTime: Long, + batchInterval: Long, + showYYYYMMSS: Boolean = true, + timezone: TimeZone = null): String = { + val oldTimezones = + (batchTimeFormat.get.getTimeZone, batchTimeFormatWithMilliseconds.get.getTimeZone) + if (timezone != null) { + batchTimeFormat.get.setTimeZone(timezone) + batchTimeFormatWithMilliseconds.get.setTimeZone(timezone) + } + try { + val formattedBatchTime = + if (batchInterval < 1000) { + batchTimeFormatWithMilliseconds.get.format(batchTime) + } else { + // If batchInterval >= 1 second, don't show milliseconds + batchTimeFormat.get.format(batchTime) + } + if (showYYYYMMSS) { + formattedBatchTime + } else { + formattedBatchTime.substring(formattedBatchTime.indexOf(' ') + 1) + } + } finally { + if (timezone != null) { + batchTimeFormat.get.setTimeZone(oldTimezones._1) + batchTimeFormatWithMilliseconds.get.setTimeZone(oldTimezones._2) + } + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index 87ba4f84a9ceb..fe6328b1ce727 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -200,7 +200,7 @@ private[streaming] class FileBasedWriteAheadLog( /** Initialize the log directory or recover existing logs inside the directory */ private def initializeOrRecover(): Unit = synchronized { val logDirectoryPath = new Path(logDirectory) - val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) + val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath }) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index 4d968f8bfa7a8..408936653c790 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -27,7 +27,7 @@ object RawTextHelper { * Splits lines and counts the words. */ def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { - val map = new OpenHashMap[String,Long] + val map = new OpenHashMap[String, Long] var i = 0 var j = 0 while (iter.hasNext) { @@ -98,7 +98,7 @@ object RawTextHelper { * before real workload starts. */ def warmUp(sc: SparkContext) { - for(i <- 0 to 1) { + for (i <- 0 to 1) { sc.parallelize(1 to 200000, 1000) .map(_ % 1331).map(_.toString) .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10) diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 2e00b980b9e44..1077b1b2cb7e3 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1766,29 +1766,10 @@ public JavaStreamingContext call() { Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); - // Function to create JavaStreamingContext using existing JavaSparkContext - // without any output operations (used to detect the new context) - Function creatingFunc2 = - new Function() { - public JavaStreamingContext call(JavaSparkContext context) { - newContextCreated.set(true); - return new JavaStreamingContext(context, Seconds.apply(1)); - } - }; - - JavaSparkContext sc = new JavaSparkContext(conf); - newContextCreated.set(false); - ssc = JavaStreamingContext.getOrCreate(emptyDir.getAbsolutePath(), creatingFunc2, sc); - Assert.assertTrue("new context not created", newContextCreated.get()); - ssc.stop(false); - newContextCreated.set(false); - ssc = JavaStreamingContext.getOrCreate(corruptedCheckpointDir, creatingFunc2, sc, true); - Assert.assertTrue("new context not created", newContextCreated.get()); - ssc.stop(false); - - newContextCreated.set(false); - ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc2, sc); + JavaSparkContext sc = new JavaSparkContext(conf); + ssc = JavaStreamingContext.getOrCreate(checkpointDir, creatingFunc, + new org.apache.hadoop.conf.Configuration()); Assert.assertTrue("old context not recovered", !newContextCreated.get()); ssc.stop(); } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 87bc20f79c3cd..08faeaa58f419 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -255,7 +255,7 @@ class BasicOperationsSuite extends TestSuiteBase { Seq( ) ) val operation = (s1: DStream[String], s2: DStream[String]) => { - s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) + s1.map(x => (x, 1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) } testOperation(inputData1, inputData2, operation, outputData, true) } @@ -427,9 +427,9 @@ class BasicOperationsSuite extends TestSuiteBase { test("updateStateByKey - object lifecycle") { val inputData = Seq( - Seq("a","b"), + Seq("a", "b"), null, - Seq("a","c","a"), + Seq("a", "c", "a"), Seq("c"), null, null @@ -557,6 +557,9 @@ class BasicOperationsSuite extends TestSuiteBase { withTestServer(new TestServer()) { testServer => withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => testServer.start() + + val batchCounter = new BatchCounter(ssc) + // Set up the streaming context and input streams val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) @@ -587,7 +590,11 @@ class BasicOperationsSuite extends TestSuiteBase { for (i <- 0 until input.size) { testServer.send(input(i).toString + "\n") Thread.sleep(200) + val numCompletedBatches = batchCounter.getNumCompletedBatches clock.advance(batchDuration.milliseconds) + if (!batchCounter.waitUntilBatchesCompleted(numCompletedBatches + 1, 5000)) { + fail("Batch took more than 5 seconds to complete") + } collectRddInfo() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala new file mode 100644 index 0000000000000..9b5e4dc819a2b --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.NotSerializableException + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{HashPartitioner, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.ReturnStatementInClosureException + +/** + * Test that closures passed to DStream operations are actually cleaned. + */ +class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { + private var ssc: StreamingContext = null + + override def beforeAll(): Unit = { + val sc = new SparkContext("local", "test") + ssc = new StreamingContext(sc, Seconds(1)) + } + + override def afterAll(): Unit = { + ssc.stop(stopSparkContext = true) + ssc = null + } + + test("user provided closures are actually cleaned") { + val dstream = new DummyInputDStream(ssc) + val pairDstream = dstream.map { i => (i, i) } + // DStream + testMap(dstream) + testFlatMap(dstream) + testFilter(dstream) + testMapPartitions(dstream) + testReduce(dstream) + testForeach(dstream) + testForeachRDD(dstream) + testTransform(dstream) + testTransformWith(dstream) + testReduceByWindow(dstream) + // PairDStreamFunctions + testReduceByKey(pairDstream) + testCombineByKey(pairDstream) + testReduceByKeyAndWindow(pairDstream) + testUpdateStateByKey(pairDstream) + testMapValues(pairDstream) + testFlatMapValues(pairDstream) + // StreamingContext + testTransform2(ssc, dstream) + } + + /** + * Verify that the expected exception is thrown. + * + * We use return statements as an indication that a closure is actually being cleaned. + * We expect closure cleaner to find the return statements in the user provided closures. + */ + private def expectCorrectException(body: => Unit): Unit = { + try { + body + } catch { + case rse: ReturnStatementInClosureException => // Success! + case e @ (_: NotSerializableException | _: SparkException) => + throw new TestException( + s"Expected ReturnStatementInClosureException, but got $e.\n" + + "This means the closure provided by user is not actually cleaned.") + } + } + + // DStream operations + private def testMap(ds: DStream[Int]): Unit = expectCorrectException { + ds.map { _ => return; 1 } + } + private def testFlatMap(ds: DStream[Int]): Unit = expectCorrectException { + ds.flatMap { _ => return; Seq.empty } + } + private def testFilter(ds: DStream[Int]): Unit = expectCorrectException { + ds.filter { _ => return; true } + } + private def testMapPartitions(ds: DStream[Int]): Unit = expectCorrectException { + ds.mapPartitions { _ => return; Seq.empty.toIterator } + } + private def testReduce(ds: DStream[Int]): Unit = expectCorrectException { + ds.reduce { case (_, _) => return; 1 } + } + private def testForeach(ds: DStream[Int]): Unit = { + val foreachF1 = (rdd: RDD[Int], t: Time) => return + val foreachF2 = (rdd: RDD[Int]) => return + expectCorrectException { ds.foreach(foreachF1) } + expectCorrectException { ds.foreach(foreachF2) } + } + private def testForeachRDD(ds: DStream[Int]): Unit = { + val foreachRDDF1 = (rdd: RDD[Int], t: Time) => return + val foreachRDDF2 = (rdd: RDD[Int]) => return + expectCorrectException { ds.foreachRDD(foreachRDDF1) } + expectCorrectException { ds.foreachRDD(foreachRDDF2) } + } + private def testTransform(ds: DStream[Int]): Unit = { + val transformF1 = (rdd: RDD[Int]) => { return; rdd } + val transformF2 = (rdd: RDD[Int], time: Time) => { return; rdd } + expectCorrectException { ds.transform(transformF1) } + expectCorrectException { ds.transform(transformF2) } + } + private def testTransformWith(ds: DStream[Int]): Unit = { + val transformF1 = (rdd1: RDD[Int], rdd2: RDD[Int]) => { return; rdd1 } + val transformF2 = (rdd1: RDD[Int], rdd2: RDD[Int], time: Time) => { return; rdd2 } + expectCorrectException { ds.transformWith(ds, transformF1) } + expectCorrectException { ds.transformWith(ds, transformF2) } + } + private def testReduceByWindow(ds: DStream[Int]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + expectCorrectException { ds.reduceByWindow(reduceF, Seconds(1), Seconds(2)) } + expectCorrectException { ds.reduceByWindow(reduceF, reduceF, Seconds(1), Seconds(2)) } + } + + // PairDStreamFunctions operations + private def testReduceByKey(ds: DStream[(Int, Int)]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + expectCorrectException { ds.reduceByKey(reduceF) } + expectCorrectException { ds.reduceByKey(reduceF, 5) } + expectCorrectException { ds.reduceByKey(reduceF, new HashPartitioner(5)) } + } + private def testCombineByKey(ds: DStream[(Int, Int)]): Unit = { + expectCorrectException { + ds.combineByKey[Int]( + { _: Int => return; 1 }, + { case (_: Int, _: Int) => return; 1 }, + { case (_: Int, _: Int) => return; 1 }, + new HashPartitioner(5) + ) + } + } + private def testReduceByKeyAndWindow(ds: DStream[(Int, Int)]): Unit = { + val reduceF = (_: Int, _: Int) => { return; 1 } + val filterF = (_: (Int, Int)) => { return; false } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1)) } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2)) } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2), 5) } + expectCorrectException { + ds.reduceByKeyAndWindow(reduceF, Seconds(1), Seconds(2), new HashPartitioner(5)) + } + expectCorrectException { ds.reduceByKeyAndWindow(reduceF, reduceF, Seconds(2)) } + expectCorrectException { + ds.reduceByKeyAndWindow( + reduceF, reduceF, Seconds(2), Seconds(3), new HashPartitioner(5), filterF) + } + } + private def testUpdateStateByKey(ds: DStream[(Int, Int)]): Unit = { + val updateF1 = (_: Seq[Int], _: Option[Int]) => { return; Some(1) } + val updateF2 = (_: Iterator[(Int, Seq[Int], Option[Int])]) => { return; Seq((1, 1)).toIterator } + val initialRDD = ds.ssc.sparkContext.emptyRDD[Int].map { i => (i, i) } + expectCorrectException { ds.updateStateByKey(updateF1) } + expectCorrectException { ds.updateStateByKey(updateF1, 5) } + expectCorrectException { ds.updateStateByKey(updateF1, new HashPartitioner(5)) } + expectCorrectException { + ds.updateStateByKey(updateF1, new HashPartitioner(5), initialRDD) + } + expectCorrectException { + ds.updateStateByKey(updateF2, new HashPartitioner(5), true) + } + expectCorrectException { + ds.updateStateByKey(updateF2, new HashPartitioner(5), true, initialRDD) + } + } + private def testMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException { + ds.mapValues { _ => return; 1 } + } + private def testFlatMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException { + ds.flatMapValues { _ => return; Seq.empty } + } + + // StreamingContext operations + private def testTransform2(ssc: StreamingContext, ds: DStream[Int]): Unit = { + val transformF = (rdds: Seq[RDD[_]], time: Time) => { return; ssc.sparkContext.emptyRDD[Int] } + expectCorrectException { ssc.transform(Seq(ds), transformF) } + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala new file mode 100644 index 0000000000000..8844c9d74b933 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.rdd.RDDOperationScope +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.ui.UIUtils + +/** + * Tests whether scope information is passed from DStream operations to RDDs correctly. + */ +class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { + private var ssc: StreamingContext = null + private val batchDuration: Duration = Seconds(1) + + override def beforeAll(): Unit = { + ssc = new StreamingContext(new SparkContext("local", "test"), batchDuration) + } + + override def afterAll(): Unit = { + ssc.stop(stopSparkContext = true) + } + + before { assertPropertiesNotSet() } + after { assertPropertiesNotSet() } + + test("dstream without scope") { + val dummyStream = new DummyDStream(ssc) + dummyStream.initialize(Time(0)) + + // This DStream is not instantiated in any scope, so all RDDs + // created by this stream should similarly not have a scope + assert(dummyStream.baseScope === None) + assert(dummyStream.getOrCompute(Time(1000)).get.scope === None) + assert(dummyStream.getOrCompute(Time(2000)).get.scope === None) + assert(dummyStream.getOrCompute(Time(3000)).get.scope === None) + } + + test("input dstream without scope") { + val inputStream = new DummyInputDStream(ssc) + inputStream.initialize(Time(0)) + + val baseScope = inputStream.baseScope.map(RDDOperationScope.fromJson) + val scope1 = inputStream.getOrCompute(Time(1000)).get.scope + val scope2 = inputStream.getOrCompute(Time(2000)).get.scope + val scope3 = inputStream.getOrCompute(Time(3000)).get.scope + + // This DStream is not instantiated in any scope, so all RDDs + assertDefined(baseScope, scope1, scope2, scope3) + assert(baseScope.get.name.startsWith("dummy stream")) + assertScopeCorrect(baseScope.get, scope1.get, 1000) + assertScopeCorrect(baseScope.get, scope2.get, 2000) + assertScopeCorrect(baseScope.get, scope3.get, 3000) + } + + test("scoping simple operations") { + val inputStream = new DummyInputDStream(ssc) + val mappedStream = inputStream.map { i => i + 1 } + val filteredStream = mappedStream.filter { i => i % 2 == 0 } + filteredStream.initialize(Time(0)) + + val mappedScopeBase = mappedStream.baseScope.map(RDDOperationScope.fromJson) + val mappedScope1 = mappedStream.getOrCompute(Time(1000)).get.scope + val mappedScope2 = mappedStream.getOrCompute(Time(2000)).get.scope + val mappedScope3 = mappedStream.getOrCompute(Time(3000)).get.scope + val filteredScopeBase = filteredStream.baseScope.map(RDDOperationScope.fromJson) + val filteredScope1 = filteredStream.getOrCompute(Time(1000)).get.scope + val filteredScope2 = filteredStream.getOrCompute(Time(2000)).get.scope + val filteredScope3 = filteredStream.getOrCompute(Time(3000)).get.scope + + // These streams are defined in their respective scopes "map" and "filter", so all + // RDDs created by these streams should inherit the IDs and names of their parent + // DStream's base scopes + assertDefined(mappedScopeBase, mappedScope1, mappedScope2, mappedScope3) + assertDefined(filteredScopeBase, filteredScope1, filteredScope2, filteredScope3) + assert(mappedScopeBase.get.name === "map") + assert(filteredScopeBase.get.name === "filter") + assertScopeCorrect(mappedScopeBase.get, mappedScope1.get, 1000) + assertScopeCorrect(mappedScopeBase.get, mappedScope2.get, 2000) + assertScopeCorrect(mappedScopeBase.get, mappedScope3.get, 3000) + assertScopeCorrect(filteredScopeBase.get, filteredScope1.get, 1000) + assertScopeCorrect(filteredScopeBase.get, filteredScope2.get, 2000) + assertScopeCorrect(filteredScopeBase.get, filteredScope3.get, 3000) + } + + test("scoping nested operations") { + val inputStream = new DummyInputDStream(ssc) + val countStream = inputStream.countByWindow(Seconds(10), Seconds(1)) + countStream.initialize(Time(0)) + + val countScopeBase = countStream.baseScope.map(RDDOperationScope.fromJson) + val countScope1 = countStream.getOrCompute(Time(1000)).get.scope + val countScope2 = countStream.getOrCompute(Time(2000)).get.scope + val countScope3 = countStream.getOrCompute(Time(3000)).get.scope + + // Assert that all children RDDs inherit the DStream operation name correctly + assertDefined(countScopeBase, countScope1, countScope2, countScope3) + assert(countScopeBase.get.name === "countByWindow") + assertScopeCorrect(countScopeBase.get, countScope1.get, 1000) + assertScopeCorrect(countScopeBase.get, countScope2.get, 2000) + assertScopeCorrect(countScopeBase.get, countScope3.get, 3000) + + // All streams except the input stream should share the same scopes as `countStream` + def testStream(stream: DStream[_]): Unit = { + if (stream != inputStream) { + val myScopeBase = stream.baseScope.map(RDDOperationScope.fromJson) + val myScope1 = stream.getOrCompute(Time(1000)).get.scope + val myScope2 = stream.getOrCompute(Time(2000)).get.scope + val myScope3 = stream.getOrCompute(Time(3000)).get.scope + assertDefined(myScopeBase, myScope1, myScope2, myScope3) + assert(myScopeBase === countScopeBase) + assert(myScope1 === countScope1) + assert(myScope2 === countScope2) + assert(myScope3 === countScope3) + // Climb upwards to test the parent streams + stream.dependencies.foreach(testStream) + } + } + testStream(countStream) + } + + /** Assert that the RDD operation scope properties are not set in our SparkContext. */ + private def assertPropertiesNotSet(): Unit = { + assert(ssc != null) + assert(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_KEY) == null) + assert(ssc.sc.getLocalProperty(SparkContext.RDD_SCOPE_NO_OVERRIDE_KEY) == null) + } + + /** Assert that the given RDD scope inherits the name and ID of the base scope correctly. */ + private def assertScopeCorrect( + baseScope: RDDOperationScope, + rddScope: RDDOperationScope, + batchTime: Long): Unit = { + assertScopeCorrect(baseScope.id, baseScope.name, rddScope, batchTime) + } + + /** Assert that the given RDD scope inherits the base name and ID correctly. */ + private def assertScopeCorrect( + baseScopeId: String, + baseScopeName: String, + rddScope: RDDOperationScope, + batchTime: Long): Unit = { + val formattedBatchTime = UIUtils.formatBatchTime( + batchTime, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) + assert(rddScope.id === s"${baseScopeId}_$batchTime") + assert(rddScope.name.replaceAll("\\n", " ") === s"$baseScopeName @ $formattedBatchTime") + } + + /** Assert that all the specified options are defined. */ + private def assertDefined[T](options: Option[T]*): Unit = { + options.zipWithIndex.foreach { case (o, i) => assert(o.isDefined, s"Option $i was empty!") } + } + +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 93e6b0cd7c661..b74d67c63a788 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} +import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -105,6 +106,36 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } + test("socket input stream - no block in a batch") { + withTestServer(new TestServer()) { testServer => + testServer.start() + + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + ssc.addStreamingListener(ssc.progressListener) + + val batchCounter = new BatchCounter(ssc) + val networkStream = ssc.socketTextStream( + "localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + outputStream.register() + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // Make sure the first batch is finished + if (!batchCounter.waitUntilBatchesCompleted(1, 30000)) { + fail("Timeout: cannot finish all batches in 30 seconds") + } + + networkStream.generatedRDDs.foreach { case (_, rdd) => + assert(!rdd.isInstanceOf[WriteAheadLogBackedBlockRDD[_]]) + } + } + } + } + test("binary records stream") { val testDir: File = null try { @@ -387,7 +418,7 @@ class TestServer(portToBind: Int = 0) extends Logging { val servingThread = new Thread() { override def run() { try { - while(true) { + while (true) { logInfo("Accepting connections on port " + port) val clientSocket = serverSocket.accept() if (startLatch.getCount == 1) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 23804237bda80..cca8cedb1d080 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -25,7 +25,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ @@ -41,7 +41,11 @@ import org.apache.spark.util.{ManualClock, Utils} import WriteAheadLogBasedBlockHandler._ import WriteAheadLogSuite._ -class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { +class ReceivedBlockHandlerSuite + extends SparkFunSuite + with BeforeAndAfter + with Matchers + with Logging { val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") val hadoopConf = new Configuration() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala index b1af8d5eaacfb..6f0ee774cb5cf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala @@ -25,10 +25,10 @@ import scala.language.{implicitConversions, postfixOps} import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler._ @@ -37,7 +37,7 @@ import org.apache.spark.streaming.util.WriteAheadLogSuite._ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils} class ReceivedBlockTrackerSuite - extends FunSuite with BeforeAndAfter with Matchers with Logging { + extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val hadoopConf = new Configuration() val akkaTimeout = 10 seconds diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 5f93332896de1..819dd2ccfe915 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,24 +17,24 @@ package org.apache.spark.streaming -import java.io.File +import java.io.{File, NotSerializableException} import java.util.concurrent.atomic.AtomicInteger import org.apache.commons.io.FileUtils -import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} -import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ +import org.scalatest.{Assertions, BeforeAndAfter} -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, SparkFunSuite} -class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { +class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging { val master = "local[2]" val appName = this.getClass.getSimpleName @@ -132,6 +132,41 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } } + test("start with non-seriazable DStream checkpoints") { + val checkpointDir = Utils.createTempDir() + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDir.getAbsolutePath) + addInputStream(ssc).foreachRDD { rdd => + // Refer to this.appName from inside closure so that this closure refers to + // the instance of StreamingContextSuite, and is therefore not serializable + rdd.count() + appName + } + + // Test whether start() fails early when checkpointing is enabled + val exception = intercept[NotSerializableException] { + ssc.start() + } + assert(exception.getMessage().contains("DStreams with their functions are not serializable")) + assert(ssc.getState() !== StreamingContextState.ACTIVE) + assert(StreamingContext.getActive().isEmpty) + } + + test("start failure should stop internal components") { + ssc = new StreamingContext(conf, batchDuration) + val inputStream = addInputStream(ssc) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + inputStream.map(x => (x, 1)).updateStateByKey[Int](updateFunc) + // Require that the start fails because checkpoint directory was not set + intercept[Exception] { + ssc.start() + } + assert(ssc.getState() === StreamingContextState.STOPPED) + assert(ssc.scheduler.isStarted === false) + } + + test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() @@ -163,7 +198,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register() ssc.stop() - intercept[SparkException] { + intercept[IllegalStateException] { ssc.start() // start after stop should throw exception } assert(ssc.getState() === StreamingContextState.STOPPED) @@ -419,76 +454,16 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) assert(ssc != null, "no context created") assert(!newContextCreated, "old context not recovered") - assert(ssc.conf.get("someKey") === "someValue") - } - } - - test("getOrCreate with existing SparkContext") { - val conf = new SparkConf().setMaster(master).setAppName(appName) - sc = new SparkContext(conf) - - // Function to create StreamingContext that has a config to identify it to be new context - var newContextCreated = false - def creatingFunction(sparkContext: SparkContext): StreamingContext = { - newContextCreated = true - new StreamingContext(sparkContext, batchDuration) + assert(ssc.conf.get("someKey") === "someValue", "checkpointed config not recovered") } - // Call ssc.stop(stopSparkContext = false) after a body of cody - def testGetOrCreate(body: => Unit): Unit = { - newContextCreated = false - try { - body - } finally { - if (ssc != null) { - ssc.stop(stopSparkContext = false) - } - ssc = null - } - } - - val emptyPath = Utils.createTempDir().getAbsolutePath() - - // getOrCreate should create new context with empty path + // getOrCreate should recover StreamingContext with existing SparkContext testGetOrCreate { - ssc = StreamingContext.getOrCreate(emptyPath, creatingFunction _, sc, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - } - - val corrutedCheckpointPath = createCorruptedCheckpoint() - - // getOrCreate should throw exception with fake checkpoint file and createOnError = false - intercept[Exception] { - ssc = StreamingContext.getOrCreate(corrutedCheckpointPath, creatingFunction _, sc) - } - - // getOrCreate should throw exception with fake checkpoint file - intercept[Exception] { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, sc, createOnError = false) - } - - // getOrCreate should create new context with fake checkpoint file and createOnError = true - testGetOrCreate { - ssc = StreamingContext.getOrCreate( - corrutedCheckpointPath, creatingFunction _, sc, createOnError = true) - assert(ssc != null, "no context created") - assert(newContextCreated, "new context not created") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - } - - val checkpointPath = createValidCheckpoint() - - // StreamingContext.getOrCreate should recover context with checkpoint path - testGetOrCreate { - ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _, sc) + sc = new SparkContext(conf) + ssc = StreamingContext.getOrCreate(checkpointPath, creatingFunction _) assert(ssc != null, "no context created") assert(!newContextCreated, "old context not recovered") - assert(ssc.sparkContext === sc, "new StreamingContext does not use existing SparkContext") - assert(!ssc.conf.contains("someKey"), - "recovered StreamingContext unexpectedly has old config") + assert(!ssc.conf.contains("someKey"), "checkpointed config unexpectedly recovered") } } @@ -641,7 +616,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val anotherInput = addInputStream(anotherSsc) anotherInput.foreachRDD { rdd => rdd.count } - val exception = intercept[SparkException] { + val exception = intercept[IllegalStateException] { anotherSsc.start() } assert(exception.getMessage.contains("StreamingContext"), "Did not get the right exception") @@ -664,7 +639,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w def testForException(clue: String, expectedErrorMsg: String)(body: => Unit): Unit = { withClue(clue) { - val ex = intercept[SparkException] { + val ex = intercept[IllegalStateException] { body } assert(ex.getMessage.toLowerCase().contains(expectedErrorMsg)) @@ -773,7 +748,9 @@ class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) def onStop() { // Simulate slow receiver by waiting for all records to be produced - while(!SlowTestReceiver.receivedAllRecords) Thread.sleep(100) + while (!SlowTestReceiver.receivedAllRecords) { + Thread.sleep(100) + } // no clean to be done, the receiving thread should stop on it own } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 312cce408cfe7..1dc8960d60528 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -133,8 +133,10 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { /** Check if a sequence of numbers is in increasing order */ def isInIncreasingOrder(seq: Seq[Long]): Boolean = { - for(i <- 1 until seq.size) { - if (seq(i - 1) > seq(i)) return false + for (i <- 1 until seq.size) { + if (seq(i - 1) > seq(i)) { + return false + } } true } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 4f70ae7f1f187..31b1aebf6a8ec 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -24,17 +24,35 @@ import scala.collection.mutable.SynchronizedBuffer import scala.language.implicitConversions import scala.reflect.ClassTag -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ManualClock, Utils} +/** + * A dummy stream that does absolutely nothing. + */ +private[streaming] class DummyDStream(ssc: StreamingContext) extends DStream[Int](ssc) { + override def dependencies: List[DStream[Int]] = List.empty + override def slideDuration: Duration = Seconds(1) + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.emptyRDD[Int]) +} + +/** + * A dummy input stream that does absolutely nothing. + */ +private[streaming] class DummyInputDStream(ssc: StreamingContext) extends InputDStream[Int](ssc) { + override def start(): Unit = { } + override def stop(): Unit = { } + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.emptyRDD[Int]) +} + /** * This is a input stream just for the testsuites. This is equivalent to a checkpointable, * replayable, reliable message queue like Kafka. It requires a sequence as input, and @@ -186,7 +204,7 @@ class BatchCounter(ssc: StreamingContext) { * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. */ -trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { +trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { // Name of the framework for Spark context def framework: String = this.getClass.getSimpleName diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 441bbf95d0153..cbc24aee4fa1e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -28,14 +28,11 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ - - - /** * Selenium tests for the Spark Web UI. */ class UISeleniumSuite - extends FunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { + extends SparkFunSuite with WebBrowser with Matchers with BeforeAndAfterAll with TestSuiteBase { implicit var webDriver: WebDriver = _ @@ -197,4 +194,3 @@ class UISeleniumSuite } } } - diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index 6859b65c7165f..cb017b798b2a4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -21,15 +21,15 @@ import java.io.File import scala.util.Random import org.apache.hadoop.conf.Configuration -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} class WriteAheadLogBackedBlockRDDSuite - extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach { val conf = new SparkConf() .setMaster("local[2]") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index 5478b41845943..2e210397fe7c7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.scheduler -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.streaming.{Time, Duration, StreamingContext} -class InputInfoTrackerSuite extends FunSuite with BeforeAndAfter { +class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { private var ssc: StreamingContext = _ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 2a0f45830e03c..c9175d61b1f49 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -64,7 +64,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (0) // onBatchStarted - val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) + val batchInfoStarted = BatchInfo(Time(1000), streamIdToNumRecords, 1000, Some(2000), None) listener.onBatchStarted(StreamingListenerBatchStarted(batchInfoStarted)) listener.waitingBatches should be (Nil) listener.runningBatches should be (List(BatchUIData(batchInfoStarted))) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala index 6df1a63ab2e37..d3ca2b58f36c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/UIUtilsSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.streaming.ui +import java.util.TimeZone import java.util.concurrent.TimeUnit -import org.scalatest.FunSuite import org.scalatest.Matchers -class UIUtilsSuite extends FunSuite with Matchers{ +import org.apache.spark.SparkFunSuite + +class UIUtilsSuite extends SparkFunSuite with Matchers{ test("shortTimeUnitString") { assert("ns" === UIUtils.shortTimeUnitString(TimeUnit.NANOSECONDS)) @@ -64,4 +66,14 @@ class UIUtilsSuite extends FunSuite with Matchers{ val convertedTime = UIUtils.convertToTimeUnit(milliseconds, unit) convertedTime should be (expectedTime +- 1E-6) } + + test("formatBatchTime") { + val tzForTest = TimeZone.getTimeZone("America/Los_Angeles") + val batchTime = 1431637480452L // Thu May 14 14:04:40 PDT 2015 + assert("2015/05/14 14:04:40" === UIUtils.formatBatchTime(batchTime, 1000, timezone = tzForTest)) + assert("2015/05/14 14:04:40.452" === + UIUtils.formatBatchTime(batchTime, 999, timezone = tzForTest)) + assert("14:04:40" === UIUtils.formatBatchTime(batchTime, 1000, false, timezone = tzForTest)) + assert("14:04:40.452" === UIUtils.formatBatchTime(batchTime, 999, false, timezone = tzForTest)) + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala index 9ebf7b484f421..78fc344b00177 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/RateLimitedOutputStreamSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.streaming.util import java.io.ByteArrayOutputStream import java.util.concurrent.TimeUnit._ -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class RateLimitedOutputStreamSuite extends FunSuite { +class RateLimitedOutputStreamSuite extends SparkFunSuite { private def benchmark[U](f: => U): Long = { val start = System.nanoTime diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 79098bcf4861c..325ff7c74c39d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -28,15 +28,15 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.concurrent.Eventually._ -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.BeforeAndAfter import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} -class WriteAheadLogSuite extends FunSuite with BeforeAndAfter { +class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { import WriteAheadLogSuite._ - + val hadoopConf = new Configuration() var tempDir: File = null var testDir: String = null @@ -359,7 +359,7 @@ object WriteAheadLogSuite { ): FileBasedWriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1) - + // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => manualClock.advance(500) diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 5b0733206b2bc..2fd17267ac427 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -42,6 +42,10 @@ com.google.code.findbugs jsr305 + + com.google.guava + guava + @@ -61,6 +65,11 @@ junit-interface test + + org.mockito + mockito-all + test + target/scala-${scala.binary.version}/classes diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 19d6a169fd2ad..0b4d8d286f5f9 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -23,6 +23,8 @@ import java.util.LinkedList; import java.util.List; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.unsafe.*; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; @@ -36,9 +38,8 @@ * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers, * which is guaranteed to exhaust the space. *

    - * The map can support up to 2^31 keys because we use 32 bit MurmurHash. If the key cardinality is - * higher than this, you should probably be using sorting instead of hashing for better cache - * locality. + * The map can support up to 2^29 keys. If the key cardinality is higher than this, you should + * probably be using sorting instead of hashing for better cache locality. *

    * This class is not thread safe. */ @@ -48,6 +49,11 @@ public final class BytesToBytesMap { private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; + /** + * Special record length that is placed after the last record in a data page. + */ + private static final int END_OF_PAGE_MARKER = -1; + private final TaskMemoryManager memoryManager; /** @@ -64,7 +70,7 @@ public final class BytesToBytesMap { /** * Offset into `currentDataPage` that points to the location where new data can be inserted into - * the page. + * the page. This does not incorporate the page's base offset. */ private long pageCursor = 0; @@ -74,6 +80,15 @@ public final class BytesToBytesMap { */ private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + /** + * The maximum number of keys that BytesToBytesMap supports. The hash table has to be + * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since + * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array + * entries per key, giving us a maximum capacity of (1 << 29). + */ + @VisibleForTesting + static final int MAX_CAPACITY = (1 << 29); + // This choice of page table size and page size means that we can address up to 500 gigabytes // of memory. @@ -143,6 +158,13 @@ public BytesToBytesMap( this.loadFactor = loadFactor; this.loc = new Location(); this.enablePerfMetrics = enablePerfMetrics; + if (initialCapacity <= 0) { + throw new IllegalArgumentException("Initial capacity must be greater than 0"); + } + if (initialCapacity > MAX_CAPACITY) { + throw new IllegalArgumentException( + "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY); + } allocate(initialCapacity); } @@ -162,6 +184,55 @@ public BytesToBytesMap( */ public int size() { return size; } + private static final class BytesToBytesMapIterator implements Iterator { + + private final int numRecords; + private final Iterator dataPagesIterator; + private final Location loc; + + private int currentRecordNumber = 0; + private Object pageBaseObject; + private long offsetInPage; + + BytesToBytesMapIterator(int numRecords, Iterator dataPagesIterator, Location loc) { + this.numRecords = numRecords; + this.dataPagesIterator = dataPagesIterator; + this.loc = loc; + if (dataPagesIterator.hasNext()) { + advanceToNextPage(); + } + } + + private void advanceToNextPage() { + final MemoryBlock currentPage = dataPagesIterator.next(); + pageBaseObject = currentPage.getBaseObject(); + offsetInPage = currentPage.getBaseOffset(); + } + + @Override + public boolean hasNext() { + return currentRecordNumber != numRecords; + } + + @Override + public Location next() { + int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + if (keyLength == END_OF_PAGE_MARKER) { + advanceToNextPage(); + keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage); + } + loc.with(pageBaseObject, offsetInPage); + offsetInPage += 8 + 8 + keyLength + loc.getValueLength(); + currentRecordNumber++; + return loc; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + } + /** * Returns an iterator for iterating over the entries of this map. * @@ -171,27 +242,7 @@ public BytesToBytesMap( * `lookup()`, the behavior of the returned iterator is undefined. */ public Iterator iterator() { - return new Iterator() { - - private int nextPos = bitset.nextSetBit(0); - - @Override - public boolean hasNext() { - return nextPos != -1; - } - - @Override - public Location next() { - final int pos = nextPos; - nextPos = bitset.nextSetBit(nextPos + 1); - return loc.with(pos, 0, true); - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - }; + return new BytesToBytesMapIterator(size, dataPages.iterator(), loc); } /** @@ -268,8 +319,11 @@ public final class Location { private int valueLength; private void updateAddressesAndSizes(long fullKeyAddress) { - final Object page = memoryManager.getPage(fullKeyAddress); - final long keyOffsetInPage = memoryManager.getOffsetInPage(fullKeyAddress); + updateAddressesAndSizes( + memoryManager.getPage(fullKeyAddress), memoryManager.getOffsetInPage(fullKeyAddress)); + } + + private void updateAddressesAndSizes(Object page, long keyOffsetInPage) { long position = keyOffsetInPage; keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position); position += 8; // word used to store the key size @@ -291,6 +345,12 @@ Location with(int pos, int keyHashcode, boolean isDefined) { return this; } + Location with(Object page, long keyOffsetInPage) { + this.isDefined = true; + updateAddressesAndSizes(page, keyOffsetInPage); + return this; + } + /** * Returns true if the key is defined at this position, and false otherwise. */ @@ -345,6 +405,8 @@ public int getValueLength() { *

    * It is only valid to call this method immediately after calling `lookup()` using the same key. *

    + * The key and value must be word-aligned (that is, their sizes must multiples of 8). + *

    * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` * will return information on the data stored by this `putNewKey` call. *

    @@ -367,20 +429,29 @@ public void putNewKey( long valueBaseOffset, int valueLengthBytes) { assert (!isDefined) : "Can only set value once for a key"; - isDefined = true; assert (keyLengthBytes % 8 == 0); assert (valueLengthBytes % 8 == 0); + if (size == MAX_CAPACITY) { + throw new IllegalStateException("BytesToBytesMap has reached maximum capacity"); + } // Here, we'll copy the data into our data pages. Because we only store a relative offset from // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. // (8 byte key length) (key) (8 byte value length) (value) final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; - assert(requiredSize <= PAGE_SIZE_BYTES); + assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker. size++; bitset.set(pos); - // If there's not enough space in the current page, allocate a new page: - if (currentDataPage == null || PAGE_SIZE_BYTES - pageCursor < requiredSize) { + // If there's not enough space in the current page, allocate a new page (8 bytes are reserved + // for the end-of-page marker). + if (currentDataPage == null || PAGE_SIZE_BYTES - 8 - pageCursor < requiredSize) { + if (currentDataPage != null) { + // There wasn't enough space in the current page, so write an end-of-page marker: + final Object pageBaseObject = currentDataPage.getBaseObject(); + final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; + PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); + } MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES); dataPages.add(newPage); pageCursor = 0; @@ -414,7 +485,7 @@ public void putNewKey( longArray.set(pos * 2 + 1, keyHashcode); updateAddressesAndSizes(storedKeyAddress); isDefined = true; - if (size > growthThreshold) { + if (size > growthThreshold && longArray.size() < MAX_CAPACITY) { growAndRehash(); } } @@ -427,8 +498,11 @@ public void putNewKey( * @param capacity the new map capacity */ private void allocate(int capacity) { - capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64); - longArray = new LongArray(memoryManager.allocate(capacity * 8 * 2)); + assert (capacity >= 0); + // The capacity needs to be divisible by 64 so that our bit set can be sized properly + capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64); + assert (capacity <= MAX_CAPACITY); + longArray = new LongArray(memoryManager.allocate(capacity * 8L * 2)); bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); this.growthThreshold = (int) (capacity * loadFactor); @@ -494,10 +568,16 @@ public long getNumHashCollisions() { return numHashCollisions; } + @VisibleForTesting + int getNumDataPages() { + return dataPages.size(); + } + /** * Grows the size of the hash table and re-hash everything. */ - private void growAndRehash() { + @VisibleForTesting + void growAndRehash() { long resizeStartTime = -1; if (enablePerfMetrics) { resizeStartTime = System.nanoTime(); @@ -508,7 +588,7 @@ private void growAndRehash() { final int oldCapacity = (int) oldBitSet.capacity(); // Allocate the new data structures - allocate(Math.min(Integer.MAX_VALUE, growthStrategy.nextCapacity(oldCapacity))); + allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY)); // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it) for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java index 7c321baffe82d..20654e4eeaa02 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -32,7 +32,9 @@ public interface HashMapGrowthStrategy { class Doubling implements HashMapGrowthStrategy { @Override public int nextCapacity(int currentCapacity) { - return currentCapacity * 2; + assert (currentCapacity > 0); + // Guard against overflow + return (currentCapacity * 2 > 0) ? (currentCapacity * 2) : Integer.MAX_VALUE; } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java index 62c29c8cc1e4d..cbbe8594627a5 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java @@ -17,6 +17,12 @@ package org.apache.spark.unsafe.memory; +import java.lang.ref.WeakReference; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; +import javax.annotation.concurrent.GuardedBy; + /** * Manages memory for an executor. Individual operators / tasks allocate memory through * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager. @@ -33,6 +39,12 @@ public class ExecutorMemoryManager { */ final boolean inHeap; + @GuardedBy("this") + private final Map>> bufferPoolsBySize = + new HashMap>>(); + + private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; + /** * Construct a new ExecutorMemoryManager. * @@ -43,16 +55,57 @@ public ExecutorMemoryManager(MemoryAllocator allocator) { this.allocator = allocator; } + /** + * Returns true if allocations of the given size should go through the pooling mechanism and + * false otherwise. + */ + private boolean shouldPool(long size) { + // Very small allocations are less likely to benefit from pooling. + // At some point, we should explore supporting pooling for off-heap memory, but for now we'll + // ignore that case in the interest of simplicity. + return size >= POOLING_THRESHOLD_BYTES && allocator instanceof HeapMemoryAllocator; + } + /** * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed * to be zeroed out (call `zero()` on the result if this is necessary). */ MemoryBlock allocate(long size) throws OutOfMemoryError { - return allocator.allocate(size); + if (shouldPool(size)) { + synchronized (this) { + final LinkedList> pool = bufferPoolsBySize.get(size); + if (pool != null) { + while (!pool.isEmpty()) { + final WeakReference blockReference = pool.pop(); + final MemoryBlock memory = blockReference.get(); + if (memory != null) { + assert (memory.size() == size); + return memory; + } + } + bufferPoolsBySize.remove(size); + } + } + return allocator.allocate(size); + } else { + return allocator.allocate(size); + } } void free(MemoryBlock memory) { - allocator.free(memory); + final long size = memory.size(); + if (shouldPool(size)) { + synchronized (this) { + LinkedList> pool = bufferPoolsBySize.get(size); + if (pool == null) { + pool = new LinkedList>(); + bufferPoolsBySize.put(size, pool); + } + pool.add(new WeakReference(memory)); + } + } else { + allocator.free(memory); + } } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 9224988e6ad69..10881969dbc78 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -19,6 +19,7 @@ import java.util.*; +import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -43,14 +44,22 @@ * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is * approximately 35 terabytes of memory. */ -public final class TaskMemoryManager { +public class TaskMemoryManager { private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); - /** - * The number of entries in the page table. - */ - private static final int PAGE_TABLE_SIZE = 1 << 13; + /** The number of bits used to address the page table. */ + private static final int PAGE_NUMBER_BITS = 13; + + /** The number of bits used to encode offsets in data pages. */ + @VisibleForTesting + static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51 + + /** The number of entries in the page table. */ + private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; + + /** Maximum supported data page size */ + private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS); /** Bit mask for the lower 51 bits of a long. */ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; @@ -101,11 +110,9 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { * intended for allocating large blocks of memory that will be shared between operators. */ public MemoryBlock allocatePage(long size) { - if (logger.isTraceEnabled()) { - logger.trace("Allocating {} byte page", size); - } - if (size >= (1L << 51)) { - throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes"); + if (size > MAXIMUM_PAGE_SIZE) { + throw new IllegalArgumentException( + "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes"); } final int pageNumber; @@ -120,8 +127,8 @@ public MemoryBlock allocatePage(long size) { final MemoryBlock page = executorMemoryManager.allocate(size); page.pageNumber = pageNumber; pageTable[pageNumber] = page; - if (logger.isDebugEnabled()) { - logger.debug("Allocate page number {} ({} bytes)", pageNumber, size); + if (logger.isTraceEnabled()) { + logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); } return page; } @@ -130,9 +137,6 @@ public MemoryBlock allocatePage(long size) { * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. */ public void freePage(MemoryBlock page) { - if (logger.isTraceEnabled()) { - logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size()); - } assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; executorMemoryManager.free(page); @@ -140,8 +144,8 @@ public void freePage(MemoryBlock page) { allocatedPages.clear(page.pageNumber); } pageTable[page.pageNumber] = null; - if (logger.isDebugEnabled()) { - logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + if (logger.isTraceEnabled()) { + logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); } } @@ -173,14 +177,36 @@ public void free(MemoryBlock memory) { /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. + * + * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}. + * @param offsetInPage an offset in this page which incorporates the base offset. In other words, + * this should be the value that you would pass as the base offset into an + * UNSAFE call (e.g. page.baseOffset() + something). + * @return an encoded page address. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { - if (inHeap) { - assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; - return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); - } else { - return offsetInPage; + if (!inHeap) { + // In off-heap mode, an offset is an absolute address that may require a full 64 bits to + // encode. Due to our page size limitation, though, we can convert this into an offset that's + // relative to the page's base offset; this relative offset will fit in 51 bits. + offsetInPage -= page.getBaseOffset(); } + return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + } + + @VisibleForTesting + public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { + assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } + + @VisibleForTesting + public static int decodePageNumber(long pagePlusOffsetAddress) { + return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); + } + + private static long decodeOffset(long pagePlusOffsetAddress) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); } /** @@ -189,7 +215,7 @@ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { */ public Object getPage(long pagePlusOffsetAddress) { if (inHeap) { - final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); final Object page = pageTable[pageNumber].getBaseObject(); assert (page != null); @@ -204,10 +230,15 @@ public Object getPage(long pagePlusOffsetAddress) { * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public long getOffsetInPage(long pagePlusOffsetAddress) { + final long offsetInPage = decodeOffset(pagePlusOffsetAddress); if (inHeap) { - return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + return offsetInPage; } else { - return pagePlusOffsetAddress; + // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we + // converted the absolute address into a relative address. Here, we invert that operation: + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + return pageTable[pageNumber].getBaseOffset() + offsetInPage; } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 7a5c0622d1ffb..81315f7c94645 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -25,24 +25,40 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import static org.mockito.AdditionalMatchers.geq; +import static org.mockito.Mockito.*; import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.memory.*; import org.apache.spark.unsafe.PlatformDependent; import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET; + public abstract class AbstractBytesToBytesMapSuite { private final Random rand = new Random(42); private TaskMemoryManager memoryManager; + private TaskMemoryManager sizeLimitedMemoryManager; @Before public void setup() { memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); + // Mocked memory manager for tests that check the maximum array size, since actually allocating + // such large arrays will cause us to run out of memory in our tests. + sizeLimitedMemoryManager = spy(memoryManager); + when(sizeLimitedMemoryManager.allocate(geq(1L << 20))).thenAnswer(new Answer() { + @Override + public MemoryBlock answer(InvocationOnMock invocation) throws Throwable { + if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) { + throw new OutOfMemoryError("Requested array size exceeds VM limit"); + } + return memoryManager.allocate(1L << 20); + } + }); } @After @@ -101,6 +117,7 @@ public void emptyMap() { final int keyLengthInBytes = keyLengthInWords * 8; final byte[] key = getRandomByteArray(keyLengthInWords); Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + Assert.assertFalse(map.iterator().hasNext()); } finally { map.free(); } @@ -159,7 +176,7 @@ public void setAndRetrieveAKey() { @Test public void iteratorTest() throws Exception { - final int size = 128; + final int size = 4096; BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2); try { for (long i = 0; i < size; i++) { @@ -167,14 +184,26 @@ public void iteratorTest() throws Exception { final BytesToBytesMap.Location loc = map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); Assert.assertFalse(loc.isDefined()); - loc.putNewKey( - value, - PlatformDependent.LONG_ARRAY_OFFSET, - 8, - value, - PlatformDependent.LONG_ARRAY_OFFSET, - 8 - ); + // Ensure that we store some zero-length keys + if (i % 5 == 0) { + loc.putNewKey( + null, + PlatformDependent.LONG_ARRAY_OFFSET, + 0, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8 + ); + } else { + loc.putNewKey( + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8 + ); + } } final java.util.BitSet valuesSeen = new java.util.BitSet(size); final Iterator iter = map.iterator(); @@ -183,11 +212,16 @@ public void iteratorTest() throws Exception { Assert.assertTrue(loc.isDefined()); final MemoryLocation keyAddress = loc.getKeyAddress(); final MemoryLocation valueAddress = loc.getValueAddress(); - final long key = PlatformDependent.UNSAFE.getLong( - keyAddress.getBaseObject(), keyAddress.getBaseOffset()); final long value = PlatformDependent.UNSAFE.getLong( valueAddress.getBaseObject(), valueAddress.getBaseOffset()); - Assert.assertEquals(key, value); + final long keyLength = loc.getKeyLength(); + if (keyLength == 0) { + Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0); + } else { + final long key = PlatformDependent.UNSAFE.getLong( + keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + Assert.assertEquals(value, key); + } valuesSeen.set((int) value); } Assert.assertEquals(size, valuesSeen.cardinality()); @@ -196,6 +230,74 @@ public void iteratorTest() throws Exception { } } + @Test + public void iteratingOverDataPagesWithWastedSpace() throws Exception { + final int NUM_ENTRIES = 1000 * 1000; + final int KEY_LENGTH = 16; + final int VALUE_LENGTH = 40; + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES); + // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte + // pages won't be evenly-divisible by records of this size, which will cause us to waste some + // space at the end of the page. This is necessary in order for us to take the end-of-record + // handling branch in iterator(). + try { + for (int i = 0; i < NUM_ENTRIES; i++) { + final long[] key = new long[] { i, i }; // 2 * 8 = 16 bytes + final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes + final BytesToBytesMap.Location loc = map.lookup( + key, + LONG_ARRAY_OFFSET, + KEY_LENGTH + ); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + key, + LONG_ARRAY_OFFSET, + KEY_LENGTH, + value, + LONG_ARRAY_OFFSET, + VALUE_LENGTH + ); + } + Assert.assertEquals(2, map.getNumDataPages()); + + final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES); + final Iterator iter = map.iterator(); + final long key[] = new long[KEY_LENGTH / 8]; + final long value[] = new long[VALUE_LENGTH / 8]; + while (iter.hasNext()) { + final BytesToBytesMap.Location loc = iter.next(); + Assert.assertTrue(loc.isDefined()); + Assert.assertEquals(KEY_LENGTH, loc.getKeyLength()); + Assert.assertEquals(VALUE_LENGTH, loc.getValueLength()); + PlatformDependent.copyMemory( + loc.getKeyAddress().getBaseObject(), + loc.getKeyAddress().getBaseOffset(), + key, + LONG_ARRAY_OFFSET, + KEY_LENGTH + ); + PlatformDependent.copyMemory( + loc.getValueAddress().getBaseObject(), + loc.getValueAddress().getBaseOffset(), + value, + LONG_ARRAY_OFFSET, + VALUE_LENGTH + ); + for (long j : key) { + Assert.assertEquals(key[0], j); + } + for (long j : value) { + Assert.assertEquals(key[0], j); + } + valuesSeen.set((int) key[0]); + } + Assert.assertEquals(NUM_ENTRIES, valuesSeen.cardinality()); + } finally { + map.free(); + } + } + @Test public void randomizedStressTest() { final int size = 65536; @@ -247,4 +349,35 @@ public void randomizedStressTest() { map.free(); } } + + @Test + public void initialCapacityBoundsChecking() { + try { + new BytesToBytesMap(sizeLimitedMemoryManager, 0); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // expected exception + } + + try { + new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // expected exception + } + + // Can allocate _at_ the max capacity + BytesToBytesMap map = + new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY); + map.free(); + } + + @Test + public void resizingLargeMap() { + // As long as a map's capacity is below the max, we should be able to resize up to the max + BytesToBytesMap map = + new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64); + map.growAndRehash(); + map.free(); + } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java index 932882f1ca248..06fb081183659 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -38,4 +38,27 @@ public void leakedPageMemoryIsDetected() { Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); } + @Test + public void encodePageNumberAndOffsetOffHeap() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + final MemoryBlock dataPage = manager.allocatePage(256); + // In off-heap mode, an offset is an absolute address that may require more than 51 bits to + // encode. This test exercises that corner-case: + final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); + Assert.assertEquals(null, manager.getPage(encodedAddress)); + Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); + } + + @Test + public void encodePageNumberAndOffsetOnHeap() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final MemoryBlock dataPage = manager.allocatePage(256); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); + Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); + Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); + } + } diff --git a/yarn/pom.xml b/yarn/pom.xml index 7c8c3613e7a05..e207a46809684 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -30,6 +30,7 @@ Spark Project YARN yarn + 1.9 @@ -38,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.hadoop hadoop-yarn-api @@ -85,7 +93,12 @@ jetty-servlet - + + + org.apache.hadoop hadoop-yarn-server-tests @@ -97,59 +110,44 @@ mockito-all test + + org.mortbay.jetty + jetty + 6.1.26 + + + org.mortbay.jetty + servlet-api + + + test + + + com.sun.jersey + jersey-core + ${jersey.version} + test + + + com.sun.jersey + jersey-json + ${jersey.version} + test + + + stax + stax-api + + + + + com.sun.jersey + jersey-server + ${jersey.version} + test + - - - - - hadoop-2.2 - - 1.9 - - - - org.mortbay.jetty - jetty - 6.1.26 - - - org.mortbay.jetty - servlet-api - - - test - - - com.sun.jersey - jersey-core - ${jersey.version} - test - - - com.sun.jersey - jersey-json - ${jersey.version} - test - - - stax - stax-api - - - - - com.sun.jersey - jersey-server - ${jersey.version} - test - - - - - + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index aaae6f9734a85..77af46c192cc2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -60,8 +60,11 @@ private[yarn] class AMDelegationTokenRenewer( private val hadoopUtil = YarnSparkHadoopUtil.get - private val daysToKeepFiles = sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) - private val numFilesToKeep = sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) + private val credentialsFile = sparkConf.get("spark.yarn.credentials.file") + private val daysToKeepFiles = + sparkConf.getInt("spark.yarn.credentials.file.retention.days", 5) + private val numFilesToKeep = + sparkConf.getInt("spark.yarn.credentials.file.retention.count", 5) /** * Schedule a login from the keytab and principal set using the --principal and --keytab @@ -121,7 +124,7 @@ private[yarn] class AMDelegationTokenRenewer( import scala.concurrent.duration._ try { val remoteFs = FileSystem.get(hadoopConf) - val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val credentialsPath = new Path(credentialsFile) val thresholdTime = System.currentTimeMillis() - (daysToKeepFiles days).toMillis hadoopUtil.listFilesSorted( remoteFs, credentialsPath.getParent, @@ -160,7 +163,7 @@ private[yarn] class AMDelegationTokenRenewer( val keytabLoggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytab) logInfo("Successfully logged into KDC.") val tempCreds = keytabLoggedInUGI.getCredentials - val credentialsPath = new Path(sparkConf.get("spark.yarn.credentials.file")) + val credentialsPath = new Path(credentialsFile) val dst = credentialsPath.getParent keytabLoggedInUGI.doAs(new PrivilegedExceptionAction[Void] { // Get a copy of the credentials @@ -186,8 +189,7 @@ private[yarn] class AMDelegationTokenRenewer( } val nextSuffix = lastCredentialsFileSuffix + 1 val tokenPathStr = - sparkConf.get("spark.yarn.credentials.file") + - SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix + credentialsFile + SparkHadoopUtil.SPARK_YARN_CREDS_COUNTER_DELIM + nextSuffix val tokenPath = new Path(tokenPathStr) val tempTokenPath = new Path(tokenPathStr + SparkHadoopUtil.SPARK_YARN_CREDS_TEMP_EXTENSION) logInfo("Writing out delegation tokens to " + tempTokenPath.toString) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 29752969e6152..760e458972d98 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -34,7 +34,7 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, Spar import org.apache.spark.SparkException import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil} import org.apache.spark.deploy.history.HistoryServer -import org.apache.spark.scheduler.cluster.YarnSchedulerBackend +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util._ @@ -220,7 +220,7 @@ private[spark] class ApplicationMaster( sparkContextRef.compareAndSet(sc, null) } - private def registerAM(uiAddress: String, securityMgr: SecurityManager) = { + private def registerAM(_rpcEnv: RpcEnv, uiAddress: String, securityMgr: SecurityManager) = { val sc = sparkContextRef.get() val appId = client.getAttemptId().getApplicationId().toString() @@ -231,8 +231,14 @@ private[spark] class ApplicationMaster( .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } .getOrElse("") - allocator = client.register(yarnConf, - if (sc != null) sc.getConf else sparkConf, + val _sparkConf = if (sc != null) sc.getConf else sparkConf + val driverUrl = _rpcEnv.uriOf( + SparkEnv.driverActorSystemName, + RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + allocator = client.register(driverUrl, + yarnConf, + _sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), uiAddress, historyAddress, @@ -279,7 +285,7 @@ private[spark] class ApplicationMaster( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + registerAM(rpcEnv, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } } @@ -289,7 +295,7 @@ private[spark] class ApplicationMaster( rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) waitForSparkDriver() addAmIpFilter() - registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + registerAM(rpcEnv, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -300,11 +306,14 @@ private[spark] class ApplicationMaster( val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) // we want to be reasonably responsive without causing too many requests to RM. - val schedulerInterval = - sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "5s") + val heartbeatInterval = math.max(0, math.min(expiryInterval / 2, + sparkConf.getTimeAsMs("spark.yarn.scheduler.heartbeat.interval-ms", "3s"))) - // must be <= expiryInterval / 2. - val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval)) + // we want to check more frequently for pending containers + val initialAllocationInterval = math.min(heartbeatInterval, + sparkConf.getTimeAsMs("spark.yarn.scheduler.initial-allocation.interval", "200ms")) + + var nextAllocationInterval = initialAllocationInterval // The number of failures in a row until Reporter thread give up val reporterMaxFailures = sparkConf.getInt("spark.yarn.scheduler.reporterThread.maxFailures", 5) @@ -330,15 +339,27 @@ private[spark] class ApplicationMaster( if (!NonFatal(e) || failureCount >= reporterMaxFailures) { finish(FinalApplicationStatus.FAILED, ApplicationMaster.EXIT_REPORTER_FAILURE, "Exception was thrown " + - s"${failureCount} time(s) from Reporter thread.") - + s"$failureCount time(s) from Reporter thread.") } else { - logWarning(s"Reporter thread fails ${failureCount} time(s) in a row.", e) + logWarning(s"Reporter thread fails $failureCount time(s) in a row.", e) } } } try { - Thread.sleep(interval) + val numPendingAllocate = allocator.getNumPendingAllocate + val sleepInterval = + if (numPendingAllocate > 0) { + val currentAllocationInterval = + math.min(heartbeatInterval, nextAllocationInterval) + nextAllocationInterval = currentAllocationInterval * 2 // avoid overflow + currentAllocationInterval + } else { + nextAllocationInterval = initialAllocationInterval + heartbeatInterval + } + logDebug(s"Number of pending allocations is $numPendingAllocate. " + + s"Sleeping for $sleepInterval.") + Thread.sleep(sleepInterval) } catch { case e: InterruptedException => } @@ -349,7 +370,8 @@ private[spark] class ApplicationMaster( t.setDaemon(true) t.setName("Reporter") t.start() - logInfo("Started progress reporter thread - sleep time : " + interval) + logInfo(s"Started progress reporter thread with (heartbeat : $heartbeatInterval, " + + s"initial allocation : $initialAllocationInterval) intervals") t } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index d21a7393478ce..234051eb7d3bb 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.yarn -import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream} +import java.io.{ByteArrayInputStream, DataInputStream, File, FileOutputStream, IOException} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer import java.security.PrivilegedExceptionAction @@ -91,30 +91,52 @@ private[spark] class Client( * available in the alpha API. */ def submitApplication(): ApplicationId = { - // Setup the credentials before doing anything else, so we have don't have issues at any point. - setupCredentials() - yarnClient.init(yarnConf) - yarnClient.start() - - logInfo("Requesting a new application from cluster with %d NodeManagers" - .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers)) - - // Get a new application from our RM - val newApp = yarnClient.createApplication() - val newAppResponse = newApp.getNewApplicationResponse() - val appId = newAppResponse.getApplicationId() - - // Verify whether the cluster has enough resources for our AM - verifyClusterResources(newAppResponse) - - // Set up the appropriate contexts to launch our AM - val containerContext = createContainerLaunchContext(newAppResponse) - val appContext = createApplicationSubmissionContext(newApp, containerContext) - - // Finally, submit and monitor the application - logInfo(s"Submitting application ${appId.getId} to ResourceManager") - yarnClient.submitApplication(appContext) - appId + var appId: ApplicationId = null + try { + // Setup the credentials before doing anything else, + // so we have don't have issues at any point. + setupCredentials() + yarnClient.init(yarnConf) + yarnClient.start() + + logInfo("Requesting a new application from cluster with %d NodeManagers" + .format(yarnClient.getYarnClusterMetrics.getNumNodeManagers)) + + // Get a new application from our RM + val newApp = yarnClient.createApplication() + val newAppResponse = newApp.getNewApplicationResponse() + appId = newAppResponse.getApplicationId() + + // Verify whether the cluster has enough resources for our AM + verifyClusterResources(newAppResponse) + + // Set up the appropriate contexts to launch our AM + val containerContext = createContainerLaunchContext(newAppResponse) + val appContext = createApplicationSubmissionContext(newApp, containerContext) + + // Finally, submit and monitor the application + logInfo(s"Submitting application ${appId.getId} to ResourceManager") + yarnClient.submitApplication(appContext) + appId + } catch { + case e: Throwable => + if (appId != null) { + val appStagingDir = getAppStagingDir(appId) + try { + val preserveFiles = sparkConf.getBoolean("spark.yarn.preserve.staging.files", false) + val stagingDirPath = new Path(appStagingDir) + val fs = FileSystem.get(hadoopConf) + if (!preserveFiles && fs.exists(stagingDirPath)) { + logInfo("Deleting staging directory " + stagingDirPath) + fs.delete(stagingDirPath, true) + } + } catch { + case ioe: IOException => + logWarning("Failed to cleanup staging dir " + appStagingDir, ioe) + } + } + throw e + } } /** @@ -1120,9 +1142,9 @@ object Client extends Logging { logDebug("HiveMetaStore configured in localmode") } } catch { - case e:java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } - case e:java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } - case e:Exception => { logError("Unexpected Exception " + e) + case e: java.lang.NoSuchMethodException => { logInfo("Hive Method not found " + e); return } + case e: java.lang.ClassNotFoundException => { logInfo("Hive Class not found " + e); return } + case e: Exception => { logError("Unexpected Exception " + e) throw new RuntimeException("Unexpected exception", e) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 5653c9f14dc6d..9c7b1b3988082 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -98,6 +98,12 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) numExecutors = initialNumExecutors } + principal = Option(principal) + .orElse(sparkConf.getOption("spark.yarn.principal")) + .orNull + keytab = Option(keytab) + .orElse(sparkConf.getOption("spark.yarn.keytab")) + .orNull } /** diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index c592ecfdfce06..3d3a966960e9f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -43,22 +43,22 @@ private[spark] class ClientDistributedCacheManager() extends Logging { * Add a resource to the list of distributed cache resources. This list can * be sent to the ApplicationMaster and possibly the executors so that it can * be downloaded into the Hadoop distributed cache for use by this application. - * Adds the LocalResource to the localResources HashMap passed in and saves + * Adds the LocalResource to the localResources HashMap passed in and saves * the stats of the resources to they can be sent to the executors and verified. * * @param fs FileSystem * @param conf Configuration * @param destPath path to the resource * @param localResources localResource hashMap to insert the resource into - * @param resourceType LocalResourceType + * @param resourceType LocalResourceType * @param link link presented in the distributed cache to the destination - * @param statCache cache to store the file/directory stats + * @param statCache cache to store the file/directory stats * @param appMasterOnly Whether to only add the resource to the app master */ def addResource( fs: FileSystem, conf: Configuration, - destPath: Path, + destPath: Path, localResources: HashMap[String, LocalResource], resourceType: LocalResourceType, link: String, @@ -74,15 +74,15 @@ private[spark] class ClientDistributedCacheManager() extends Logging { amJarRsrc.setSize(destStatus.getLen()) if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name") localResources(link) = amJarRsrc - + if (!appMasterOnly) { val uri = destPath.toUri() val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link) if (resourceType == LocalResourceType.FILE) { - distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } else { - distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), + distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(), destStatus.getModificationTime().toString(), visibility.name()) } } @@ -95,13 +95,13 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (keys, tupleValues) = distCacheFiles.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { - env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = - timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = - sizes.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = - visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = + sizes.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_FILES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -112,13 +112,13 @@ private[spark] class ClientDistributedCacheManager() extends Logging { val (keys, tupleValues) = distCacheArchives.unzip val (sizes, timeStamps, visibilities) = tupleValues.unzip3 if (keys.size > 0) { - env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = - timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = + timeStamps.reduceLeft[String] { (acc, n) => acc + "," + n } env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = - sizes.reduceLeft[String] { (acc,n) => acc + "," + n } - env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = - visibilities.reduceLeft[String] { (acc,n) => acc + "," + n } + sizes.reduceLeft[String] { (acc, n) => acc + "," + n } + env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") = + visibilities.reduceLeft[String] { (acc, n) => acc + "," + n } } } @@ -160,7 +160,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def ancestorsHaveExecutePermissions( fs: FileSystem, path: Path, - statCache: Map[URI, FileStatus]): Boolean = { + statCache: Map[URI, FileStatus]): Boolean = { var current = path while (current != null) { // the subdirs in the path should have execute permissions for others @@ -197,7 +197,7 @@ private[spark] class ClientDistributedCacheManager() extends Logging { def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = { val stat = statCache.get(uri) match { case Some(existstat) => existstat - case None => + case None => val newStat = fs.getFileStatus(new Path(uri)) statCache.put(uri, newStat) newStat diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 8a08f561a2df2..21193e7c625e3 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -34,10 +34,8 @@ import org.apache.hadoop.yarn.util.RackResolver import org.apache.log4j.{Level, Logger} -import org.apache.spark.{SparkEnv, Logging, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.AkkaUtils /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -53,6 +51,7 @@ import org.apache.spark.util.AkkaUtils * synchronized. */ private[yarn] class YarnAllocator( + driverUrl: String, conf: Configuration, sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], @@ -107,13 +106,6 @@ private[yarn] class YarnAllocator( new ThreadFactoryBuilder().setNameFormat("ContainerLauncher #%d").setDaemon(true).build()) launcherPool.allowCoreThreadTimeOut(true) - private val driverUrl = AkkaUtils.address( - AkkaUtils.protocol(securityMgr.akkaSSLOptions.enabled), - SparkEnv.driverActorSystemName, - sparkConf.get("spark.driver.host"), - sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) - // For testing private val launchContainers = sparkConf.getBoolean("spark.yarn.launchContainers", true) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index b134751366522..7f533ee55e8bb 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -55,6 +55,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg * @param uiHistoryAddress Address of the application on the History Server. */ def register( + driverUrl: String, conf: YarnConfiguration, sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], @@ -72,7 +73,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args, securityMgr) + new YarnAllocator(driverUrl, conf, sparkConf, amClient, getAttemptId(), args, securityMgr) } /** @@ -89,9 +90,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg /** Returns the attempt ID. */ def getAttemptId(): ApplicationAttemptId = { - val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) - val containerId = ConverterUtils.toContainerId(containerIdString) - containerId.getApplicationAttemptId() + YarnSparkHadoopUtil.get.getContainerId.getApplicationAttemptId() } /** Returns the configuration for the AmIpFilter to add to the Spark UI. */ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index ba91872107d0c..68d01c17ef720 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -33,7 +33,8 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import org.apache.hadoop.yarn.api.records.{Priority, ApplicationAccessType} +import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority} +import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{SecurityManager, SparkConf, SparkException} @@ -136,12 +137,16 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { tokenRenewer.foreach(_.stop()) } + private[spark] def getContainerId: ContainerId = { + val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name()) + ConverterUtils.toContainerId(containerIdString) + } } object YarnSparkHadoopUtil { - // Additional memory overhead + // Additional memory overhead // 10% was arrived at experimentally. In the interest of minimizing memory waste while covering - // the common cases. Memory overhead tends to grow with container size. + // the common cases. Memory overhead tends to grow with container size. val MEMORY_OVERHEAD_FACTOR = 0.10 val MEMORY_OVERHEAD_MIN = 384 diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index aeb218a575455..1ace1a97d5156 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -17,10 +17,19 @@ package org.apache.spark.scheduler.cluster +import java.net.NetworkInterface + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.yarn.api.records.NodeState +import org.apache.hadoop.yarn.client.api.YarnClient +import org.apache.hadoop.yarn.conf.YarnConfiguration + import org.apache.spark.SparkContext +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.IntParam +import org.apache.spark.util.{IntParam, Utils} private[spark] class YarnClusterSchedulerBackend( scheduler: TaskSchedulerImpl, @@ -53,4 +62,70 @@ private[spark] class YarnClusterSchedulerBackend( logError("Application attempt ID is not set.") super.applicationAttemptId } + + override def getDriverLogUrls: Option[Map[String, String]] = { + var yarnClientOpt: Option[YarnClient] = None + var driverLogs: Option[Map[String, String]] = None + try { + val yarnConf = new YarnConfiguration(sc.hadoopConfiguration) + val containerId = YarnSparkHadoopUtil.get.getContainerId + yarnClientOpt = Some(YarnClient.createYarnClient()) + yarnClientOpt.foreach { yarnClient => + yarnClient.init(yarnConf) + yarnClient.start() + + // For newer versions of YARN, we can find the HTTP address for a given node by getting a + // container report for a given container. But container reports came only in Hadoop 2.4, + // so we basically have to get the node reports for all nodes and find the one which runs + // this container. For that we have to compare the node's host against the current host. + // Since the host can have multiple addresses, we need to compare against all of them to + // find out if one matches. + + // Get all the addresses of this node. + val addresses = + NetworkInterface.getNetworkInterfaces.asScala + .flatMap(_.getInetAddresses.asScala) + .toSeq + + // Find a node report that matches one of the addresses + val nodeReport = + yarnClient.getNodeReports(NodeState.RUNNING).asScala.find { x => + val host = x.getNodeId.getHost + addresses.exists { address => + address.getHostAddress == host || + address.getHostName == host || + address.getCanonicalHostName == host + } + } + + // Now that we have found the report for the Node Manager that the AM is running on, we + // can get the base HTTP address for the Node manager from the report. + // The format used for the logs for each container is well-known and can be constructed + // using the NM's HTTP address and the container ID. + // The NM may be running several containers, but we can build the URL for the AM using + // the AM's container ID, which we already know. + nodeReport.foreach { report => + val httpAddress = report.getHttpAddress + // lookup appropriate http scheme for container log urls + val yarnHttpPolicy = yarnConf.get( + YarnConfiguration.YARN_HTTP_POLICY_KEY, + YarnConfiguration.YARN_HTTP_POLICY_DEFAULT + ) + val user = Utils.getCurrentUserName() + val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" + logDebug(s"Base URL for logs: $baseUrl") + driverLogs = Some( + Map("stderr" -> s"$baseUrl/stderr?start=0", "stdout" -> s"$baseUrl/stdout?start=0")) + } + } + } catch { + case e: Exception => + logInfo("Node Report API is not available in the version of YARN being used, so AM" + + " logs link will not appear in application UI", e) + } finally { + yarnClientOpt.foreach(_.close()) + } + driverLogs + } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 80b57d1355a3a..804dfecde7867 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn import java.net.URI -import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar import org.mockito.Mockito.when @@ -36,16 +35,18 @@ import org.apache.hadoop.yarn.util.{Records, ConverterUtils} import scala.collection.mutable.HashMap import scala.collection.mutable.Map +import org.apache.spark.SparkFunSuite -class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { + +class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar { class MockClientDistributedCacheManager extends ClientDistributedCacheManager { - override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): + override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): LocalResourceVisibility = { LocalResourceVisibility.PRIVATE } } - + test("test getFileStatus empty") { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] @@ -60,7 +61,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val distMgr = new ClientDistributedCacheManager() val fs = mock[FileSystem] val uri = new URI("/tmp/testing") - val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", + val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus()) val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus) @@ -77,7 +78,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -100,11 +101,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None) // add another one and verify both there and order correct - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing2")) val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2") when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", + distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2", statCache, false) val resource2 = localResources("link2") assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -116,7 +117,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val env2 = new HashMap[String, String]() distMgr.setDistFilesEnv(env2) val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',') - val files = env2("SPARK_YARN_CACHE_FILES").split(',') + val files = env2("SPARK_YARN_CACHE_FILES").split(',') val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',') val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',') assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link") @@ -140,7 +141,7 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { when(fs.getFileStatus(destPath)).thenReturn(new FileStatus()) intercept[Exception] { - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null, statCache, false) } assert(localResources.get("link") === None) @@ -154,11 +155,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, true) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) @@ -188,11 +189,11 @@ class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar { val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing") val localResources = HashMap[String, LocalResource]() val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", + val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner", null, new Path("/tmp/testing")) when(fs.getFileStatus(destPath)).thenReturn(realFileStatus) - distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", + distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link", statCache, false) val resource = localResources("link") assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 508819e242a26..01d33c9ce9297 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -33,12 +33,12 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{SparkException, SparkConf} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { +class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { override def beforeAll(): Unit = { System.setProperty("SPARK_YARN_MODE", "true") @@ -203,7 +203,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { def getFieldValue2[A: ClassTag, A1: ClassTag, B]( clazz: Class[_], field: String, - defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { + defaults: => B)(mapTo: A => B)(mapTo1: A1 => B): B = { Try(clazz.getField(field)).map(_.get(null)).map { case v: A => mapTo(v) case v1: A1 => mapTo1(v1) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 455f1019d86dd..7509000771d94 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -26,13 +26,13 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.apache.spark.SecurityManager +import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ import org.apache.spark.scheduler.SplitInfo -import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} class MockResolver extends DNSToSwitchMapping { @@ -46,7 +46,7 @@ class MockResolver extends DNSToSwitchMapping { def reloadCachedMappings(names: JList[String]) {} } -class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach { +class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { val conf = new Configuration() conf.setClass( CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, @@ -90,6 +90,7 @@ class YarnAllocatorSuite extends FunSuite with Matchers with BeforeAndAfterEach "--jar", "somejar.jar", "--class", "SomeClass") new YarnAllocator( + "not used", conf, sparkConf, rmClient, diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index d3c606e0ed998..d8bc2534c1a6a 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -23,17 +23,19 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConversions._ import scala.collection.mutable +import scala.io.Source import com.google.common.base.Charsets.UTF_8 import com.google.common.io.ByteStreams import com.google.common.io.Files import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.scalatest.{BeforeAndAfterAll, Matchers} -import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils} +import org.apache.spark._ import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListener, SparkListenerExecutorAdded} +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationStart, + SparkListenerExecutorAdded} import org.apache.spark.util.Utils /** @@ -41,7 +43,7 @@ import org.apache.spark.util.Utils * applications, and require the Spark assembly to be built before they can be successfully * run. */ -class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging { +class YarnClusterSuite extends SparkFunSuite with BeforeAndAfterAll with Matchers with Logging { // log4j configuration for the YARN containers, so that their output is collected // by YARN instead of trying to overwrite unit-tests.log. @@ -290,10 +292,15 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit private[spark] class SaveExecutorInfo extends SparkListener { val addedExecutorInfos = mutable.Map[String, ExecutorInfo]() + var driverLogs: Option[collection.Map[String, String]] = None override def onExecutorAdded(executor: SparkListenerExecutorAdded) { addedExecutorInfos(executor.executorId) = executor.executorInfo } + + override def onApplicationStart(appStart: SparkListenerApplicationStart): Unit = { + driverLogs = appStart.driverLogs + } } private object YarnClusterDriver extends Logging with Matchers { @@ -314,6 +321,7 @@ private object YarnClusterDriver extends Logging with Matchers { val sc = new SparkContext(new SparkConf() .set("spark.extraListeners", classOf[SaveExecutorInfo].getName) .setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns")) + val conf = sc.getConf val status = new File(args(0)) var result = "failure" try { @@ -335,6 +343,20 @@ private object YarnClusterDriver extends Logging with Matchers { executorInfos.foreach { info => assert(info.logUrlMap.nonEmpty) } + + // If we are running in yarn-cluster mode, verify that driver logs are downloadable. + if (conf.get("spark.master") == "yarn-cluster") { + assert(listener.driverLogs.nonEmpty) + val driverLogs = listener.driverLogs.get + assert(driverLogs.size === 2) + assert(driverLogs.containsKey("stderr")) + assert(driverLogs.containsKey("stdout")) + val stderr = driverLogs("stderr") // YARN puts everything in stderr. + val lines = Source.fromURL(stderr).getLines() + // Look for a line that contains YarnClusterSchedulerBackend, since that is guaranteed in + // cluster mode. + assert(lines.exists(_.contains("YarnClusterSchedulerBackend"))) + } } } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index e10b985c3c236..49bee0866dd43 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -25,15 +25,15 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.scalatest.{FunSuite, Matchers} +import org.scalatest.Matchers import org.apache.hadoop.yarn.api.records.ApplicationAccessType -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils -class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { +class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging { val hasBash = try {